Spaces:
Running
Running
| import re, math, torch | |
| from transformers import pipeline | |
| import spaces | |
| # ------------- Model (CPU-friendly); use device=0 + fp16 on GPU ------------- | |
| ZSHOT = pipeline( | |
| "zero-shot-classification", | |
| model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0", | |
| multi_label=True, | |
| device=-1, | |
| model_kwargs={"torch_dtype": torch.float32} | |
| ) | |
| # ------------------ Taxonomy with descriptions (helps NLI) ------------------- | |
| TAXO = { | |
| "intent_type": [ | |
| "objective: declares goals or aims", | |
| "principle: states guiding values", | |
| "strategy: outlines measures or actions", | |
| "obligation: mandates an action (shall/must)", | |
| "prohibition: forbids an action", | |
| "permission: allows an action (may)", | |
| "exception: states conditions where rules change", | |
| "definition: defines a term", | |
| "scope: states applicability or coverage" | |
| ], | |
| "disposition": [ | |
| "restrictive: limits or constrains the topic", | |
| "cautionary: warns or urges care", | |
| "neutral: descriptive with no clear stance", | |
| "enabling: allows or facilitates the topic", | |
| "supportive: promotes or expands the topic" | |
| ], | |
| "rigidity": [ | |
| "must: mandatory (shall/must)", | |
| "should: advisory (should)", | |
| "may: permissive (may/can)" | |
| ], | |
| "temporal": [ | |
| "deadline: requires completion by a date or period", | |
| "schedule: sets a cadence (e.g., annually, quarterly)", | |
| "ongoing: continuing requirement without end date", | |
| "effective_date: specifies when rules start/apply" | |
| ], | |
| "scope": [ | |
| "actor_specific: targets a group or entity (e.g., county governments, permit holders)", | |
| "geography_specific: targets a location or region", | |
| "subject_specific: targets a topic (e.g., permits, sanitation)", | |
| "nationwide: applies across the country" | |
| ], | |
| "enforcement": [ | |
| "penalty: fines or sanctions for non-compliance", | |
| "remedy: corrective actions required", | |
| "monitoring: oversight or audits", | |
| "reporting: reports/returns required", | |
| "none_detected: no enforcement mechanisms present" | |
| ], | |
| "resourcing": [ | |
| "funding: funds or budget allocations", | |
| "fees_levies: charges or levies", | |
| "capacity_hr: staffing or training", | |
| "infrastructure: capital works or equipment", | |
| "none_detected: no resourcing present" | |
| ], | |
| "impact": [ | |
| "low: limited effect on regulated parties", | |
| "medium: moderate practical effect", | |
| "high: significant obligations or restrictions" | |
| ] | |
| } | |
| # ---------------- Axis-specific thresholds (calibrate later) ----------------- | |
| TAU = { | |
| "intent_type": 0.55, "disposition": 0.55, "rigidity": 0.60, | |
| "temporal": 0.62, "scope": 0.55, | |
| "enforcement": 0.50, "resourcing": 0.50, "impact": 0.60 | |
| } | |
| TAU_LOW = 0.40 # only for deciding if we can safely emit "none_detected" | |
| # ------------------------- Cleaning & evidence rules ------------------------- | |
| def _clean(t: str) -> str: | |
| t = re.sub(r"[ \t]*\n[ \t]*", " ", str(t)) | |
| t = re.sub(r"\s{2,}", " ", t).strip() | |
| return t | |
| PAT = { | |
| "actor": r"\bCounty Government(?:s)?\b|\bAuthority\b|\bMinistry\b|\bAgency\b|\bBoard\b|\bCommission\b", | |
| "nationwide": r"\bKenya\b|\bnational\b|\bnationwide\b|\bacross the country\b|\bthe country\b", | |
| "objective": r"\b(Objective[s]?|Purpose)\b|(?:^|\.\s+)To [A-Za-z]", | |
| "imperative": r"(?:^|\.\s+)(Promote|Ensure|Encourage|Strengthen|Adopt)\b.*?(?:\.|;)", | |
| "modal_must": r"\bshall\b|\bmust\b", | |
| "modal_should": r"\bshould\b", | |
| "modal_may": r"\bmay\b|\bcan\b", | |
| "temporal": r"\bwithin \d+\s+(day|days|month|months|year|years)\b|\bby \d{4}\b|\beffective\b", | |
| "enforcement": r"\bpenalt(y|ies)\b|\bfine(s)?\b|\brevocation\b|\bsuspension\b|\breport(ing)?\b|\bmonitor(ing)?\b", | |
| "resourcing": r"\bfund(?:ing)?\b|\blevy|levies|fee(s)?\b|\bbudget\b|\binfrastructure\b|\bcapacity\b|\btraining\b" | |
| } | |
| def _spans(text, pattern, max_spans=2): | |
| spans = [] | |
| for m in re.finditer(pattern, text, flags=re.I): | |
| # sentence-level extraction | |
| start = text.rfind('.', 0, m.start()) + 1 | |
| end = text.find('.', m.end()) | |
| if end == -1: end = len(text) | |
| snippet = text[start:end].strip() | |
| if snippet and snippet not in spans: | |
| spans.append(snippet) | |
| if len(spans) >= max_spans: break | |
| return spans | |
| def _softmax(d): | |
| vals = list(d.values()) | |
| if not vals: return {k: 0.0 for k in d} | |
| m = max(vals) | |
| exps = [math.exp(v - m) for v in vals] | |
| Z = sum(exps) | |
| return {k: (e / Z) for k, e in zip(d.keys(), exps)} | |
| # -------------------- Main: classify + explanations + % ---------------------- | |
| def classify_and_explain(text: str, topic: str = "water and sanitation", per_axis_top_k=2): | |
| text = _clean(text) | |
| if not text: | |
| return {"decision_summary": "No operative decision; empty passage.", | |
| "labels": {ax: [] for ax in TAXO}, | |
| "percents_raw": {ax: {} for ax in TAXO}, | |
| "percents_norm": {ax: {} for ax in TAXO}, | |
| "why": [], "text_preview": ""} | |
| # Topic-aware hypotheses (improves stance/intent) | |
| def hyp(axis): | |
| base = "This passage {} regarding " + topic + "." | |
| return { | |
| "intent_type": base.format("states a {}"), | |
| "disposition": base.format("is {}"), | |
| "rigidity": "Compliance in this passage is {}.", | |
| "temporal": base.format("specifies a {} aspect"), | |
| "scope": base.format("is {} in applicability"), | |
| "enforcement": base.format("includes {} for compliance"), | |
| "resourcing": base.format("provides {}"), | |
| "impact": base.format("has {} impact") | |
| }[axis] | |
| # Single call if supported; else per-axis fallback | |
| tasks = [{"sequences": text, "candidate_labels": labels, "hypothesis_template": hyp(axis)} | |
| for axis, labels in TAXO.items()] | |
| try: | |
| results = ZSHOT(tasks) | |
| except TypeError: | |
| results = [ZSHOT(text, labels, hypothesis_template=hyp(axis)) | |
| for axis, labels in TAXO.items()] | |
| labels_out, perc_raw, perc_norm, why = {}, {}, {}, [] | |
| for (axis, labels), r in zip(TAXO.items(), results): | |
| # raw scores | |
| raw = {lbl.split(":")[0].strip(): float(s) for lbl, s in zip(r["labels"], r["scores"])} | |
| perc_raw[axis] = {k: round(raw[k]*100, 1) for k in raw} # independent sigmoid | |
| norm = _softmax(raw) | |
| perc_norm[axis] = {k: round(norm[k]*100, 1) for k in norm} # sums ~100% | |
| # select labels by threshold | |
| keep = [k for k, s in raw.items() if s >= TAU[axis]] | |
| keep = sorted(keep, key=lambda k: raw[k], reverse=True)[:per_axis_top_k] | |
| # only emit none_detected when everything else is weak and no heuristic evidence | |
| if not keep and "none_detected" in raw: | |
| if max([v for k, v in raw.items() if k != "none_detected"] or [0.0]) < TAU_LOW: | |
| keep = ["none_detected"] | |
| labels_out[axis] = keep | |
| # compact "why" with evidence for the top choice | |
| if keep and keep[0] != "none_detected": | |
| if axis == "intent_type": | |
| ev = _spans(text, PAT["objective"]) or _spans(text, PAT["imperative"]) | |
| why.append({"axis": axis, "label": keep[0], "reason": "functional cues", "evidence": ev[:2]}) | |
| elif axis == "disposition": | |
| ev = _spans(text, PAT["imperative"]) | |
| why.append({"axis": axis, "label": keep[0], "reason": "promotional/allowing framing", "evidence": ev[:2]}) | |
| elif axis == "rigidity": | |
| pat = {"must": PAT["modal_must"], "should": PAT["modal_should"], "may": PAT["modal_may"]}[keep[0]] | |
| why.append({"axis": axis, "label": keep[0], "reason": "modal verb", "evidence": _spans(text, pat)[:2]}) | |
| elif axis == "temporal": | |
| why.append({"axis": axis, "label": keep[0], "reason": "time expressions", "evidence": _spans(text, PAT["temporal"])[:2]}) | |
| elif axis == "scope": | |
| ev = _spans(text, PAT["nationwide"]) or _spans(text, PAT["actor"]) | |
| why.append({"axis": axis, "label": keep[0], "reason": "applicability cues", "evidence": ev[:2]}) | |
| elif axis == "enforcement": | |
| why.append({"axis": axis, "label": keep[0], "reason": "compliance hooks", "evidence": _spans(text, PAT["enforcement"])[:2]}) | |
| elif axis == "resourcing": | |
| why.append({"axis": axis, "label": keep[0], "reason": "resourcing hooks", "evidence": _spans(text, PAT["resourcing"])[:2]}) | |
| # Decision summary: imperative lines + problem statements; never fabricate | |
| summary_bits = [] | |
| imperatives = re.findall(PAT["imperative"], text, flags=re.I) | |
| # pull full imperative sentences | |
| imp_sents = _spans(text, PAT["imperative"], max_spans=3) | |
| if imp_sents: | |
| summary_bits.append("Strategies: " + " ".join(imp_sents)) | |
| if "nationwide" in labels_out.get("scope", []): | |
| summary_bits.append("Applies nationwide.") | |
| if labels_out.get("enforcement") == ["none_detected"]: | |
| summary_bits.append("Enforcement: none detected in this passage.") | |
| if labels_out.get("resourcing") == ["none_detected"]: | |
| summary_bits.append("Resourcing: none detected in this passage.") | |
| decision_summary = " ".join(summary_bits) if summary_bits else "No operative decision beyond high-level description detected." | |
| return { | |
| "decision_summary": decision_summary, | |
| "labels": labels_out, | |
| "percents_raw": perc_raw, # model confidences per label (0–100, do NOT sum to 100) | |
| "percents_norm": perc_norm, # normalized per axis (sums to ~100) | |
| "why": why, | |
| "text_preview": text[:300] + ("..." if len(text) > 300 else "") | |
| } | |
| # Get the sentiment for all the docs | |
| def get_sentiment(texts): | |
| return [classify_and_explain(texts[i].page_content) for i in range(len(texts))] |