riyasuryawanshi746 commited on
Commit
6f39cfb
·
verified ·
1 Parent(s): ff5e155

Updated interface after checkpoint

Browse files
Files changed (1) hide show
  1. inference.py +31 -138
inference.py CHANGED
@@ -1,20 +1,9 @@
1
  # inference.py
2
- # Wraps enterprise_inference from clauseXplain v5.0
3
- # Requires: model, tokenizer, clause_mlb, risk_mlb, feature_extractor
4
- # to be loaded externally before calling analyze_clause().
5
 
6
  from __future__ import annotations
7
- import torch
8
 
9
- # ── These globals must be set by app.py before calling analyze_clause ──
10
- model = None
11
- tokenizer = None
12
- clause_mlb = None
13
- risk_mlb = None
14
- feature_extractor = None
15
- device = None
16
-
17
- RISK_LEVEL_ORDER = {"Low": 0, "Medium": 1, "High": 2}
18
  IP_CLAUSE_TYPES = {
19
  "IP Ownership Assignment", "Joint IP Ownership",
20
  "Irrevocable Or Perpetual License",
@@ -22,9 +11,18 @@ IP_CLAUSE_TYPES = {
22
  }
23
 
24
 
25
- def _symbolic_rule_score(features: dict, SYMBOLIC_RULES: list) -> dict:
 
 
 
 
 
 
 
 
 
26
  triggered, total = [], 0.0
27
- for rule in SYMBOLIC_RULES:
28
  try:
29
  if rule["condition"](features):
30
  triggered.append(rule)
@@ -37,137 +35,32 @@ def _symbolic_rule_score(features: dict, SYMBOLIC_RULES: list) -> dict:
37
  }
38
 
39
 
40
- def _neuro_symbolic_fusion(neural: float, symbolic: float,
41
- is_ip_clause: bool = False) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
42
  if is_ip_clause and symbolic > 0:
43
  w_n, w_s = 0.35, 0.65
44
  else:
45
  w_n, w_s = 0.60, 0.40
 
46
  score = w_n * neural + w_s * symbolic
47
  if symbolic > 0:
48
- score = max(score, 0.30)
49
  score = round(min(score, 1.0), 3)
 
50
  if score <= 0.33: level, emoji = "Low", "🟢"
51
  elif score <= 0.66: level, emoji = "Medium", "🟡"
52
  else: level, emoji = "High", "🔴"
53
- return {"score": score, "level": level, "emoji": emoji}
54
-
55
-
56
- def analyze_clause(text: str, SYMBOLIC_RULES: list) -> dict:
57
- """
58
- Run full neuro-symbolic inference on a single clause text.
59
- Returns a clean dict with risk_score, risk_level, triggered_rules, etc.
60
- Requires module-level globals to be set (model, tokenizer, ...).
61
- """
62
- if model is None:
63
- raise RuntimeError("Model not loaded. Call load_model() in app.py first.")
64
-
65
- enc = tokenizer(
66
- text, padding="max_length", truncation=True,
67
- max_length=512, return_tensors="pt"
68
- )
69
- model.eval()
70
- with torch.no_grad():
71
- clause_logits, risk_logits, risk_score_tensor, _, _ = model(
72
- enc["input_ids"].to(device),
73
- enc["attention_mask"].to(device),
74
- )
75
-
76
- clause_probs = torch.sigmoid(clause_logits).cpu().numpy()[0]
77
- top3_idx = clause_probs.argsort()[::-1][:3]
78
- top_clauses = [
79
- (clause_mlb.classes_[i], round(float(clause_probs[i]), 3))
80
- for i in top3_idx if clause_probs[i] > 0.05
81
- ]
82
-
83
- risk_probs = torch.sigmoid(risk_logits).cpu().numpy()[0]
84
- top2_idx = risk_probs.argsort()[::-1][:2]
85
- top_risks = [
86
- (risk_mlb.classes_[i], round(float(risk_probs[i]), 3))
87
- for i in top2_idx if risk_probs[i] > 0.05
88
- ]
89
-
90
- neural_score = round(float(risk_score_tensor.item()), 3)
91
- features = feature_extractor.extract(text)
92
- sym_result = _symbolic_rule_score(features, SYMBOLIC_RULES)
93
-
94
- top_clause_name = top_clauses[0][0] if top_clauses else ""
95
- is_ip = top_clause_name in IP_CLAUSE_TYPES
96
- fusion = _neuro_symbolic_fusion(neural_score, sym_result["symbolic_score"], is_ip)
97
 
98
- triggered_clean = [
99
- {
100
- "rule_id": r["rule_id"],
101
- "name": r["name"],
102
- "reference": r["reference"],
103
- "penalty": r["penalty"],
104
- "category": r["category"],
105
- }
106
- for r in sym_result["triggered_rules"]
107
- ]
108
-
109
- return {
110
- "risk_score": fusion["score"],
111
- "neural_score": neural_score,
112
- "symbolic_score": sym_result["symbolic_score"],
113
- "risk_level": f"{fusion['emoji']} {fusion['level']}",
114
- "risk_level_raw": fusion["level"],
115
- "top_clauses": top_clauses,
116
- "top_risk_cats": top_risks,
117
- "triggered_rules": triggered_clean,
118
- "features": {k: v for k, v in features.items() if v},
119
- }
120
-
121
-
122
- # ── Document-level analysis (added for dashboard) ���───────────────────────────
123
- def analyze_document(text: str, SYMBOLIC_RULES: list, max_clauses: int = 50) -> dict:
124
- """
125
- Split text into clauses, run analyze_clause() on each, return document summary.
126
-
127
- Returns:
128
- {
129
- "overall_risk": float, # weighted-max of fused scores
130
- "overall_level": str, # Low / Medium / High
131
- "num_clauses": int,
132
- "top_risks": list[dict], # top 3 by risk_score
133
- "clauses": list[dict], # all clause results + index + text
134
- }
135
- """
136
- from pdf_utils import split_into_clauses
137
-
138
- clauses = split_into_clauses(text)[:max_clauses]
139
- if not clauses:
140
- clauses = [text[:2000]] # fallback: treat whole text as one clause
141
-
142
- results = []
143
- for idx, clause_text in enumerate(clauses):
144
- try:
145
- r = analyze_clause(clause_text, SYMBOLIC_RULES)
146
- except Exception:
147
- r = {
148
- "risk_score": 0.0, "neural_score": 0.0, "symbolic_score": 0.0,
149
- "risk_level": "🟢 Low", "risk_level_raw": "Low",
150
- "top_clauses": [], "top_risk_cats": [],
151
- "triggered_rules": [], "features": {},
152
- }
153
- r["clause_index"] = idx + 1
154
- r["clause_text"] = clause_text
155
- results.append(r)
156
-
157
- scores = [r["risk_score"] for r in results]
158
-
159
- # Overall = 70% max + 30% mean (punishes worst clause, not just average)
160
- overall = round(0.70 * max(scores) + 0.30 * (sum(scores) / len(scores)), 3)
161
- if overall <= 0.33: level = "Low"
162
- elif overall <= 0.66: level = "Medium"
163
- else: level = "High"
164
-
165
- top_risks = sorted(results, key=lambda x: x["risk_score"], reverse=True)[:3]
166
-
167
- return {
168
- "overall_risk": overall,
169
- "overall_level": level,
170
- "num_clauses": len(results),
171
- "top_risks": top_risks,
172
- "clauses": results,
173
- }
 
1
  # inference.py
2
+ # Pure utility functions for neuro-symbolic fusion.
3
+ # No module-level mutable globals all state lives in ModelManager (app.py).
 
4
 
5
  from __future__ import annotations
 
6
 
 
 
 
 
 
 
 
 
 
7
  IP_CLAUSE_TYPES = {
8
  "IP Ownership Assignment", "Joint IP Ownership",
9
  "Irrevocable Or Perpetual License",
 
11
  }
12
 
13
 
14
+ def _symbolic_rule_score(features: dict, symbolic_rules: list) -> dict:
15
+ """
16
+ Evaluate symbolic rules against extracted features.
17
+
18
+ Returns:
19
+ {
20
+ "symbolic_score": float, # clamped to [0, 1]
21
+ "triggered_rules": list[dict], # rules whose condition fired
22
+ }
23
+ """
24
  triggered, total = [], 0.0
25
+ for rule in symbolic_rules:
26
  try:
27
  if rule["condition"](features):
28
  triggered.append(rule)
 
35
  }
36
 
37
 
38
+ def _neuro_symbolic_fusion(
39
+ neural: float,
40
+ symbolic: float,
41
+ is_ip_clause: bool = False,
42
+ ) -> dict:
43
+ """
44
+ Weighted fusion of neural and symbolic scores.
45
+
46
+ IP clauses shift weight toward symbolic rules (which capture IP-specific law).
47
+ Ensures score is non-trivially low when symbolic rules fire.
48
+
49
+ Returns:
50
+ { "score": float, "level": str, "emoji": str }
51
+ """
52
  if is_ip_clause and symbolic > 0:
53
  w_n, w_s = 0.35, 0.65
54
  else:
55
  w_n, w_s = 0.60, 0.40
56
+
57
  score = w_n * neural + w_s * symbolic
58
  if symbolic > 0:
59
+ score = max(score, 0.30) # symbolic trigger → at least Medium
60
  score = round(min(score, 1.0), 3)
61
+
62
  if score <= 0.33: level, emoji = "Low", "🟢"
63
  elif score <= 0.66: level, emoji = "Medium", "🟡"
64
  else: level, emoji = "High", "🔴"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ return {"score": score, "level": level, "emoji": emoji}