Spaces:
Running
Running
| """ | |
| FR-17: src/modules/contradiction.py — Module 4: Cross-Document Contradiction Detection | |
| ======================================================================================== | |
| Uses the same DeBERTa NLI cross-encoder (cross-encoder/nli-deberta-v3-small) to | |
| detect contradictions between the LLM answer and retrieved context passages. | |
| Algorithm (SRS Section 6.4): | |
| 1. Split answer into sentences (claims) | |
| 2. Split each context chunk into sentences | |
| 3. For each (answer_sentence, context_sentence) pair: | |
| - Run NLI → get contradiction score | |
| - If contradiction_score ≥ CONTRADICTION_THRESHOLD → flag pair | |
| 4. score = 1.0 - (flagged_pairs / total_pairs) | |
| This module shares the NLI model instance with faithfulness.py when both | |
| run in the same process (the model is cached at the faithfulness module level). | |
| Design note: | |
| To keep latency manageable, context sentences are limited to | |
| MAX_CONTEXT_SENTS per chunk and total pairs are capped at MAX_PAIRS. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| import pysbd | |
| from src.modules.base import EvalResult | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| CONTRADICTION_THRESHOLD = 0.50 # Balanced: catches real contradictions without over-flagging | |
| MIN_KEYWORD_OVERLAP = 3 # At least 3 meaningful words in common before running NLI | |
| MAX_CONTEXT_SENTS = 4 # top N sentences per context chunk | |
| MAX_PAIRS = 200 # hard cap to keep latency bounded (~2-3s) | |
| _segmenter = None | |
| # Common stopwords to ignore in overlap check | |
| _STOPWORDS = { | |
| "the", "a", "an", "is", "in", "of", "to", "for", "and", "or", "are", | |
| "be", "at", "by", "if", "it", "as", "on", "with", "this", "that", | |
| "was", "were", "not", "no", "have", "has", "had", "but", "so", "from", | |
| "should", "may", "can", "will", "than", "more", "when", "which", "who", | |
| "what", "all", "each", "after", "before", "been", "do", "does", "1", | |
| "2", "3", "mg", "iv", "od", "per", "day", "based", "using", "include", | |
| } | |
| def _get_segmenter(): | |
| """Lazily load and return the pysbd segmenter.""" | |
| global _segmenter | |
| if _segmenter is None: | |
| try: | |
| import pysbd | |
| _segmenter = pysbd.Segmenter(language="en", clean=False) | |
| except ImportError: | |
| logger.warning("pysbd not installed, falling back to naive sentence splitting.") | |
| _segmenter = "stub" # Use a string to indicate stub mode | |
| except Exception as e: | |
| logger.error("Failed to initialize pysbd segmenter: %s", e) | |
| _segmenter = "stub" | |
| return _segmenter | |
| def _keyword_overlap(sent_a: str, sent_b: str) -> int: | |
| """Count shared content words between two sentences.""" | |
| tokens_a = {w.lower() for w in sent_a.split() if w.lower() not in _STOPWORDS and len(w) > 2} | |
| tokens_b = {w.lower() for w in sent_b.split() if w.lower() not in _STOPWORDS and len(w) > 2} | |
| return len(tokens_a & tokens_b) | |
| def _segment(text: str) -> list[str]: | |
| """Segment text into sentences using pysbd or a fallback.""" | |
| seg = _get_segmenter() | |
| try: | |
| if seg == "stub": | |
| return [s.strip() for s in text.split(".") if s.strip()] | |
| else: | |
| return [s.strip() for s in seg.segment(text) if s.strip()] | |
| except Exception: | |
| return [s.strip() for s in text.split(".") if s.strip()] | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| def score_contradiction( | |
| answer: str, | |
| context_docs: list[str], | |
| max_chunks: int = 5, | |
| ) -> EvalResult: | |
| """ | |
| Detect contradictions between the LLM answer and retrieved context. | |
| Args: | |
| answer : LLM-generated answer text. | |
| context_docs : List of retrieved context passage strings. | |
| max_chunks : Max number of context chunks to evaluate. | |
| Returns: | |
| EvalResult with module_name="contradiction", score in [0,1] where | |
| 1.0 = no contradictions detected, 0.0 = all pairs contradicted. | |
| """ | |
| t0 = time.perf_counter() | |
| if not answer or not context_docs: | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=0.5, # neutral — cannot verify with missing input | |
| details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []}, | |
| latency_ms=0, | |
| ) | |
| # Import model via faithfulness module (shared cache) | |
| try: | |
| from src.modules.faithfulness import _get_model, LABEL_CONTRADICTION | |
| except ImportError: | |
| # (Lazy imports to prevent startup crashes when libraries aren't installed yet) | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| _model = CrossEncoder("cross-encoder/nli-deberta-v3-small") | |
| _get_model = lambda: _model # noqa: E731 | |
| LABEL_CONTRADICTION = 0 | |
| except ImportError: | |
| logger.error("sentence-transformers not installed. Cannot run NLI model.") | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=1.0, | |
| details={}, | |
| error="NLI model (sentence-transformers) not installed.", | |
| latency_ms=int((time.perf_counter() - t0) * 1000), | |
| ) | |
| except Exception as exc: | |
| logger.error("Failed to load NLI model: %s", exc) | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=1.0, | |
| details={}, | |
| error=f"Failed to load NLI model: {exc}", | |
| latency_ms=int((time.perf_counter() - t0) * 1000), | |
| ) | |
| model = _get_model() | |
| # Strip markdown/citations from answer before NLI (same reason as faithfulness.py) | |
| import re as _re | |
| _MD = _re.compile( | |
| r'\[Source:[^\]]*\]|\[[^\]]{0,120}\]' # citations | |
| r'|\*\*([^*]+)\*\*|\*([^*]+)\*' # bold/italic → keep text | |
| r'|`[^`]+`' # code | |
| ) | |
| answer = _MD.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip() | |
| # Segment answer into claims | |
| answer_sents = _segment(answer) | |
| if not answer_sents: | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=0.5, # neutral — cannot verify with no sentences | |
| details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []}, | |
| latency_ms=0, | |
| ) | |
| # Segment context chunks | |
| docs = context_docs[:max_chunks] | |
| context_sents: list[str] = [] | |
| for doc in docs: | |
| sents = _segment(doc)[:MAX_CONTEXT_SENTS] | |
| context_sents.extend(sents) | |
| if not context_sents: | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=0.5, # neutral — cannot verify with no context sentences | |
| details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []}, | |
| latency_ms=0, | |
| ) | |
| # Build pairs WITH topical pre-filter (skip unrelated sentence pairs entirely) | |
| all_pairs: list[tuple[str, str]] = [] | |
| for a_sent in answer_sents: | |
| for c_sent in context_sents: | |
| if _keyword_overlap(a_sent, c_sent) >= MIN_KEYWORD_OVERLAP: | |
| all_pairs.append((a_sent, c_sent)) | |
| if len(all_pairs) >= MAX_PAIRS: | |
| break | |
| if len(all_pairs) >= MAX_PAIRS: | |
| break | |
| if not all_pairs: | |
| # Topically unrelated — cannot check for contradictions | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=0.5, # neutral — no overlapping pairs to evaluate | |
| details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []}, | |
| latency_ms=int((time.perf_counter() - t0) * 1000), | |
| ) | |
| # Batch NLI inference | |
| try: | |
| scores_raw = model.predict(all_pairs, apply_softmax=True) | |
| except Exception as exc: | |
| logger.error("Contradiction NLI inference failed: %s", exc) | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=1.0, | |
| details={}, | |
| error=f"Model inference error: {exc}", | |
| latency_ms=int((time.perf_counter() - t0) * 1000), | |
| ) | |
| # Collect flagged pairs | |
| pair_details: list[dict] = [] | |
| contradicted = 0 | |
| total = len(all_pairs) | |
| for i, (a_sent, c_sent) in enumerate(all_pairs): | |
| con_score = float(scores_raw[i][LABEL_CONTRADICTION]) | |
| flagged = con_score >= CONTRADICTION_THRESHOLD | |
| if flagged: | |
| contradicted += 1 | |
| # Only log the most severe contradictions to keep details manageable | |
| pair_details.append( | |
| { | |
| "sentence_a": a_sent[:120], | |
| "sentence_b": c_sent[:120], | |
| "contradiction_score": round(con_score, 4), | |
| "flagged": True, | |
| } | |
| ) | |
| # Score: 1.0 = clean, lower = more contradictions found | |
| score = 1.0 - (contradicted / total) if total > 0 else 1.0 | |
| details = { | |
| "total_sentences": len(answer_sents), | |
| "checked_pairs": total, | |
| "contradicted_pairs": contradicted, | |
| "pairs": pair_details[:20], # cap output to top 20 flagged pairs | |
| } | |
| latency_ms = int((time.perf_counter() - t0) * 1000) | |
| logger.info( | |
| "Contradiction: %.3f (%d/%d pairs flagged) in %d ms", | |
| score, contradicted, total, latency_ms, | |
| ) | |
| return EvalResult( | |
| module_name="contradiction", | |
| score=score, | |
| details=details, | |
| latency_ms=latency_ms, | |
| ) | |