elselse commited on
Commit
ba82d85
·
verified ·
1 Parent(s): e93230f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -30
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import json
3
- from transformers import pipeline
4
  import torch
5
  import random
6
  import numpy as np
@@ -8,48 +8,62 @@ import numpy as np
8
  torch.manual_seed(42)
9
  random.seed(42)
10
  np.random.seed(42)
11
-
12
  torch.use_deterministic_algorithms(True)
13
 
14
- # Load Hugging Face model (text classification)
 
 
 
15
  classifier = pipeline(
16
  task="text-classification",
17
- model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
18
- top_k=None
 
 
19
  )
20
- classifier.model.eval()
21
 
22
- # Load child-to-parent mapping
23
- with open("child_to_parent_mapping.json", "r") as f:
24
- child_to_parent = json.load(f)
 
25
 
26
- def predict_cwe(commit_message: str):
 
 
 
27
  """
28
- Predict CWE(s) from a commit message and map to parent CWEs.
29
  """
30
- results = classifier(commit_message)[0]
31
- sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
32
-
33
- threshold = 0.2
34
- filtered_results = [item for item in sorted_results if item["score"] >= threshold]
35
-
36
- # Map predictions to parent CWE (if available)
37
- mapped_results = {}
38
- for item in sorted_results[:5]:
39
- child_cwe = item["label"].replace("CWE-", "")
40
- parent_cwe = child_to_parent.get(child_cwe, child_cwe) # default to child if no parent
41
- mapped_results[f"CWE-{parent_cwe}"] = round(float(item["score"]), 4)
42
-
43
- return mapped_results
44
-
45
- # Gradio UI
 
 
 
 
 
 
 
46
  demo = gr.Interface(
47
  fn=predict_cwe,
48
  inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
49
  outputs=gr.Label(num_top_classes=5),
50
- title="CWE Prediction from Commit Message",
51
- description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
52
- "Predicted child CWEs are mapped to their parent CWEs if applicable.",
53
  examples=[
54
  ["Fixed buffer overflow in input parsing"],
55
  ["SQL injection possible in login flow"],
 
1
  import gradio as gr
2
  import json
3
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
4
  import torch
5
  import random
6
  import numpy as np
 
8
  torch.manual_seed(42)
9
  random.seed(42)
10
  np.random.seed(42)
 
11
  torch.use_deterministic_algorithms(True)
12
 
13
+ model_path = "CIRCL/cwe-parent-vulnerability-classification-roberta-base"
14
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
15
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
16
+
17
  classifier = pipeline(
18
  task="text-classification",
19
+ model=model,
20
+ tokenizer=tokenizer,
21
+ top_k=None,
22
+ return_all_scores=True
23
  )
24
+ model.eval()
25
 
26
+ with open(f"{model_path}/config.json", "r") as f:
27
+ config = json.load(f)
28
+ id_to_cwe = {int(k): v for k, v in config["id2label"].items()}
29
+ valid_cwes = set(id_to_cwe.values())
30
 
31
+ with open("deep_child_to_ancestor.json", "r") as f:
32
+ child_to_ancestor = json.load(f)
33
+
34
+ def map_prediction_to_valid_cwes(predictions, id_to_cwe, child_to_ancestor, threshold=0.2, top_k=5):
35
  """
36
+ Map model predictions to CWE ancestors and return top_k valid results.
37
  """
38
+ results = []
39
+ for item in predictions:
40
+ for label_idx, score in enumerate(item):
41
+ if score["score"] >= threshold:
42
+ label_id = score["label"].split("_")[-1] # "LABEL_123" "123"
43
+ label_id = int(label_id)
44
+ if label_id in id_to_cwe:
45
+ cwe = id_to_cwe[label_id]
46
+ ancestor = child_to_ancestor.get(cwe, cwe)
47
+ if ancestor in valid_cwes:
48
+ results.append((f"CWE-{ancestor}", round(score["score"], 4)))
49
+
50
+ aggregated = {}
51
+ for cwe, score in results:
52
+ aggregated[cwe] = max(aggregated.get(cwe, 0), score)
53
+
54
+ sorted_results = sorted(aggregated.items(), key=lambda x: x[1], reverse=True)
55
+ return dict(sorted_results[:top_k])
56
+
57
+ def predict_cwe(commit_message: str):
58
+ raw_preds = classifier(commit_message)
59
+ return map_prediction_to_valid_cwes(raw_preds, id_to_cwe, child_to_ancestor)
60
+
61
  demo = gr.Interface(
62
  fn=predict_cwe,
63
  inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
64
  outputs=gr.Label(num_top_classes=5),
65
+ title="CWE Prediction from Commit Message and vulnerability description",
66
+ description="This tool predicts CWE ancestor categories from Git commit messages and vulnerability descriptions, based on a fine-tuned transformer model.",
 
67
  examples=[
68
  ["Fixed buffer overflow in input parsing"],
69
  ["SQL injection possible in login flow"],