elselse commited on
Commit
9ccb1c5
·
verified ·
1 Parent(s): 427ff3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -53
app.py CHANGED
@@ -1,69 +1,35 @@
1
- import gradio as gr
2
  import json
3
- from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
4
- import torch
5
- import random
6
- import numpy as np
7
-
8
- torch.manual_seed(42)
9
- random.seed(42)
10
- np.random.seed(42)
11
- torch.use_deterministic_algorithms(True)
12
-
13
- model_path = "CIRCL/cwe-parent-vulnerability-classification-roberta-base"
14
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
15
- tokenizer = AutoTokenizer.from_pretrained(model_path)
16
 
17
  classifier = pipeline(
18
  task="text-classification",
19
- model=model,
20
- tokenizer=tokenizer,
21
- top_k=None,
22
- return_all_scores=True
23
  )
24
- model.eval()
25
 
26
- with open(f"{model_path}/config.json", "r") as f:
27
- config = json.load(f)
28
- id_to_cwe = {int(k): v for k, v in config["id2label"].items()}
29
- valid_cwes = set(id_to_cwe.values())
30
 
31
- with open("deep_child_to_ancestor.json", "r") as f:
32
- child_to_ancestor = json.load(f)
33
-
34
- def map_prediction_to_valid_cwes(predictions, id_to_cwe, child_to_ancestor, threshold=0.2, top_k=5):
35
- """
36
- Map model predictions to CWE ancestors and return top_k valid results.
37
- """
38
- results = []
39
- for item in predictions:
40
- for label_idx, score in enumerate(item):
41
- if score["score"] >= threshold:
42
- label_id = score["label"].split("_")[-1] # "LABEL_123" → "123"
43
- label_id = int(label_id)
44
- if label_id in id_to_cwe:
45
- cwe = id_to_cwe[label_id]
46
- ancestor = child_to_ancestor.get(cwe, cwe)
47
- if ancestor in valid_cwes:
48
- results.append((f"CWE-{ancestor}", round(score["score"], 4)))
49
-
50
- aggregated = {}
51
- for cwe, score in results:
52
- aggregated[cwe] = max(aggregated.get(cwe, 0), score)
53
 
54
- sorted_results = sorted(aggregated.items(), key=lambda x: x[1], reverse=True)
55
- return dict(sorted_results[:top_k])
 
 
 
56
 
57
- def predict_cwe(commit_message: str):
58
- raw_preds = classifier(commit_message)
59
- return map_prediction_to_valid_cwes(raw_preds, id_to_cwe, child_to_ancestor)
60
 
61
  demo = gr.Interface(
62
  fn=predict_cwe,
63
- inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
64
  outputs=gr.Label(num_top_classes=5),
65
- title="CWE Prediction from Commit Message and vulnerability description",
66
- description="This tool predicts CWE ancestor categories from Git commit messages and vulnerability descriptions, based on a fine-tuned transformer model.",
67
  examples=[
68
  ["Fixed buffer overflow in input parsing"],
69
  ["SQL injection possible in login flow"],
 
1
+ from transformers import pipeline
2
  import json
3
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  classifier = pipeline(
6
  task="text-classification",
7
+ model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
8
+ top_k=None
 
 
9
  )
 
10
 
11
+ with open("child_to_parent_mapping.json", "r") as f:
12
+ child_to_parent = json.load(f)
 
 
13
 
14
+ def predict_cwe(commit_message: str):
15
+ results = classifier(commit_message)[0]
16
+ threshold = 0.2
17
+ filtered = [r for r in results if r["score"] >= threshold]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ mapped = {}
20
+ for r in filtered:
21
+ cwe_id = r["label"].replace("CWE-", "")
22
+ parent_id = child_to_parent.get(cwe_id, cwe_id)
23
+ mapped[f"CWE-{parent_id}"] = round(float(r["score"]), 4)
24
 
25
+ return mapped
 
 
26
 
27
  demo = gr.Interface(
28
  fn=predict_cwe,
29
+ inputs=gr.Textbox(lines=3, placeholder="Enter your commit message or vulnerability description here..."),
30
  outputs=gr.Label(num_top_classes=5),
31
+ title="CWE Prediction from Commit Message Or Description",
32
+ description="Predict top CWE parent classes from a commit message or description.",
33
  examples=[
34
  ["Fixed buffer overflow in input parsing"],
35
  ["SQL injection possible in login flow"],