policy-analysis / utils /coherence_bbscore.py
kaburia's picture
rewrite
b022bee
# pip install sentence-transformers (if not already)
import math, re, unicodedata
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
import os, re, unicodedata, numpy as np
from utils.retrieve_n_rerank import retrieve_and_rerank
try:
from sentence_transformers import SentenceTransformer
except Exception:
SentenceTransformer = None
# -----------------------------
# Text utilities
# -----------------------------
def _norm(t: str) -> str:
if t is None: return ""
t = unicodedata.normalize("NFKC", str(t))
t = re.sub(r"\s*\n\s*", " ", t)
t = re.sub(r"\s{2,}", " ", t)
return t.strip()
def split_sentences(text: str) -> List[str]:
t = _norm(text)
parts = re.split(r"(?<=[\.\?\!])\s+(?=[A-Z“\"'])", t)
return [p.strip() for p in parts if p.strip()]
# -----------------------------
# Embeddings wrapper
# -----------------------------
class Embedder:
def __init__(self, model_name: str = "BAAI/bge-m3", device: str = "cpu"):
if SentenceTransformer is None:
raise RuntimeError("Install sentence-transformers to enable coherence scoring.")
self.model = SentenceTransformer(model_name, device=device)
def encode(self, sentences: List[str]) -> np.ndarray:
if not sentences:
return np.zeros((0, 768), dtype=np.float32)
X = self.model.encode(sentences, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
return np.asarray(X, dtype=np.float32)
def _cos(a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b))
def _normalize(v: np.ndarray) -> np.ndarray:
v = np.asarray(v, dtype=np.float32)
n = np.linalg.norm(v) + 1e-8
return v / n
# -----------------------------
# Brownian-bridge style metric
# -----------------------------
def bb_coherence(sentences: List[str], E: np.ndarray) -> Dict[str, Any]:
"""
Brownian-bridge–inspired coherence:
- Build a main-idea vector (intro+outro+centroid)
- Compare per-sentence sim to target curve that's high at ends, lower mid
- Map max bridge deviation -> (0,1] score (higher=more coherent)
"""
n = len(sentences)
if n == 0:
return {"bbscore": 0.0, "sims": [], "off_idx": [], "rep_pairs": [], "sim_matrix": None}
k = max(1, min(3, n // 5))
v_first = E[:k].mean(axis=0)
v_last = E[-k:].mean(axis=0)
v_all = E.mean(axis=0)
v_main = _normalize(0.4*v_first + 0.4*v_last + 0.2*v_all)
sims = np.array([_cos(v_main, E[i]) for i in range(n)], dtype=np.float32)
t = np.linspace(0.0, 1.0, num=n, dtype=np.float32)
q = 1.0 - 4.0 * t * (1.0 - t) # peaks at ends
q = q / (q.mean() + 1e-8) * (sims.mean() if sims.size else 0.0)
r = sims - q
r_centered = r - r.mean()
cumsum = np.cumsum(r_centered)
B = cumsum - t * (cumsum[-1] if n > 1 else 0.0)
denom = (np.std(r_centered) * math.sqrt(n)) + 1e-8
ks = float(np.max(np.abs(B)) / denom)
bbscore = float(1.0 / (1.0 + ks))
# Off-topic: sims < mean - 1σ
off_thr = float(sims.mean() - sims.std())
off_idx = [i for i, s in enumerate(sims) if s < off_thr]
# Repetition: very high pairwise similarity, skip adjacent
S = E @ E.T if n > 1 else np.zeros((1,1), dtype=np.float32) # cosine due to normalization
rep_pairs = []
if n > 1:
for i in range(n):
for j in range(i+2, n): # skip adjacent
if S[i, j] >= 0.92: # threshold tunable
rep_pairs.append((i, j, float(S[i, j])))
return {"bbscore": round(bbscore, 3), "sims": sims, "off_idx": off_idx, "rep_pairs": rep_pairs, "sim_matrix": S}
# -----------------------------
# Zero-shot labeler (optional)
# -----------------------------
def zshot_label(text: str, topic: str = "the main topic") -> Dict[str, float]:
"""
Optional: zero-shot verdict to complement rule-based label.
Labels: Coherent, Off topic, Repeated
"""
try:
from transformers import pipeline
except Exception:
return {}
clf = pipeline("zero-shot-classification",
model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
multi_label=True)
labels = ["Coherent", "Off topic", "Repeated"]
res = clf(_norm(text), labels, hypothesis_template=f"This passage is {{}} with respect to {topic}.")
return {lbl: float(score) for lbl, score in zip(res["labels"], res["scores"])}
# -----------------------------
# Decision logic + reasons
# -----------------------------
def decide_label_with_reasons(
text: str,
topic_hint: Optional[str],
bb: Dict[str, Any],
sentences: List[str],
zshot_scores: Optional[Dict[str, float]] = None,
thresholds: Dict[str, float] = None
) -> Dict[str, Any]:
"""
Returns:
{
"label": "Coherent" | "Off topic" | "Repeated",
"reasons": [ "...", "..."],
"evidence": { "off_topic_examples": [...], "repeated_examples": [...] },
"bbscore": 0.74
}
"""
thr = thresholds or {
"bb_coherent_min": 0.65, # >= coherent
"off_topic_ratio_max": 0.20, # <= coherent
"repeat_pairs_min": 1 # >= repeated (if any)
}
n = max(1, len(sentences))
off_ratio = len(bb["off_idx"]) / n
has_repeat = len(bb["rep_pairs"]) >= thr["repeat_pairs_min"]
bbscore = bb["bbscore"]
# Rule-based primary decision
if off_ratio > thr["off_topic_ratio_max"] and bbscore < thr["bb_coherent_min"]:
label = "Off topic"
elif has_repeat and bbscore >= 0.5:
label = "Repeated"
elif bbscore >= thr["bb_coherent_min"] and off_ratio <= thr["off_topic_ratio_max"] and not has_repeat:
label = "Coherent"
else:
# Tie-breaker using zero-shot if provided
if zshot_scores:
label = max(zshot_scores.items(), key=lambda kv: kv[1])[0]
else:
# fallback: prefer coherence if bbscore okay, else off-topic
label = "Coherent" if bbscore >= 0.6 else "Off topic"
# Reasons
reasons = [f"BBScore={bbscore:.3f}."]
if bb["off_idx"]:
reasons.append(f"Off-topic fraction={off_ratio:.2f} ({len(bb['off_idx'])}/{n} sentences below main-idea similarity).")
if has_repeat:
top_rep = sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2]
reasons.append(f"Repeated content detected (top sim={top_rep[0][2]:.2f}).")
if zshot_scores:
top = sorted(zshot_scores.items(), key=lambda kv: kv[1], reverse=True)[:2]
reasons.append("Zero-shot support: " + ", ".join([f"{k}={v:.2f}" for k,v in top]))
# Evidence snippets
ev_off = [f'{i}: "{sentences[i]}"' for i in bb["off_idx"][:2]]
ev_rep = []
for (i, j, sim) in sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2]:
ev_rep.append(f'({i},{j}) sim={sim:.2f}: "{sentences[i]}", "{sentences[j]}"')
return {
"label": label,
"reasons": reasons,
"evidence": {"off_topic_examples": ev_off, "repeated_examples": ev_rep},
"bbscore": bbscore
}
def _display_title(meta: Dict[str, Any], fallback: str) -> str:
if meta.get("title"): return str(meta["title"]).strip()
src = meta.get("source") or meta.get("path")
if src:
base = os.path.basename(str(src))
return re.sub(r"\.pdf$", "", base, flags=re.I)
return meta.get("doc_id") or fallback
def _page_label(meta: Dict[str, Any]) -> str:
return str(meta.get("page_label") or meta.get("page") or "?")
def to_std_doc(item: Any, idx: int = 0) -> Dict[str, Any]:
"""
Accepts a LangChain Document or dict; returns a standard dict:
{title, page_label, text}
"""
if hasattr(item, "page_content"): # LangChain Document
meta = getattr(item, "metadata", {}) or {}
return {
"title": _display_title(meta, f"doc{idx+1}"),
"page_label": _page_label(meta),
"text": _norm(item.page_content),
}
elif isinstance(item, dict):
meta = item.get("metadata", {}) or {}
title = item.get("title") or _display_title(meta, item.get("doc_id", f"doc{idx+1}"))
page = item.get("page_label") or _page_label(meta)
text = _norm(item.get("text") or item.get("page_content", ""))
return {"title": title, "page_label": page, "text": text}
else:
raise TypeError(f"Unsupported doc type at index {idx}: {type(item)}")
def coherence_assessment_std(
std_doc: Dict[str, Any],
embedder,
topic_hint: Optional[str] = None,
run_zero_shot: bool = False,
thresholds: Optional[Dict[str, float]] = None
) -> Dict[str, Any]:
"""Same as your coherence_assessment, but expects a standardized dict."""
text = std_doc.get("text", "")
sents = split_sentences(text)
if not sents:
return {"title": std_doc.get("title","Document"), "label": "Off topic", "bbscore": 0.0,
"reasons": ["Empty text."], "evidence": {}}
E = embedder.encode(sents)
bb = bb_coherence(sents, E)
zshot_scores = zshot_label(text, topic_hint) if run_zero_shot else None
decision = decide_label_with_reasons(text, topic_hint, bb, sents, zshot_scores, thresholds)
return {
"title": std_doc.get("title","Document"),
"page_label": std_doc.get("page_label","?"),
"label": decision["label"],
"bbscore": decision["bbscore"],
"reasons": decision["reasons"],
"evidence": decision["evidence"],
}
# Get the coherence report
def coherence_report(embedder="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
input_text=None,
reranked_results=None,
run_zero_shot=True):
embedder = Embedder(embedder) if isinstance(embedder, str) else embedder
if reranked_results is None:
reranked_results = retrieve_and_rerank(input_text)
if not reranked_results:
return []
# Convert reranked_results to standardized documents
std_results = [to_std_doc(doc, i) for i, doc in enumerate(reranked_results)]
reports = [coherence_assessment_std(d, embedder, topic_hint=input_text, run_zero_shot=run_zero_shot)
for d in std_results]
return reports