Spaces:
Running
Running
File size: 9,840 Bytes
b6f9fa8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | """
FR-17: src/modules/contradiction.py β Module 4: Cross-Document Contradiction Detection
========================================================================================
Uses the same DeBERTa NLI cross-encoder (cross-encoder/nli-deberta-v3-small) to
detect contradictions between the LLM answer and retrieved context passages.
Algorithm (SRS Section 6.4):
1. Split answer into sentences (claims)
2. Split each context chunk into sentences
3. For each (answer_sentence, context_sentence) pair:
- Run NLI β get contradiction score
- If contradiction_score β₯ CONTRADICTION_THRESHOLD β flag pair
4. score = 1.0 - (flagged_pairs / total_pairs)
This module shares the NLI model instance with faithfulness.py when both
run in the same process (the model is cached at the faithfulness module level).
Design note:
To keep latency manageable, context sentences are limited to
MAX_CONTEXT_SENTS per chunk and total pairs are capped at MAX_PAIRS.
"""
from __future__ import annotations
import logging
import time
import pysbd
from src.modules.base import EvalResult
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CONTRADICTION_THRESHOLD = 0.50 # Balanced: catches real contradictions without over-flagging
MIN_KEYWORD_OVERLAP = 3 # At least 3 meaningful words in common before running NLI
MAX_CONTEXT_SENTS = 4 # top N sentences per context chunk
MAX_PAIRS = 200 # hard cap to keep latency bounded (~2-3s)
_segmenter = None
# Common stopwords to ignore in overlap check
_STOPWORDS = {
"the", "a", "an", "is", "in", "of", "to", "for", "and", "or", "are",
"be", "at", "by", "if", "it", "as", "on", "with", "this", "that",
"was", "were", "not", "no", "have", "has", "had", "but", "so", "from",
"should", "may", "can", "will", "than", "more", "when", "which", "who",
"what", "all", "each", "after", "before", "been", "do", "does", "1",
"2", "3", "mg", "iv", "od", "per", "day", "based", "using", "include",
}
def _get_segmenter():
"""Lazily load and return the pysbd segmenter."""
global _segmenter
if _segmenter is None:
try:
import pysbd
_segmenter = pysbd.Segmenter(language="en", clean=False)
except ImportError:
logger.warning("pysbd not installed, falling back to naive sentence splitting.")
_segmenter = "stub" # Use a string to indicate stub mode
except Exception as e:
logger.error("Failed to initialize pysbd segmenter: %s", e)
_segmenter = "stub"
return _segmenter
def _keyword_overlap(sent_a: str, sent_b: str) -> int:
"""Count shared content words between two sentences."""
tokens_a = {w.lower() for w in sent_a.split() if w.lower() not in _STOPWORDS and len(w) > 2}
tokens_b = {w.lower() for w in sent_b.split() if w.lower() not in _STOPWORDS and len(w) > 2}
return len(tokens_a & tokens_b)
def _segment(text: str) -> list[str]:
"""Segment text into sentences using pysbd or a fallback."""
seg = _get_segmenter()
try:
if seg == "stub":
return [s.strip() for s in text.split(".") if s.strip()]
else:
return [s.strip() for s in seg.segment(text) if s.strip()]
except Exception:
return [s.strip() for s in text.split(".") if s.strip()]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def score_contradiction(
answer: str,
context_docs: list[str],
max_chunks: int = 5,
) -> EvalResult:
"""
Detect contradictions between the LLM answer and retrieved context.
Args:
answer : LLM-generated answer text.
context_docs : List of retrieved context passage strings.
max_chunks : Max number of context chunks to evaluate.
Returns:
EvalResult with module_name="contradiction", score in [0,1] where
1.0 = no contradictions detected, 0.0 = all pairs contradicted.
"""
t0 = time.perf_counter()
if not answer or not context_docs:
return EvalResult(
module_name="contradiction",
score=0.5, # neutral β cannot verify with missing input
details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
latency_ms=0,
)
# Import model via faithfulness module (shared cache)
try:
from src.modules.faithfulness import _get_model, LABEL_CONTRADICTION
except ImportError:
# (Lazy imports to prevent startup crashes when libraries aren't installed yet)
try:
from sentence_transformers import CrossEncoder
_model = CrossEncoder("cross-encoder/nli-deberta-v3-small")
_get_model = lambda: _model # noqa: E731
LABEL_CONTRADICTION = 0
except ImportError:
logger.error("sentence-transformers not installed. Cannot run NLI model.")
return EvalResult(
module_name="contradiction",
score=1.0,
details={},
error="NLI model (sentence-transformers) not installed.",
latency_ms=int((time.perf_counter() - t0) * 1000),
)
except Exception as exc:
logger.error("Failed to load NLI model: %s", exc)
return EvalResult(
module_name="contradiction",
score=1.0,
details={},
error=f"Failed to load NLI model: {exc}",
latency_ms=int((time.perf_counter() - t0) * 1000),
)
model = _get_model()
# Strip markdown/citations from answer before NLI (same reason as faithfulness.py)
import re as _re
_MD = _re.compile(
r'\[Source:[^\]]*\]|\[[^\]]{0,120}\]' # citations
r'|\*\*([^*]+)\*\*|\*([^*]+)\*' # bold/italic β keep text
r'|`[^`]+`' # code
)
answer = _MD.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip()
# Segment answer into claims
answer_sents = _segment(answer)
if not answer_sents:
return EvalResult(
module_name="contradiction",
score=0.5, # neutral β cannot verify with no sentences
details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
latency_ms=0,
)
# Segment context chunks
docs = context_docs[:max_chunks]
context_sents: list[str] = []
for doc in docs:
sents = _segment(doc)[:MAX_CONTEXT_SENTS]
context_sents.extend(sents)
if not context_sents:
return EvalResult(
module_name="contradiction",
score=0.5, # neutral β cannot verify with no context sentences
details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
latency_ms=0,
)
# Build pairs WITH topical pre-filter (skip unrelated sentence pairs entirely)
all_pairs: list[tuple[str, str]] = []
for a_sent in answer_sents:
for c_sent in context_sents:
if _keyword_overlap(a_sent, c_sent) >= MIN_KEYWORD_OVERLAP:
all_pairs.append((a_sent, c_sent))
if len(all_pairs) >= MAX_PAIRS:
break
if len(all_pairs) >= MAX_PAIRS:
break
if not all_pairs:
# Topically unrelated β cannot check for contradictions
return EvalResult(
module_name="contradiction",
score=0.5, # neutral β no overlapping pairs to evaluate
details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
latency_ms=int((time.perf_counter() - t0) * 1000),
)
# Batch NLI inference
try:
scores_raw = model.predict(all_pairs, apply_softmax=True)
except Exception as exc:
logger.error("Contradiction NLI inference failed: %s", exc)
return EvalResult(
module_name="contradiction",
score=1.0,
details={},
error=f"Model inference error: {exc}",
latency_ms=int((time.perf_counter() - t0) * 1000),
)
# Collect flagged pairs
pair_details: list[dict] = []
contradicted = 0
total = len(all_pairs)
for i, (a_sent, c_sent) in enumerate(all_pairs):
con_score = float(scores_raw[i][LABEL_CONTRADICTION])
flagged = con_score >= CONTRADICTION_THRESHOLD
if flagged:
contradicted += 1
# Only log the most severe contradictions to keep details manageable
pair_details.append(
{
"sentence_a": a_sent[:120],
"sentence_b": c_sent[:120],
"contradiction_score": round(con_score, 4),
"flagged": True,
}
)
# Score: 1.0 = clean, lower = more contradictions found
score = 1.0 - (contradicted / total) if total > 0 else 1.0
details = {
"total_sentences": len(answer_sents),
"checked_pairs": total,
"contradicted_pairs": contradicted,
"pairs": pair_details[:20], # cap output to top 20 flagged pairs
}
latency_ms = int((time.perf_counter() - t0) * 1000)
logger.info(
"Contradiction: %.3f (%d/%d pairs flagged) in %d ms",
score, contradicted, total, latency_ms,
)
return EvalResult(
module_name="contradiction",
score=score,
details=details,
latency_ms=latency_ms,
)
|