elselse commited on
Commit
e93230f
·
verified ·
1 Parent(s): 0ba47f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -15,14 +15,10 @@ torch.use_deterministic_algorithms(True)
15
  classifier = pipeline(
16
  task="text-classification",
17
  model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
18
- return_all_scores=True
19
  )
20
  classifier.model.eval()
21
 
22
- threshold = 0.2
23
- filtered_results = [item for item in sorted_results if item["score"] >= threshold]
24
-
25
-
26
  # Load child-to-parent mapping
27
  with open("child_to_parent_mapping.json", "r") as f:
28
  child_to_parent = json.load(f)
@@ -33,6 +29,9 @@ def predict_cwe(commit_message: str):
33
  """
34
  results = classifier(commit_message)[0]
35
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
 
 
 
36
 
37
  # Map predictions to parent CWE (if available)
38
  mapped_results = {}
 
15
  classifier = pipeline(
16
  task="text-classification",
17
  model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
18
+ top_k=None
19
  )
20
  classifier.model.eval()
21
 
 
 
 
 
22
  # Load child-to-parent mapping
23
  with open("child_to_parent_mapping.json", "r") as f:
24
  child_to_parent = json.load(f)
 
29
  """
30
  results = classifier(commit_message)[0]
31
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
32
+
33
+ threshold = 0.2
34
+ filtered_results = [item for item in sorted_results if item["score"] >= threshold]
35
 
36
  # Map predictions to parent CWE (if available)
37
  mapped_results = {}