elselse commited on
Commit
4c74307
·
verified ·
1 Parent(s): 0bfd012

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -65
app.py CHANGED
@@ -1,70 +1,53 @@
1
  import gradio as gr
2
- import json
3
- from transformers import pipeline
4
  import torch
5
- import random
6
- import numpy as np
7
-
8
- torch.manual_seed(42)
9
- random.seed(42)
10
- np.random.seed(42)
11
-
12
- torch.use_deterministic_algorithms(True)
13
 
14
- # Load Hugging Face model (text classification)
15
- classifier = pipeline(
16
- task="text-classification",
17
- model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
18
- top_k=None
19
- )
20
- classifier.model.eval()
21
 
22
- # Load child-to-parent mapping
23
  with open("child_to_parent_mapping.json", "r") as f:
24
- child_to_parent = json.load(f)
25
-
26
- def predict_cwe(commit_message: str):
27
- """
28
- Predict CWE(s) from a commit message and map to parent CWEs if applicable.
29
- """
30
- results = classifier(commit_message)[0]
31
- sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
32
-
33
- threshold = 0.2
34
- filtered_results = [item for item in sorted_results if item["score"] >= threshold]
35
-
36
- parent_scores = {}
37
-
38
- for item in filtered_results:
39
- label = item["label"].replace("CWEEEE_", "")
40
- score = float(item["score"])
41
-
42
- parent_label = child_to_parent.get(label, label)
43
-
44
- if parent_label in parent_scores:
45
- parent_scores[parent_label] += score
46
- else:
47
- parent_scores[parent_label] = score
48
-
49
- # Sort by score descending and round
50
- sorted_parent_scores = sorted(parent_scores.items(), key=lambda x: x[1], reverse=True)
51
- return {f"CWE-{k}": round(v, 4) for k, v in sorted_parent_scores[:5]}
52
-
53
-
54
- # Gradio UI
55
- demo = gr.Interface(
56
- fn=predict_cwe,
57
- inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
58
- outputs=gr.Label(num_top_classes=5),
59
- title="CWE Prediction from Commit Message",
60
- description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
61
- "Predicted child CWEs are mapped to their parent CWEs if applicable.",
62
- examples=[
63
- ["Fixed buffer overflow in input parsing"],
64
- ["SQL injection possible in login flow"],
65
- ["Improved input validation to prevent XSS"],
66
- ]
67
- )
68
-
69
- if __name__ == "__main__":
70
- demo.launch()
 
1
  import gradio as gr
 
 
2
  import torch
3
+ import json
4
+ import base64
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
 
6
 
7
+ model_path = "CIRCL/cwe-parent-vulnerability-classification-roberta-base"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
10
+ model.eval()
 
 
 
11
 
 
12
  with open("child_to_parent_mapping.json", "r") as f:
13
+ child_to_ancestor = json.load(f)
14
+
15
+ with open(f"{model_path}/config.json", "r") as f:
16
+ config = json.load(f)
17
+ id2label = config["id2label"]
18
+
19
+ # Fonction d'extraction pour simuler une entrée formatée
20
+ def extract_commit_text_hg_style(input_text):
21
+ # Ici, on pourrait simuler un vrai patch ou commit. Pour l’instant, on prend l’entrée brute.
22
+ return input_text.strip()
23
+
24
+ # Fonction Gradio de prédiction
25
+ def predict_ancestors(input_text):
26
+ text = extract_commit_text_hg_style(input_text)
27
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
28
+
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits
32
+ probs = torch.softmax(logits, dim=-1)
33
+
34
+ topk = torch.topk(probs, k=5)
35
+ top_ids = topk.indices[0].tolist()
36
+ top_scores = topk.values[0].tolist()
37
+
38
+ results = []
39
+ for i, (idx, score) in enumerate(zip(top_ids, top_scores), 1):
40
+ cwe_child = id2label[str(idx)]
41
+ ancestor = child_to_ancestor.get(cwe_child, "N/A")
42
+ results.append(f"{i}. CWE-{cwe_child} (ancestor: CWE-{ancestor}) - {score:.4f}")
43
+
44
+ return results
45
+
46
+ # Interface Gradio
47
+ gr.Interface(
48
+ fn=predict_ancestors,
49
+ inputs=gr.Textbox(label="Commit message or patch (e.g., 'hg')"),
50
+ outputs=gr.outputs.Textbox(label="Top 5 Predicted CWE Ancestors"),
51
+ title="CWE Ancestor Predictor",
52
+ description="Entrez un message de commit ou un patch. Le modèle prédit les 5 CWE enfants les plus probables et affiche leurs ancêtres."
53
+ ).launch()