elselse commited on
Commit
daf9b52
·
verified ·
1 Parent(s): 9ccb1c5

back to ancient version

Browse files
Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -1,35 +1,55 @@
1
- from transformers import pipeline
2
- import json
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
4
 
 
5
  classifier = pipeline(
6
  task="text-classification",
7
  model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
8
  top_k=None
9
  )
 
10
 
 
11
  with open("child_to_parent_mapping.json", "r") as f:
12
  child_to_parent = json.load(f)
13
 
14
  def predict_cwe(commit_message: str):
 
 
 
15
  results = classifier(commit_message)[0]
 
 
16
  threshold = 0.2
17
- filtered = [r for r in results if r["score"] >= threshold]
18
 
19
- mapped = {}
20
- for r in filtered:
21
- cwe_id = r["label"].replace("CWE-", "")
22
- parent_id = child_to_parent.get(cwe_id, cwe_id)
23
- mapped[f"CWE-{parent_id}"] = round(float(r["score"]), 4)
 
24
 
25
- return mapped
26
 
 
27
  demo = gr.Interface(
28
  fn=predict_cwe,
29
- inputs=gr.Textbox(lines=3, placeholder="Enter your commit message or vulnerability description here..."),
30
  outputs=gr.Label(num_top_classes=5),
31
- title="CWE Prediction from Commit Message Or Description",
32
- description="Predict top CWE parent classes from a commit message or description.",
 
33
  examples=[
34
  ["Fixed buffer overflow in input parsing"],
35
  ["SQL injection possible in login flow"],
 
 
 
1
  import gradio as gr
2
+ import json
3
+ from transformers import pipeline
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+
8
+ torch.manual_seed(42)
9
+ random.seed(42)
10
+ np.random.seed(42)
11
+
12
+ torch.use_deterministic_algorithms(True)
13
 
14
+ # Load Hugging Face model (text classification)
15
  classifier = pipeline(
16
  task="text-classification",
17
  model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
18
  top_k=None
19
  )
20
+ classifier.model.eval()
21
 
22
+ # Load child-to-parent mapping
23
  with open("child_to_parent_mapping.json", "r") as f:
24
  child_to_parent = json.load(f)
25
 
26
  def predict_cwe(commit_message: str):
27
+ """
28
+ Predict CWE(s) from a commit message and map to parent CWEs.
29
+ """
30
  results = classifier(commit_message)[0]
31
+ sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
32
+
33
  threshold = 0.2
34
+ filtered_results = [item for item in sorted_results if item["score"] >= threshold]
35
 
36
+ # Map predictions to parent CWE (if available)
37
+ mapped_results = {}
38
+ for item in sorted_results[:5]:
39
+ child_cwe = item["label"].replace("CWE-", "")
40
+ parent_cwe = child_to_parent.get(child_cwe, child_cwe) # default to child if no parent
41
+ mapped_results[f"CWE-{parent_cwe}"] = round(float(item["score"]), 4)
42
 
43
+ return mapped_results
44
 
45
+ # Gradio UI
46
  demo = gr.Interface(
47
  fn=predict_cwe,
48
+ inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
49
  outputs=gr.Label(num_top_classes=5),
50
+ title="CWE Prediction from Commit Message",
51
+ description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
52
+ "Predicted child CWEs are mapped to their parent CWEs if applicable.",
53
  examples=[
54
  ["Fixed buffer overflow in input parsing"],
55
  ["SQL injection possible in login flow"],