Spaces:
Sleeping
Sleeping
Updated interface after checkpoint
Browse files- inference.py +31 -138
inference.py
CHANGED
|
@@ -1,20 +1,9 @@
|
|
| 1 |
# inference.py
|
| 2 |
-
#
|
| 3 |
-
#
|
| 4 |
-
# to be loaded externally before calling analyze_clause().
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
-
import torch
|
| 8 |
|
| 9 |
-
# ── These globals must be set by app.py before calling analyze_clause ──
|
| 10 |
-
model = None
|
| 11 |
-
tokenizer = None
|
| 12 |
-
clause_mlb = None
|
| 13 |
-
risk_mlb = None
|
| 14 |
-
feature_extractor = None
|
| 15 |
-
device = None
|
| 16 |
-
|
| 17 |
-
RISK_LEVEL_ORDER = {"Low": 0, "Medium": 1, "High": 2}
|
| 18 |
IP_CLAUSE_TYPES = {
|
| 19 |
"IP Ownership Assignment", "Joint IP Ownership",
|
| 20 |
"Irrevocable Or Perpetual License",
|
|
@@ -22,9 +11,18 @@ IP_CLAUSE_TYPES = {
|
|
| 22 |
}
|
| 23 |
|
| 24 |
|
| 25 |
-
def _symbolic_rule_score(features: dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
triggered, total = [], 0.0
|
| 27 |
-
for rule in
|
| 28 |
try:
|
| 29 |
if rule["condition"](features):
|
| 30 |
triggered.append(rule)
|
|
@@ -37,137 +35,32 @@ def _symbolic_rule_score(features: dict, SYMBOLIC_RULES: list) -> dict:
|
|
| 37 |
}
|
| 38 |
|
| 39 |
|
| 40 |
-
def _neuro_symbolic_fusion(
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if is_ip_clause and symbolic > 0:
|
| 43 |
w_n, w_s = 0.35, 0.65
|
| 44 |
else:
|
| 45 |
w_n, w_s = 0.60, 0.40
|
|
|
|
| 46 |
score = w_n * neural + w_s * symbolic
|
| 47 |
if symbolic > 0:
|
| 48 |
-
score = max(score, 0.30)
|
| 49 |
score = round(min(score, 1.0), 3)
|
|
|
|
| 50 |
if score <= 0.33: level, emoji = "Low", "🟢"
|
| 51 |
elif score <= 0.66: level, emoji = "Medium", "🟡"
|
| 52 |
else: level, emoji = "High", "🔴"
|
| 53 |
-
return {"score": score, "level": level, "emoji": emoji}
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def analyze_clause(text: str, SYMBOLIC_RULES: list) -> dict:
|
| 57 |
-
"""
|
| 58 |
-
Run full neuro-symbolic inference on a single clause text.
|
| 59 |
-
Returns a clean dict with risk_score, risk_level, triggered_rules, etc.
|
| 60 |
-
Requires module-level globals to be set (model, tokenizer, ...).
|
| 61 |
-
"""
|
| 62 |
-
if model is None:
|
| 63 |
-
raise RuntimeError("Model not loaded. Call load_model() in app.py first.")
|
| 64 |
-
|
| 65 |
-
enc = tokenizer(
|
| 66 |
-
text, padding="max_length", truncation=True,
|
| 67 |
-
max_length=512, return_tensors="pt"
|
| 68 |
-
)
|
| 69 |
-
model.eval()
|
| 70 |
-
with torch.no_grad():
|
| 71 |
-
clause_logits, risk_logits, risk_score_tensor, _, _ = model(
|
| 72 |
-
enc["input_ids"].to(device),
|
| 73 |
-
enc["attention_mask"].to(device),
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
clause_probs = torch.sigmoid(clause_logits).cpu().numpy()[0]
|
| 77 |
-
top3_idx = clause_probs.argsort()[::-1][:3]
|
| 78 |
-
top_clauses = [
|
| 79 |
-
(clause_mlb.classes_[i], round(float(clause_probs[i]), 3))
|
| 80 |
-
for i in top3_idx if clause_probs[i] > 0.05
|
| 81 |
-
]
|
| 82 |
-
|
| 83 |
-
risk_probs = torch.sigmoid(risk_logits).cpu().numpy()[0]
|
| 84 |
-
top2_idx = risk_probs.argsort()[::-1][:2]
|
| 85 |
-
top_risks = [
|
| 86 |
-
(risk_mlb.classes_[i], round(float(risk_probs[i]), 3))
|
| 87 |
-
for i in top2_idx if risk_probs[i] > 0.05
|
| 88 |
-
]
|
| 89 |
-
|
| 90 |
-
neural_score = round(float(risk_score_tensor.item()), 3)
|
| 91 |
-
features = feature_extractor.extract(text)
|
| 92 |
-
sym_result = _symbolic_rule_score(features, SYMBOLIC_RULES)
|
| 93 |
-
|
| 94 |
-
top_clause_name = top_clauses[0][0] if top_clauses else ""
|
| 95 |
-
is_ip = top_clause_name in IP_CLAUSE_TYPES
|
| 96 |
-
fusion = _neuro_symbolic_fusion(neural_score, sym_result["symbolic_score"], is_ip)
|
| 97 |
|
| 98 |
-
|
| 99 |
-
{
|
| 100 |
-
"rule_id": r["rule_id"],
|
| 101 |
-
"name": r["name"],
|
| 102 |
-
"reference": r["reference"],
|
| 103 |
-
"penalty": r["penalty"],
|
| 104 |
-
"category": r["category"],
|
| 105 |
-
}
|
| 106 |
-
for r in sym_result["triggered_rules"]
|
| 107 |
-
]
|
| 108 |
-
|
| 109 |
-
return {
|
| 110 |
-
"risk_score": fusion["score"],
|
| 111 |
-
"neural_score": neural_score,
|
| 112 |
-
"symbolic_score": sym_result["symbolic_score"],
|
| 113 |
-
"risk_level": f"{fusion['emoji']} {fusion['level']}",
|
| 114 |
-
"risk_level_raw": fusion["level"],
|
| 115 |
-
"top_clauses": top_clauses,
|
| 116 |
-
"top_risk_cats": top_risks,
|
| 117 |
-
"triggered_rules": triggered_clean,
|
| 118 |
-
"features": {k: v for k, v in features.items() if v},
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# ── Document-level analysis (added for dashboard) ���───────────────────────────
|
| 123 |
-
def analyze_document(text: str, SYMBOLIC_RULES: list, max_clauses: int = 50) -> dict:
|
| 124 |
-
"""
|
| 125 |
-
Split text into clauses, run analyze_clause() on each, return document summary.
|
| 126 |
-
|
| 127 |
-
Returns:
|
| 128 |
-
{
|
| 129 |
-
"overall_risk": float, # weighted-max of fused scores
|
| 130 |
-
"overall_level": str, # Low / Medium / High
|
| 131 |
-
"num_clauses": int,
|
| 132 |
-
"top_risks": list[dict], # top 3 by risk_score
|
| 133 |
-
"clauses": list[dict], # all clause results + index + text
|
| 134 |
-
}
|
| 135 |
-
"""
|
| 136 |
-
from pdf_utils import split_into_clauses
|
| 137 |
-
|
| 138 |
-
clauses = split_into_clauses(text)[:max_clauses]
|
| 139 |
-
if not clauses:
|
| 140 |
-
clauses = [text[:2000]] # fallback: treat whole text as one clause
|
| 141 |
-
|
| 142 |
-
results = []
|
| 143 |
-
for idx, clause_text in enumerate(clauses):
|
| 144 |
-
try:
|
| 145 |
-
r = analyze_clause(clause_text, SYMBOLIC_RULES)
|
| 146 |
-
except Exception:
|
| 147 |
-
r = {
|
| 148 |
-
"risk_score": 0.0, "neural_score": 0.0, "symbolic_score": 0.0,
|
| 149 |
-
"risk_level": "🟢 Low", "risk_level_raw": "Low",
|
| 150 |
-
"top_clauses": [], "top_risk_cats": [],
|
| 151 |
-
"triggered_rules": [], "features": {},
|
| 152 |
-
}
|
| 153 |
-
r["clause_index"] = idx + 1
|
| 154 |
-
r["clause_text"] = clause_text
|
| 155 |
-
results.append(r)
|
| 156 |
-
|
| 157 |
-
scores = [r["risk_score"] for r in results]
|
| 158 |
-
|
| 159 |
-
# Overall = 70% max + 30% mean (punishes worst clause, not just average)
|
| 160 |
-
overall = round(0.70 * max(scores) + 0.30 * (sum(scores) / len(scores)), 3)
|
| 161 |
-
if overall <= 0.33: level = "Low"
|
| 162 |
-
elif overall <= 0.66: level = "Medium"
|
| 163 |
-
else: level = "High"
|
| 164 |
-
|
| 165 |
-
top_risks = sorted(results, key=lambda x: x["risk_score"], reverse=True)[:3]
|
| 166 |
-
|
| 167 |
-
return {
|
| 168 |
-
"overall_risk": overall,
|
| 169 |
-
"overall_level": level,
|
| 170 |
-
"num_clauses": len(results),
|
| 171 |
-
"top_risks": top_risks,
|
| 172 |
-
"clauses": results,
|
| 173 |
-
}
|
|
|
|
| 1 |
# inference.py
|
| 2 |
+
# Pure utility functions for neuro-symbolic fusion.
|
| 3 |
+
# No module-level mutable globals — all state lives in ModelManager (app.py).
|
|
|
|
| 4 |
|
| 5 |
from __future__ import annotations
|
|
|
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
IP_CLAUSE_TYPES = {
|
| 8 |
"IP Ownership Assignment", "Joint IP Ownership",
|
| 9 |
"Irrevocable Or Perpetual License",
|
|
|
|
| 11 |
}
|
| 12 |
|
| 13 |
|
| 14 |
+
def _symbolic_rule_score(features: dict, symbolic_rules: list) -> dict:
|
| 15 |
+
"""
|
| 16 |
+
Evaluate symbolic rules against extracted features.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
{
|
| 20 |
+
"symbolic_score": float, # clamped to [0, 1]
|
| 21 |
+
"triggered_rules": list[dict], # rules whose condition fired
|
| 22 |
+
}
|
| 23 |
+
"""
|
| 24 |
triggered, total = [], 0.0
|
| 25 |
+
for rule in symbolic_rules:
|
| 26 |
try:
|
| 27 |
if rule["condition"](features):
|
| 28 |
triggered.append(rule)
|
|
|
|
| 35 |
}
|
| 36 |
|
| 37 |
|
| 38 |
+
def _neuro_symbolic_fusion(
|
| 39 |
+
neural: float,
|
| 40 |
+
symbolic: float,
|
| 41 |
+
is_ip_clause: bool = False,
|
| 42 |
+
) -> dict:
|
| 43 |
+
"""
|
| 44 |
+
Weighted fusion of neural and symbolic scores.
|
| 45 |
+
|
| 46 |
+
IP clauses shift weight toward symbolic rules (which capture IP-specific law).
|
| 47 |
+
Ensures score is non-trivially low when symbolic rules fire.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
{ "score": float, "level": str, "emoji": str }
|
| 51 |
+
"""
|
| 52 |
if is_ip_clause and symbolic > 0:
|
| 53 |
w_n, w_s = 0.35, 0.65
|
| 54 |
else:
|
| 55 |
w_n, w_s = 0.60, 0.40
|
| 56 |
+
|
| 57 |
score = w_n * neural + w_s * symbolic
|
| 58 |
if symbolic > 0:
|
| 59 |
+
score = max(score, 0.30) # symbolic trigger → at least Medium
|
| 60 |
score = round(min(score, 1.0), 3)
|
| 61 |
+
|
| 62 |
if score <= 0.33: level, emoji = "Low", "🟢"
|
| 63 |
elif score <= 0.66: level, emoji = "Medium", "🟡"
|
| 64 |
else: level, emoji = "High", "🔴"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
return {"score": score, "level": level, "emoji": emoji}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|