elselse commited on
Commit
bcd6411
·
verified ·
1 Parent(s): c364fcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -40
app.py CHANGED
@@ -1,46 +1,58 @@
1
  import gradio as gr
2
- import torch
3
  import json
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
-
6
- model_path = "CIRCL/cwe-parent-vulnerability-classification-roberta-base"
7
- tokenizer = AutoTokenizer.from_pretrained(model_path)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
9
- model.eval()
10
-
11
- with open("child_to_parent_mapping.json", "r") as f:
12
- child_to_ancestor = json.load(f)
13
-
14
- id2label = model.config.id2label
15
-
16
- def extract_commit_text_hg_style(input_text):
17
- return input_text.strip()
18
-
19
- def predict_ancestors(input_text):
20
- text = extract_commit_text_hg_style(input_text)
21
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
22
-
23
- with torch.no_grad():
24
- outputs = model(**inputs)
25
- logits = outputs.logits
26
- probs = torch.softmax(logits, dim=-1)
27
 
28
- topk = torch.topk(probs, k=5)
29
- top_ids = topk.indices[0].tolist()
30
- top_scores = topk.values[0].tolist()
31
 
32
- results = []
33
- for i, (idx, score) in enumerate(zip(top_ids, top_scores), 1):
34
- cwe_child = id2label[str(idx)]
35
- ancestor = child_to_ancestor.get(cwe_child, "N/A")
36
- results.append(f"{i}. CWE-{cwe_child} (ancestor: CWE-{ancestor}) - {score:.4f}")
37
 
38
- return "\n".join(results)
 
 
 
 
 
 
39
 
40
- gr.Interface(
41
- fn=predict_ancestors,
42
- inputs=gr.Textbox(label="Commit message or patch (e.g., 'hg')"),
43
- outputs=gr.Textbox(label="Top 5 Predicted CWE Ancestors"),
44
- title="CWE Ancestor Predictor",
45
- description="Entrez un message de commit ou un patch. Le modèle prédit les 5 CWE enfants les plus probables et affiche leurs ancêtres."
46
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import json
3
+ from transformers import pipeline
4
+ import torch
5
+ import random
6
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ torch.manual_seed(42)
9
+ random.seed(42)
10
+ np.random.seed(42)
11
 
12
+ torch.use_deterministic_algorithms(True)
 
 
 
 
13
 
14
+ # Load Hugging Face model (text classification)
15
+ classifier = pipeline(
16
+ task="text-classification",
17
+ model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
18
+ top_k=None
19
+ )
20
+ classifier.model.eval()
21
 
22
+ # Load child-to-parent mapping
23
+ with open("child_to_parent_mapping.json", "r") as f:
24
+ child_to_parent = json.load(f)
25
+
26
+ def predict_cwe(commit_message: str):
27
+ """
28
+ Predict CWE(s) from a commit message and map to parent CWEs.
29
+ """
30
+ results = classifier(commit_message)[0]
31
+ sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
32
+
33
+ threshold = 0.2
34
+ filtered_results = [item for item in sorted_results if item["score"] >= threshold]
35
+
36
+ # Map predictions to parent CWE (if available)
37
+ mapped_results = {}
38
+ for item in sorted_results[:5]:
39
+ mapped_results[item["label"]] = round(float(item["score"]), 4)
40
+ return mapped_results
41
+
42
+ # Gradio UI
43
+ demo = gr.Interface(
44
+ fn=predict_cwe,
45
+ inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
46
+ outputs=gr.Label(num_top_classes=5),
47
+ title="CWE Prediction from Commit Message",
48
+ description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
49
+ "Predicted child CWEs are mapped to their parent CWEs if applicable.",
50
+ examples=[
51
+ ["Fixed buffer overflow in input parsing"],
52
+ ["SQL injection possible in login flow"],
53
+ ["Improved input validation to prevent XSS"],
54
+ ]
55
+ )
56
+
57
+ if __name__ == "__main__":
58
+ demo.launch()