Spaces:
Sleeping
Sleeping
| # pip install sentence-transformers (if not already) | |
| import math, re, unicodedata | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import numpy as np | |
| import os, re, unicodedata, numpy as np | |
| from utils.retrieve_n_rerank import retrieve_and_rerank | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| except Exception: | |
| SentenceTransformer = None | |
| # ----------------------------- | |
| # Text utilities | |
| # ----------------------------- | |
| def _norm(t: str) -> str: | |
| if t is None: return "" | |
| t = unicodedata.normalize("NFKC", str(t)) | |
| t = re.sub(r"\s*\n\s*", " ", t) | |
| t = re.sub(r"\s{2,}", " ", t) | |
| return t.strip() | |
| def split_sentences(text: str) -> List[str]: | |
| t = _norm(text) | |
| parts = re.split(r"(?<=[\.\?\!])\s+(?=[A-Z“\"'])", t) | |
| return [p.strip() for p in parts if p.strip()] | |
| # ----------------------------- | |
| # Embeddings wrapper | |
| # ----------------------------- | |
| class Embedder: | |
| def __init__(self, model_name: str = "BAAI/bge-m3", device: str = "cpu"): | |
| if SentenceTransformer is None: | |
| raise RuntimeError("Install sentence-transformers to enable coherence scoring.") | |
| self.model = SentenceTransformer(model_name, device=device) | |
| def encode(self, sentences: List[str]) -> np.ndarray: | |
| if not sentences: | |
| return np.zeros((0, 768), dtype=np.float32) | |
| X = self.model.encode(sentences, normalize_embeddings=True, batch_size=32, show_progress_bar=False) | |
| return np.asarray(X, dtype=np.float32) | |
| def _cos(a: np.ndarray, b: np.ndarray) -> float: | |
| return float(np.dot(a, b)) | |
| def _normalize(v: np.ndarray) -> np.ndarray: | |
| v = np.asarray(v, dtype=np.float32) | |
| n = np.linalg.norm(v) + 1e-8 | |
| return v / n | |
| # ----------------------------- | |
| # Brownian-bridge style metric | |
| # ----------------------------- | |
| def bb_coherence(sentences: List[str], E: np.ndarray) -> Dict[str, Any]: | |
| """ | |
| Brownian-bridge–inspired coherence: | |
| - Build a main-idea vector (intro+outro+centroid) | |
| - Compare per-sentence sim to target curve that's high at ends, lower mid | |
| - Map max bridge deviation -> (0,1] score (higher=more coherent) | |
| """ | |
| n = len(sentences) | |
| if n == 0: | |
| return {"bbscore": 0.0, "sims": [], "off_idx": [], "rep_pairs": [], "sim_matrix": None} | |
| k = max(1, min(3, n // 5)) | |
| v_first = E[:k].mean(axis=0) | |
| v_last = E[-k:].mean(axis=0) | |
| v_all = E.mean(axis=0) | |
| v_main = _normalize(0.4*v_first + 0.4*v_last + 0.2*v_all) | |
| sims = np.array([_cos(v_main, E[i]) for i in range(n)], dtype=np.float32) | |
| t = np.linspace(0.0, 1.0, num=n, dtype=np.float32) | |
| q = 1.0 - 4.0 * t * (1.0 - t) # peaks at ends | |
| q = q / (q.mean() + 1e-8) * (sims.mean() if sims.size else 0.0) | |
| r = sims - q | |
| r_centered = r - r.mean() | |
| cumsum = np.cumsum(r_centered) | |
| B = cumsum - t * (cumsum[-1] if n > 1 else 0.0) | |
| denom = (np.std(r_centered) * math.sqrt(n)) + 1e-8 | |
| ks = float(np.max(np.abs(B)) / denom) | |
| bbscore = float(1.0 / (1.0 + ks)) | |
| # Off-topic: sims < mean - 1σ | |
| off_thr = float(sims.mean() - sims.std()) | |
| off_idx = [i for i, s in enumerate(sims) if s < off_thr] | |
| # Repetition: very high pairwise similarity, skip adjacent | |
| S = E @ E.T if n > 1 else np.zeros((1,1), dtype=np.float32) # cosine due to normalization | |
| rep_pairs = [] | |
| if n > 1: | |
| for i in range(n): | |
| for j in range(i+2, n): # skip adjacent | |
| if S[i, j] >= 0.92: # threshold tunable | |
| rep_pairs.append((i, j, float(S[i, j]))) | |
| return {"bbscore": round(bbscore, 3), "sims": sims, "off_idx": off_idx, "rep_pairs": rep_pairs, "sim_matrix": S} | |
| # ----------------------------- | |
| # Zero-shot labeler (optional) | |
| # ----------------------------- | |
| def zshot_label(text: str, topic: str = "the main topic") -> Dict[str, float]: | |
| """ | |
| Optional: zero-shot verdict to complement rule-based label. | |
| Labels: Coherent, Off topic, Repeated | |
| """ | |
| try: | |
| from transformers import pipeline | |
| except Exception: | |
| return {} | |
| clf = pipeline("zero-shot-classification", | |
| model="MoritzLaurer/deberta-v3-base-zeroshot-v2.0", | |
| multi_label=True) | |
| labels = ["Coherent", "Off topic", "Repeated"] | |
| res = clf(_norm(text), labels, hypothesis_template=f"This passage is {{}} with respect to {topic}.") | |
| return {lbl: float(score) for lbl, score in zip(res["labels"], res["scores"])} | |
| # ----------------------------- | |
| # Decision logic + reasons | |
| # ----------------------------- | |
| def decide_label_with_reasons( | |
| text: str, | |
| topic_hint: Optional[str], | |
| bb: Dict[str, Any], | |
| sentences: List[str], | |
| zshot_scores: Optional[Dict[str, float]] = None, | |
| thresholds: Dict[str, float] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Returns: | |
| { | |
| "label": "Coherent" | "Off topic" | "Repeated", | |
| "reasons": [ "...", "..."], | |
| "evidence": { "off_topic_examples": [...], "repeated_examples": [...] }, | |
| "bbscore": 0.74 | |
| } | |
| """ | |
| thr = thresholds or { | |
| "bb_coherent_min": 0.65, # >= coherent | |
| "off_topic_ratio_max": 0.20, # <= coherent | |
| "repeat_pairs_min": 1 # >= repeated (if any) | |
| } | |
| n = max(1, len(sentences)) | |
| off_ratio = len(bb["off_idx"]) / n | |
| has_repeat = len(bb["rep_pairs"]) >= thr["repeat_pairs_min"] | |
| bbscore = bb["bbscore"] | |
| # Rule-based primary decision | |
| if off_ratio > thr["off_topic_ratio_max"] and bbscore < thr["bb_coherent_min"]: | |
| label = "Off topic" | |
| elif has_repeat and bbscore >= 0.5: | |
| label = "Repeated" | |
| elif bbscore >= thr["bb_coherent_min"] and off_ratio <= thr["off_topic_ratio_max"] and not has_repeat: | |
| label = "Coherent" | |
| else: | |
| # Tie-breaker using zero-shot if provided | |
| if zshot_scores: | |
| label = max(zshot_scores.items(), key=lambda kv: kv[1])[0] | |
| else: | |
| # fallback: prefer coherence if bbscore okay, else off-topic | |
| label = "Coherent" if bbscore >= 0.6 else "Off topic" | |
| # Reasons | |
| reasons = [f"BBScore={bbscore:.3f}."] | |
| if bb["off_idx"]: | |
| reasons.append(f"Off-topic fraction={off_ratio:.2f} ({len(bb['off_idx'])}/{n} sentences below main-idea similarity).") | |
| if has_repeat: | |
| top_rep = sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2] | |
| reasons.append(f"Repeated content detected (top sim={top_rep[0][2]:.2f}).") | |
| if zshot_scores: | |
| top = sorted(zshot_scores.items(), key=lambda kv: kv[1], reverse=True)[:2] | |
| reasons.append("Zero-shot support: " + ", ".join([f"{k}={v:.2f}" for k,v in top])) | |
| # Evidence snippets | |
| ev_off = [f'{i}: "{sentences[i]}"' for i in bb["off_idx"][:2]] | |
| ev_rep = [] | |
| for (i, j, sim) in sorted(bb["rep_pairs"], key=lambda x: x[2], reverse=True)[:2]: | |
| ev_rep.append(f'({i},{j}) sim={sim:.2f}: "{sentences[i]}", "{sentences[j]}"') | |
| return { | |
| "label": label, | |
| "reasons": reasons, | |
| "evidence": {"off_topic_examples": ev_off, "repeated_examples": ev_rep}, | |
| "bbscore": bbscore | |
| } | |
| def _display_title(meta: Dict[str, Any], fallback: str) -> str: | |
| if meta.get("title"): return str(meta["title"]).strip() | |
| src = meta.get("source") or meta.get("path") | |
| if src: | |
| base = os.path.basename(str(src)) | |
| return re.sub(r"\.pdf$", "", base, flags=re.I) | |
| return meta.get("doc_id") or fallback | |
| def _page_label(meta: Dict[str, Any]) -> str: | |
| return str(meta.get("page_label") or meta.get("page") or "?") | |
| def to_std_doc(item: Any, idx: int = 0) -> Dict[str, Any]: | |
| """ | |
| Accepts a LangChain Document or dict; returns a standard dict: | |
| {title, page_label, text} | |
| """ | |
| if hasattr(item, "page_content"): # LangChain Document | |
| meta = getattr(item, "metadata", {}) or {} | |
| return { | |
| "title": _display_title(meta, f"doc{idx+1}"), | |
| "page_label": _page_label(meta), | |
| "text": _norm(item.page_content), | |
| } | |
| elif isinstance(item, dict): | |
| meta = item.get("metadata", {}) or {} | |
| title = item.get("title") or _display_title(meta, item.get("doc_id", f"doc{idx+1}")) | |
| page = item.get("page_label") or _page_label(meta) | |
| text = _norm(item.get("text") or item.get("page_content", "")) | |
| return {"title": title, "page_label": page, "text": text} | |
| else: | |
| raise TypeError(f"Unsupported doc type at index {idx}: {type(item)}") | |
| def coherence_assessment_std( | |
| std_doc: Dict[str, Any], | |
| embedder, | |
| topic_hint: Optional[str] = None, | |
| run_zero_shot: bool = False, | |
| thresholds: Optional[Dict[str, float]] = None | |
| ) -> Dict[str, Any]: | |
| """Same as your coherence_assessment, but expects a standardized dict.""" | |
| text = std_doc.get("text", "") | |
| sents = split_sentences(text) | |
| if not sents: | |
| return {"title": std_doc.get("title","Document"), "label": "Off topic", "bbscore": 0.0, | |
| "reasons": ["Empty text."], "evidence": {}} | |
| E = embedder.encode(sents) | |
| bb = bb_coherence(sents, E) | |
| zshot_scores = zshot_label(text, topic_hint) if run_zero_shot else None | |
| decision = decide_label_with_reasons(text, topic_hint, bb, sents, zshot_scores, thresholds) | |
| return { | |
| "title": std_doc.get("title","Document"), | |
| "page_label": std_doc.get("page_label","?"), | |
| "label": decision["label"], | |
| "bbscore": decision["bbscore"], | |
| "reasons": decision["reasons"], | |
| "evidence": decision["evidence"], | |
| } | |
| # Get the coherence report | |
| def coherence_report(embedder="MoritzLaurer/deberta-v3-base-zeroshot-v2.0", | |
| input_text=None, | |
| reranked_results=None, | |
| run_zero_shot=True): | |
| embedder = Embedder(embedder) if isinstance(embedder, str) else embedder | |
| if reranked_results is None: | |
| reranked_results = retrieve_and_rerank(input_text) | |
| if not reranked_results: | |
| return [] | |
| # Convert reranked_results to standardized documents | |
| std_results = [to_std_doc(doc, i) for i, doc in enumerate(reranked_results)] | |
| reports = [coherence_assessment_std(d, embedder, topic_hint=input_text, run_zero_shot=run_zero_shot) | |
| for d in std_results] | |
| return reports | |