SCBconsulting commited on
Commit
576eef7
·
verified ·
1 Parent(s): 1e80870

Update utils/risk_detector.py

Browse files
Files changed (1) hide show
  1. utils/risk_detector.py +19 -3
utils/risk_detector.py CHANGED
@@ -2,14 +2,30 @@
2
 
3
  from transformers import pipeline
4
 
 
5
  classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
 
 
6
  labels = ["Indemnity", "Exclusivity", "Termination", "Jurisdiction", "Confidentiality"]
7
 
8
- def detect_risks(text):
9
- if not text:
 
 
 
 
 
 
 
 
10
  return []
11
 
12
  result = classifier(text[:1000], candidate_labels=labels, multi_label=True)
13
 
14
- # Package into list of [label, score] for Gradio dataframe
 
 
 
 
 
15
  return list(zip(result["labels"], result["scores"]))
 
2
 
3
  from transformers import pipeline
4
 
5
+ # ⚖️ Load zero-shot classification model
6
  classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
7
+
8
+ # 🎯 Define risk-related labels
9
  labels = ["Indemnity", "Exclusivity", "Termination", "Jurisdiction", "Confidentiality"]
10
 
11
+ def detect_risks(text, verbose=False):
12
+ """
13
+ Classify clauses into predefined legal risk categories.
14
+ If verbose=True, include detailed scores for each label.
15
+
16
+ Returns:
17
+ - List of (label, score) tuples (default)
18
+ - Or dict of full model output if verbose
19
+ """
20
+ if not text.strip():
21
  return []
22
 
23
  result = classifier(text[:1000], candidate_labels=labels, multi_label=True)
24
 
25
+ if verbose:
26
+ return {
27
+ "sequence": result["sequence"],
28
+ "predictions": list(zip(result["labels"], result["scores"]))
29
+ }
30
+
31
  return list(zip(result["labels"], result["scores"]))