File size: 9,840 Bytes
b6f9fa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
FR-17: src/modules/contradiction.py β€” Module 4: Cross-Document Contradiction Detection
========================================================================================
Uses the same DeBERTa NLI cross-encoder (cross-encoder/nli-deberta-v3-small) to
detect contradictions between the LLM answer and retrieved context passages.

Algorithm (SRS Section 6.4):
    1. Split answer into sentences  (claims)
    2. Split each context chunk into sentences
    3. For each (answer_sentence, context_sentence) pair:
        - Run NLI β†’ get contradiction score
        - If contradiction_score β‰₯ CONTRADICTION_THRESHOLD β†’ flag pair
    4. score = 1.0 - (flagged_pairs / total_pairs)

This module shares the NLI model instance with faithfulness.py when both
run in the same process (the model is cached at the faithfulness module level).

Design note:
    To keep latency manageable, context sentences are limited to
    MAX_CONTEXT_SENTS per chunk and total pairs are capped at MAX_PAIRS.
"""
from __future__ import annotations

import logging
import time

import pysbd

from src.modules.base import EvalResult

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

CONTRADICTION_THRESHOLD = 0.50   # Balanced: catches real contradictions without over-flagging
MIN_KEYWORD_OVERLAP = 3          # At least 3 meaningful words in common before running NLI
MAX_CONTEXT_SENTS = 4   # top N sentences per context chunk
MAX_PAIRS = 200          # hard cap to keep latency bounded (~2-3s)

_segmenter = None

# Common stopwords to ignore in overlap check
_STOPWORDS = {
    "the", "a", "an", "is", "in", "of", "to", "for", "and", "or", "are",
    "be", "at", "by", "if", "it", "as", "on", "with", "this", "that",
    "was", "were", "not", "no", "have", "has", "had", "but", "so", "from",
    "should", "may", "can", "will", "than", "more", "when", "which", "who",
    "what", "all", "each", "after", "before", "been", "do", "does", "1",
    "2", "3", "mg", "iv", "od", "per", "day", "based", "using", "include",
}


def _get_segmenter():
    """Lazily load and return the pysbd segmenter."""
    global _segmenter
    if _segmenter is None:
        try:
            import pysbd
            _segmenter = pysbd.Segmenter(language="en", clean=False)
        except ImportError:
            logger.warning("pysbd not installed, falling back to naive sentence splitting.")
            _segmenter = "stub"  # Use a string to indicate stub mode
        except Exception as e:
            logger.error("Failed to initialize pysbd segmenter: %s", e)
            _segmenter = "stub"
    return _segmenter


def _keyword_overlap(sent_a: str, sent_b: str) -> int:
    """Count shared content words between two sentences."""
    tokens_a = {w.lower() for w in sent_a.split() if w.lower() not in _STOPWORDS and len(w) > 2}
    tokens_b = {w.lower() for w in sent_b.split() if w.lower() not in _STOPWORDS and len(w) > 2}
    return len(tokens_a & tokens_b)


def _segment(text: str) -> list[str]:
    """Segment text into sentences using pysbd or a fallback."""
    seg = _get_segmenter()
    try:
        if seg == "stub":
            return [s.strip() for s in text.split(".") if s.strip()]
        else:
            return [s.strip() for s in seg.segment(text) if s.strip()]
    except Exception:
        return [s.strip() for s in text.split(".") if s.strip()]


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def score_contradiction(
    answer: str,
    context_docs: list[str],
    max_chunks: int = 5,
) -> EvalResult:
    """
    Detect contradictions between the LLM answer and retrieved context.

    Args:
        answer       : LLM-generated answer text.
        context_docs : List of retrieved context passage strings.
        max_chunks   : Max number of context chunks to evaluate.

    Returns:
        EvalResult with module_name="contradiction", score in [0,1] where
        1.0 = no contradictions detected, 0.0 = all pairs contradicted.
    """
    t0 = time.perf_counter()

    if not answer or not context_docs:
        return EvalResult(
            module_name="contradiction",
            score=0.5,  # neutral β€” cannot verify with missing input
            details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
            latency_ms=0,
        )

    # Import model via faithfulness module (shared cache)
    try:
        from src.modules.faithfulness import _get_model, LABEL_CONTRADICTION
    except ImportError:
        # (Lazy imports to prevent startup crashes when libraries aren't installed yet)
        try:
            from sentence_transformers import CrossEncoder
            _model = CrossEncoder("cross-encoder/nli-deberta-v3-small")
            _get_model = lambda: _model  # noqa: E731
            LABEL_CONTRADICTION = 0
        except ImportError:
            logger.error("sentence-transformers not installed. Cannot run NLI model.")
            return EvalResult(
                module_name="contradiction",
                score=1.0,
                details={},
                error="NLI model (sentence-transformers) not installed.",
                latency_ms=int((time.perf_counter() - t0) * 1000),
            )
        except Exception as exc:
            logger.error("Failed to load NLI model: %s", exc)
            return EvalResult(
                module_name="contradiction",
                score=1.0,
                details={},
                error=f"Failed to load NLI model: {exc}",
                latency_ms=int((time.perf_counter() - t0) * 1000),
            )

    model = _get_model()

    # Strip markdown/citations from answer before NLI (same reason as faithfulness.py)
    import re as _re
    _MD = _re.compile(
        r'\[Source:[^\]]*\]|\[[^\]]{0,120}\]'  # citations
        r'|\*\*([^*]+)\*\*|\*([^*]+)\*'        # bold/italic β†’ keep text
        r'|`[^`]+`'                             # code
    )
    answer = _MD.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip()

    # Segment answer into claims
    answer_sents = _segment(answer)
    if not answer_sents:
        return EvalResult(
            module_name="contradiction",
            score=0.5,  # neutral β€” cannot verify with no sentences
            details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
            latency_ms=0,
        )

    # Segment context chunks
    docs = context_docs[:max_chunks]
    context_sents: list[str] = []
    for doc in docs:
        sents = _segment(doc)[:MAX_CONTEXT_SENTS]
        context_sents.extend(sents)

    if not context_sents:
        return EvalResult(
            module_name="contradiction",
            score=0.5,  # neutral β€” cannot verify with no context sentences
            details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
            latency_ms=0,
        )

    # Build pairs WITH topical pre-filter (skip unrelated sentence pairs entirely)
    all_pairs: list[tuple[str, str]] = []
    for a_sent in answer_sents:
        for c_sent in context_sents:
            if _keyword_overlap(a_sent, c_sent) >= MIN_KEYWORD_OVERLAP:
                all_pairs.append((a_sent, c_sent))
            if len(all_pairs) >= MAX_PAIRS:
                break
        if len(all_pairs) >= MAX_PAIRS:
            break

    if not all_pairs:
        # Topically unrelated β€” cannot check for contradictions
        return EvalResult(
            module_name="contradiction",
            score=0.5,  # neutral β€” no overlapping pairs to evaluate
            details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
            latency_ms=int((time.perf_counter() - t0) * 1000),
        )

    # Batch NLI inference
    try:
        scores_raw = model.predict(all_pairs, apply_softmax=True)
    except Exception as exc:
        logger.error("Contradiction NLI inference failed: %s", exc)
        return EvalResult(
            module_name="contradiction",
            score=1.0,
            details={},
            error=f"Model inference error: {exc}",
            latency_ms=int((time.perf_counter() - t0) * 1000),
        )

    # Collect flagged pairs
    pair_details: list[dict] = []
    contradicted = 0
    total = len(all_pairs)

    for i, (a_sent, c_sent) in enumerate(all_pairs):
        con_score = float(scores_raw[i][LABEL_CONTRADICTION])
        flagged = con_score >= CONTRADICTION_THRESHOLD
        if flagged:
            contradicted += 1
            # Only log the most severe contradictions to keep details manageable
            pair_details.append(
                {
                    "sentence_a": a_sent[:120],
                    "sentence_b": c_sent[:120],
                    "contradiction_score": round(con_score, 4),
                    "flagged": True,
                }
            )

    # Score: 1.0 = clean, lower = more contradictions found
    score = 1.0 - (contradicted / total) if total > 0 else 1.0

    details = {
        "total_sentences": len(answer_sents),
        "checked_pairs": total,
        "contradicted_pairs": contradicted,
        "pairs": pair_details[:20],  # cap output to top 20 flagged pairs
    }

    latency_ms = int((time.perf_counter() - t0) * 1000)
    logger.info(
        "Contradiction: %.3f (%d/%d pairs flagged) in %d ms",
        score, contradicted, total, latency_ms,
    )
    return EvalResult(
        module_name="contradiction",
        score=score,
        details=details,
        latency_ms=latency_ms,
    )