Spaces:
Running
Running
File size: 9,858 Bytes
ef26a79 f96633d ef26a79 f96633d ef26a79 f96633d ef26a79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import re, math, torch
from transformers import pipeline
import spaces
# ------------- Model (CPU-friendly); use device=0 + fp16 on GPU -------------
ZSHOT = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
multi_label=True,
device=-1,
model_kwargs={"torch_dtype": torch.float32}
)
# ------------------ Taxonomy with descriptions (helps NLI) -------------------
TAXO = {
"intent_type": [
"objective: declares goals or aims",
"principle: states guiding values",
"strategy: outlines measures or actions",
"obligation: mandates an action (shall/must)",
"prohibition: forbids an action",
"permission: allows an action (may)",
"exception: states conditions where rules change",
"definition: defines a term",
"scope: states applicability or coverage"
],
"disposition": [
"restrictive: limits or constrains the topic",
"cautionary: warns or urges care",
"neutral: descriptive with no clear stance",
"enabling: allows or facilitates the topic",
"supportive: promotes or expands the topic"
],
"rigidity": [
"must: mandatory (shall/must)",
"should: advisory (should)",
"may: permissive (may/can)"
],
"temporal": [
"deadline: requires completion by a date or period",
"schedule: sets a cadence (e.g., annually, quarterly)",
"ongoing: continuing requirement without end date",
"effective_date: specifies when rules start/apply"
],
"scope": [
"actor_specific: targets a group or entity (e.g., county governments, permit holders)",
"geography_specific: targets a location or region",
"subject_specific: targets a topic (e.g., permits, sanitation)",
"nationwide: applies across the country"
],
"enforcement": [
"penalty: fines or sanctions for non-compliance",
"remedy: corrective actions required",
"monitoring: oversight or audits",
"reporting: reports/returns required",
"none_detected: no enforcement mechanisms present"
],
"resourcing": [
"funding: funds or budget allocations",
"fees_levies: charges or levies",
"capacity_hr: staffing or training",
"infrastructure: capital works or equipment",
"none_detected: no resourcing present"
],
"impact": [
"low: limited effect on regulated parties",
"medium: moderate practical effect",
"high: significant obligations or restrictions"
]
}
# ---------------- Axis-specific thresholds (calibrate later) -----------------
TAU = {
"intent_type": 0.55, "disposition": 0.55, "rigidity": 0.60,
"temporal": 0.62, "scope": 0.55,
"enforcement": 0.50, "resourcing": 0.50, "impact": 0.60
}
TAU_LOW = 0.40 # only for deciding if we can safely emit "none_detected"
# ------------------------- Cleaning & evidence rules -------------------------
def _clean(t: str) -> str:
t = re.sub(r"[ \t]*\n[ \t]*", " ", str(t))
t = re.sub(r"\s{2,}", " ", t).strip()
return t
PAT = {
"actor": r"\bCounty Government(?:s)?\b|\bAuthority\b|\bMinistry\b|\bAgency\b|\bBoard\b|\bCommission\b",
"nationwide": r"\bKenya\b|\bnational\b|\bnationwide\b|\bacross the country\b|\bthe country\b",
"objective": r"\b(Objective[s]?|Purpose)\b|(?:^|\.\s+)To [A-Za-z]",
"imperative": r"(?:^|\.\s+)(Promote|Ensure|Encourage|Strengthen|Adopt)\b.*?(?:\.|;)",
"modal_must": r"\bshall\b|\bmust\b",
"modal_should": r"\bshould\b",
"modal_may": r"\bmay\b|\bcan\b",
"temporal": r"\bwithin \d+\s+(day|days|month|months|year|years)\b|\bby \d{4}\b|\beffective\b",
"enforcement": r"\bpenalt(y|ies)\b|\bfine(s)?\b|\brevocation\b|\bsuspension\b|\breport(ing)?\b|\bmonitor(ing)?\b",
"resourcing": r"\bfund(?:ing)?\b|\blevy|levies|fee(s)?\b|\bbudget\b|\binfrastructure\b|\bcapacity\b|\btraining\b"
}
def _spans(text, pattern, max_spans=2):
spans = []
for m in re.finditer(pattern, text, flags=re.I):
# sentence-level extraction
start = text.rfind('.', 0, m.start()) + 1
end = text.find('.', m.end())
if end == -1: end = len(text)
snippet = text[start:end].strip()
if snippet and snippet not in spans:
spans.append(snippet)
if len(spans) >= max_spans: break
return spans
def _softmax(d):
vals = list(d.values())
if not vals: return {k: 0.0 for k in d}
m = max(vals)
exps = [math.exp(v - m) for v in vals]
Z = sum(exps)
return {k: (e / Z) for k, e in zip(d.keys(), exps)}
# -------------------- Main: classify + explanations + % ----------------------
def classify_and_explain(text: str, topic: str = "water and sanitation", per_axis_top_k=2):
text = _clean(text)
if not text:
return {"decision_summary": "No operative decision; empty passage.",
"labels": {ax: [] for ax in TAXO},
"percents_raw": {ax: {} for ax in TAXO},
"percents_norm": {ax: {} for ax in TAXO},
"why": [], "text_preview": ""}
# Topic-aware hypotheses (improves stance/intent)
def hyp(axis):
base = "This passage {} regarding " + topic + "."
return {
"intent_type": base.format("states a {}"),
"disposition": base.format("is {}"),
"rigidity": "Compliance in this passage is {}.",
"temporal": base.format("specifies a {} aspect"),
"scope": base.format("is {} in applicability"),
"enforcement": base.format("includes {} for compliance"),
"resourcing": base.format("provides {}"),
"impact": base.format("has {} impact")
}[axis]
# Single call if supported; else per-axis fallback
tasks = [{"sequences": text, "candidate_labels": labels, "hypothesis_template": hyp(axis)}
for axis, labels in TAXO.items()]
try:
results = ZSHOT(tasks)
except TypeError:
results = [ZSHOT(text, labels, hypothesis_template=hyp(axis))
for axis, labels in TAXO.items()]
labels_out, perc_raw, perc_norm, why = {}, {}, {}, []
for (axis, labels), r in zip(TAXO.items(), results):
# raw scores
raw = {lbl.split(":")[0].strip(): float(s) for lbl, s in zip(r["labels"], r["scores"])}
perc_raw[axis] = {k: round(raw[k]*100, 1) for k in raw} # independent sigmoid
norm = _softmax(raw)
perc_norm[axis] = {k: round(norm[k]*100, 1) for k in norm} # sums ~100%
# select labels by threshold
keep = [k for k, s in raw.items() if s >= TAU[axis]]
keep = sorted(keep, key=lambda k: raw[k], reverse=True)[:per_axis_top_k]
# only emit none_detected when everything else is weak and no heuristic evidence
if not keep and "none_detected" in raw:
if max([v for k, v in raw.items() if k != "none_detected"] or [0.0]) < TAU_LOW:
keep = ["none_detected"]
labels_out[axis] = keep
# compact "why" with evidence for the top choice
if keep and keep[0] != "none_detected":
if axis == "intent_type":
ev = _spans(text, PAT["objective"]) or _spans(text, PAT["imperative"])
why.append({"axis": axis, "label": keep[0], "reason": "functional cues", "evidence": ev[:2]})
elif axis == "disposition":
ev = _spans(text, PAT["imperative"])
why.append({"axis": axis, "label": keep[0], "reason": "promotional/allowing framing", "evidence": ev[:2]})
elif axis == "rigidity":
pat = {"must": PAT["modal_must"], "should": PAT["modal_should"], "may": PAT["modal_may"]}[keep[0]]
why.append({"axis": axis, "label": keep[0], "reason": "modal verb", "evidence": _spans(text, pat)[:2]})
elif axis == "temporal":
why.append({"axis": axis, "label": keep[0], "reason": "time expressions", "evidence": _spans(text, PAT["temporal"])[:2]})
elif axis == "scope":
ev = _spans(text, PAT["nationwide"]) or _spans(text, PAT["actor"])
why.append({"axis": axis, "label": keep[0], "reason": "applicability cues", "evidence": ev[:2]})
elif axis == "enforcement":
why.append({"axis": axis, "label": keep[0], "reason": "compliance hooks", "evidence": _spans(text, PAT["enforcement"])[:2]})
elif axis == "resourcing":
why.append({"axis": axis, "label": keep[0], "reason": "resourcing hooks", "evidence": _spans(text, PAT["resourcing"])[:2]})
# Decision summary: imperative lines + problem statements; never fabricate
summary_bits = []
imperatives = re.findall(PAT["imperative"], text, flags=re.I)
# pull full imperative sentences
imp_sents = _spans(text, PAT["imperative"], max_spans=3)
if imp_sents:
summary_bits.append("Strategies: " + " ".join(imp_sents))
if "nationwide" in labels_out.get("scope", []):
summary_bits.append("Applies nationwide.")
if labels_out.get("enforcement") == ["none_detected"]:
summary_bits.append("Enforcement: none detected in this passage.")
if labels_out.get("resourcing") == ["none_detected"]:
summary_bits.append("Resourcing: none detected in this passage.")
decision_summary = " ".join(summary_bits) if summary_bits else "No operative decision beyond high-level description detected."
return {
"decision_summary": decision_summary,
"labels": labels_out,
"percents_raw": perc_raw, # model confidences per label (0–100, do NOT sum to 100)
"percents_norm": perc_norm, # normalized per axis (sums to ~100)
"why": why,
"text_preview": text[:300] + ("..." if len(text) > 300 else "")
}
# Get the sentiment for all the docs
@spaces.GPU(duration=120)
def get_sentiment(texts):
return [classify_and_explain(texts[i].page_content) for i in range(len(texts))] |