policy-analysis / utils /sentiment_analysis.py
kaburia's picture
sentiment
f96633d
import re, math, torch
from transformers import pipeline
import spaces
# ------------- Model (CPU-friendly); use device=0 + fp16 on GPU -------------
ZSHOT = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
multi_label=True,
device=-1,
model_kwargs={"torch_dtype": torch.float32}
)
# ------------------ Taxonomy with descriptions (helps NLI) -------------------
TAXO = {
"intent_type": [
"objective: declares goals or aims",
"principle: states guiding values",
"strategy: outlines measures or actions",
"obligation: mandates an action (shall/must)",
"prohibition: forbids an action",
"permission: allows an action (may)",
"exception: states conditions where rules change",
"definition: defines a term",
"scope: states applicability or coverage"
],
"disposition": [
"restrictive: limits or constrains the topic",
"cautionary: warns or urges care",
"neutral: descriptive with no clear stance",
"enabling: allows or facilitates the topic",
"supportive: promotes or expands the topic"
],
"rigidity": [
"must: mandatory (shall/must)",
"should: advisory (should)",
"may: permissive (may/can)"
],
"temporal": [
"deadline: requires completion by a date or period",
"schedule: sets a cadence (e.g., annually, quarterly)",
"ongoing: continuing requirement without end date",
"effective_date: specifies when rules start/apply"
],
"scope": [
"actor_specific: targets a group or entity (e.g., county governments, permit holders)",
"geography_specific: targets a location or region",
"subject_specific: targets a topic (e.g., permits, sanitation)",
"nationwide: applies across the country"
],
"enforcement": [
"penalty: fines or sanctions for non-compliance",
"remedy: corrective actions required",
"monitoring: oversight or audits",
"reporting: reports/returns required",
"none_detected: no enforcement mechanisms present"
],
"resourcing": [
"funding: funds or budget allocations",
"fees_levies: charges or levies",
"capacity_hr: staffing or training",
"infrastructure: capital works or equipment",
"none_detected: no resourcing present"
],
"impact": [
"low: limited effect on regulated parties",
"medium: moderate practical effect",
"high: significant obligations or restrictions"
]
}
# ---------------- Axis-specific thresholds (calibrate later) -----------------
TAU = {
"intent_type": 0.55, "disposition": 0.55, "rigidity": 0.60,
"temporal": 0.62, "scope": 0.55,
"enforcement": 0.50, "resourcing": 0.50, "impact": 0.60
}
TAU_LOW = 0.40 # only for deciding if we can safely emit "none_detected"
# ------------------------- Cleaning & evidence rules -------------------------
def _clean(t: str) -> str:
t = re.sub(r"[ \t]*\n[ \t]*", " ", str(t))
t = re.sub(r"\s{2,}", " ", t).strip()
return t
PAT = {
"actor": r"\bCounty Government(?:s)?\b|\bAuthority\b|\bMinistry\b|\bAgency\b|\bBoard\b|\bCommission\b",
"nationwide": r"\bKenya\b|\bnational\b|\bnationwide\b|\bacross the country\b|\bthe country\b",
"objective": r"\b(Objective[s]?|Purpose)\b|(?:^|\.\s+)To [A-Za-z]",
"imperative": r"(?:^|\.\s+)(Promote|Ensure|Encourage|Strengthen|Adopt)\b.*?(?:\.|;)",
"modal_must": r"\bshall\b|\bmust\b",
"modal_should": r"\bshould\b",
"modal_may": r"\bmay\b|\bcan\b",
"temporal": r"\bwithin \d+\s+(day|days|month|months|year|years)\b|\bby \d{4}\b|\beffective\b",
"enforcement": r"\bpenalt(y|ies)\b|\bfine(s)?\b|\brevocation\b|\bsuspension\b|\breport(ing)?\b|\bmonitor(ing)?\b",
"resourcing": r"\bfund(?:ing)?\b|\blevy|levies|fee(s)?\b|\bbudget\b|\binfrastructure\b|\bcapacity\b|\btraining\b"
}
def _spans(text, pattern, max_spans=2):
spans = []
for m in re.finditer(pattern, text, flags=re.I):
# sentence-level extraction
start = text.rfind('.', 0, m.start()) + 1
end = text.find('.', m.end())
if end == -1: end = len(text)
snippet = text[start:end].strip()
if snippet and snippet not in spans:
spans.append(snippet)
if len(spans) >= max_spans: break
return spans
def _softmax(d):
vals = list(d.values())
if not vals: return {k: 0.0 for k in d}
m = max(vals)
exps = [math.exp(v - m) for v in vals]
Z = sum(exps)
return {k: (e / Z) for k, e in zip(d.keys(), exps)}
# -------------------- Main: classify + explanations + % ----------------------
def classify_and_explain(text: str, topic: str = "water and sanitation", per_axis_top_k=2):
text = _clean(text)
if not text:
return {"decision_summary": "No operative decision; empty passage.",
"labels": {ax: [] for ax in TAXO},
"percents_raw": {ax: {} for ax in TAXO},
"percents_norm": {ax: {} for ax in TAXO},
"why": [], "text_preview": ""}
# Topic-aware hypotheses (improves stance/intent)
def hyp(axis):
base = "This passage {} regarding " + topic + "."
return {
"intent_type": base.format("states a {}"),
"disposition": base.format("is {}"),
"rigidity": "Compliance in this passage is {}.",
"temporal": base.format("specifies a {} aspect"),
"scope": base.format("is {} in applicability"),
"enforcement": base.format("includes {} for compliance"),
"resourcing": base.format("provides {}"),
"impact": base.format("has {} impact")
}[axis]
# Single call if supported; else per-axis fallback
tasks = [{"sequences": text, "candidate_labels": labels, "hypothesis_template": hyp(axis)}
for axis, labels in TAXO.items()]
try:
results = ZSHOT(tasks)
except TypeError:
results = [ZSHOT(text, labels, hypothesis_template=hyp(axis))
for axis, labels in TAXO.items()]
labels_out, perc_raw, perc_norm, why = {}, {}, {}, []
for (axis, labels), r in zip(TAXO.items(), results):
# raw scores
raw = {lbl.split(":")[0].strip(): float(s) for lbl, s in zip(r["labels"], r["scores"])}
perc_raw[axis] = {k: round(raw[k]*100, 1) for k in raw} # independent sigmoid
norm = _softmax(raw)
perc_norm[axis] = {k: round(norm[k]*100, 1) for k in norm} # sums ~100%
# select labels by threshold
keep = [k for k, s in raw.items() if s >= TAU[axis]]
keep = sorted(keep, key=lambda k: raw[k], reverse=True)[:per_axis_top_k]
# only emit none_detected when everything else is weak and no heuristic evidence
if not keep and "none_detected" in raw:
if max([v for k, v in raw.items() if k != "none_detected"] or [0.0]) < TAU_LOW:
keep = ["none_detected"]
labels_out[axis] = keep
# compact "why" with evidence for the top choice
if keep and keep[0] != "none_detected":
if axis == "intent_type":
ev = _spans(text, PAT["objective"]) or _spans(text, PAT["imperative"])
why.append({"axis": axis, "label": keep[0], "reason": "functional cues", "evidence": ev[:2]})
elif axis == "disposition":
ev = _spans(text, PAT["imperative"])
why.append({"axis": axis, "label": keep[0], "reason": "promotional/allowing framing", "evidence": ev[:2]})
elif axis == "rigidity":
pat = {"must": PAT["modal_must"], "should": PAT["modal_should"], "may": PAT["modal_may"]}[keep[0]]
why.append({"axis": axis, "label": keep[0], "reason": "modal verb", "evidence": _spans(text, pat)[:2]})
elif axis == "temporal":
why.append({"axis": axis, "label": keep[0], "reason": "time expressions", "evidence": _spans(text, PAT["temporal"])[:2]})
elif axis == "scope":
ev = _spans(text, PAT["nationwide"]) or _spans(text, PAT["actor"])
why.append({"axis": axis, "label": keep[0], "reason": "applicability cues", "evidence": ev[:2]})
elif axis == "enforcement":
why.append({"axis": axis, "label": keep[0], "reason": "compliance hooks", "evidence": _spans(text, PAT["enforcement"])[:2]})
elif axis == "resourcing":
why.append({"axis": axis, "label": keep[0], "reason": "resourcing hooks", "evidence": _spans(text, PAT["resourcing"])[:2]})
# Decision summary: imperative lines + problem statements; never fabricate
summary_bits = []
imperatives = re.findall(PAT["imperative"], text, flags=re.I)
# pull full imperative sentences
imp_sents = _spans(text, PAT["imperative"], max_spans=3)
if imp_sents:
summary_bits.append("Strategies: " + " ".join(imp_sents))
if "nationwide" in labels_out.get("scope", []):
summary_bits.append("Applies nationwide.")
if labels_out.get("enforcement") == ["none_detected"]:
summary_bits.append("Enforcement: none detected in this passage.")
if labels_out.get("resourcing") == ["none_detected"]:
summary_bits.append("Resourcing: none detected in this passage.")
decision_summary = " ".join(summary_bits) if summary_bits else "No operative decision beyond high-level description detected."
return {
"decision_summary": decision_summary,
"labels": labels_out,
"percents_raw": perc_raw, # model confidences per label (0–100, do NOT sum to 100)
"percents_norm": perc_norm, # normalized per axis (sums to ~100)
"why": why,
"text_preview": text[:300] + ("..." if len(text) > 300 else "")
}
# Get the sentiment for all the docs
@spaces.GPU(duration=120)
def get_sentiment(texts):
return [classify_and_explain(texts[i].page_content) for i in range(len(texts))]