""" FR-05: src/modules/faithfulness.py — Module 1: Faithfulness Scoring ===================================================================== Uses cross-encoder/nli-deberta-v3-small to score how well the LLM answer is entailed by the retrieved context chunks. Architecture: 1. Split answer into individual claims (sentences via pysbd) 2. For each claim: compute NLI score against every context chunk 3. Assign claim status: ENTAILED / NEUTRAL / CONTRADICTED 4. score = entailed_count / total_claims Thresholds (SRS Section 6.1): entailment ≥ 0.50 → ENTAILED contradiction ≥ 0.30 → CONTRADICTED otherwise → NEUTRAL Model loaded lazily and cached at module level (avoids double-loading when called multiple times in same process). """ from __future__ import annotations import logging import time from functools import lru_cache from typing import TYPE_CHECKING from src.modules.base import EvalResult if TYPE_CHECKING: pass logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- # BioLinkBERT fine-tuned on MedNLI (clinical notes, MIMIC-III) # Paper 15 (Chen et al. SemEval-2023): best single model for biomedical NLI (F1=0.765) # Faster on CPU than DeBERTa-large (BERT-base architecture) MODEL_NAME = "cnut1648/biolinkbert-mednli" # MedNLI label order (verified): {0: entailment, 1: neutral, 2: contradiction} LABEL_ENTAILMENT = 0 LABEL_NEUTRAL = 1 LABEL_CONTRADICTION = 2 ENTAILMENT_THRESHOLD = 0.50 CONTRADICTION_THRESHOLD = 0.30 # --------------------------------------------------------------------------- # Lazy model loader # --------------------------------------------------------------------------- _model = None _segmenter = None def _get_model(): global _model if _model is None: try: from sentence_transformers import CrossEncoder logger.info("Loading NLI model: %s (first call only)", MODEL_NAME) _model = CrossEncoder(MODEL_NAME) logger.info("NLI model loaded.") except ImportError: logger.error("sentence_transformers not installed. Faithfulness will be stubbed.") _model = "stub" return _model def _get_segmenter(): global _segmenter if _segmenter is None: try: import pysbd _segmenter = pysbd.Segmenter(language="en", clean=False) except ImportError: _segmenter = "stub" return _segmenter # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def score_faithfulness( answer: str, context_docs: list[str], chunk_ids: list[str] | None = None, max_chunks: int = 3, config: dict | None = None, ) -> EvalResult: """ Score the faithfulness of an answer against retrieved context documents. Args: answer : The LLM-generated answer text. context_docs : List of context passage strings (top-k retrieved chunks). chunk_ids : Optional IDs matching context_docs for traceability. max_chunks : Maximum context chunks to consider (to limit API calls). Returns: EvalResult with module_name="faithfulness", score in [0,1], and details dict matching the shape defined in src/modules/__init__.py. """ t0 = time.perf_counter() _faith_cfg = (config or {}).get("modules", {}).get("faithfulness", {}) entailment_threshold = _faith_cfg.get("entailment_threshold", ENTAILMENT_THRESHOLD) contradiction_threshold = CONTRADICTION_THRESHOLD if not answer or not context_docs: return EvalResult( module_name="faithfulness", score=0.0, details={"error": "Empty answer or no context provided"}, error="Empty answer or no context", latency_ms=0, ) # Limit context size docs = context_docs[:max_chunks] ids = (chunk_ids or [f"chunk_{i}" for i in range(len(docs))])[:max_chunks] # Strip markdown formatting from guideline/structured chunks before NLI # DeBERTa NLI was trained on clean prose — markdown confuses it import re as _re _MD_CLEAN = _re.compile(r'\[([^\]]+)\]\n|#{1,6}\s+|•\s+|\*\*([^*]+)\*\*|\*([^*]+)\*|`[^`]+`') docs = [_MD_CLEAN.sub(lambda m: m.group(2) or m.group(3) or '', d) for d in docs] # Strip inline citations and markdown from the answer before claim splitting. # LLM answers often include [Source: *title*] citations and **bold** text that # confuse BioLinkBERT NLI — the model was trained on clean prose. _CITE_RE = _re.compile( r'\[Source:[^\]]*\]' # [Source: title] or [Source: *italic title*] r'|\[[^\]]{0,120}\]' # other short bracket constructs r'|\*\*([^*]+)\*\*' # **bold** → keep inner text r'|\*([^*]+)\*' # *italic* → keep inner text r'|`[^`]+`' # `code` r'|^\s*[*•]\s+' # bullet points at line start ) answer_clean = _CITE_RE.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip() # Split answer into claims seg = _get_segmenter() try: if seg == "stub": claims = [s.strip() for s in answer_clean.split(".") if s.strip()] else: claims = [s.strip() for s in seg.segment(answer_clean) if s.strip()] except Exception: claims = [s.strip() for s in answer_clean.split(".") if s.strip()] if not claims: return EvalResult( module_name="faithfulness", score=0.5, details={"error": "Could not extract claims from answer"}, error="No claims extracted", latency_ms=0, ) model = _get_model() # Limit claims to avoid O(claims×chunks) explosion with the large model claims = claims[:12] # --------------------------------------------------------------------------- # Numerical Bypass (Paper 14: non-optional for clinical NLI) # NLI models structurally cannot verify numerical comparisons (≥6.5%, 126 mg/dL). # Use direct string/lexical matching for claims containing clinical measurements. # --------------------------------------------------------------------------- import re as _re2 _NUM_PATTERN = _re2.compile( r'[\d]+[\s]*(mg|mcg|%|mL|mmol|IU|units?|g|kg|≥|≤|>|<|±|mg/dL|mmol/L|mg/kg)', _re2.IGNORECASE, ) def _numerical_match(claim: str, context_chunks: list[str]) -> str: """ For claims with numerical clinical values, check if the key numbers appear in any context chunk. Returns ENTAILED or NEUTRAL. """ nums = _re2.findall(r'[\d]+\.?[\d]*', claim) if not nums: return "NEUTRAL" combined = " ".join(context_chunks).lower() matched = sum(1 for n in nums if n in combined) return "ENTAILED" if matched >= len(nums) * 0.6 else "NEUTRAL" # Separate numerical claims (bypass NLI) from textual claims (use NLI) numerical_results: dict[int, str] = {} # claim_idx → status nli_claim_indices: list[int] = [] for ci, claim in enumerate(claims): if _NUM_PATTERN.search(claim): numerical_results[ci] = _numerical_match(claim, docs) else: nli_claim_indices.append(ci) # Build NLI pairs only for non-numerical claims nli_claims = [claims[ci] for ci in nli_claim_indices] all_pairs = [] pair_map: list[tuple[int, int]] = [] # (nli_claim_idx, doc_idx) for nci, claim in enumerate(nli_claims): for di, doc in enumerate(docs): all_pairs.append((doc, claim)) pair_map.append((nci, di)) # Batch NLI inference try: if model == "stub": # Provide dummy scores if model is unavailable scores_raw = [[0.1, 0.1, 0.8] for _ in all_pairs] else: scores_raw = model.predict(all_pairs, apply_softmax=True) except Exception as exc: logger.error("NLI model inference failed: %s", exc) return EvalResult( module_name="faithfulness", score=0.0, details={}, error=f"Model inference error: {exc}", latency_ms=int((time.perf_counter() - t0) * 1000), ) # Aggregate: for each claim find the context with the highest entailment claim_results: list[dict] = [] entailed = 0 neutral = 0 contradicted = 0 # Build per-NLI-claim best scores from batch results nli_best: dict[int, tuple[float, float, int]] = {} # nci → (best_ent, best_con, best_doc) for idx, (nci, d_i) in enumerate(pair_map): score_vec = scores_raw[idx] ent_score = float(score_vec[LABEL_ENTAILMENT]) con_score = float(score_vec[LABEL_CONTRADICTION]) if nci not in nli_best or ent_score > nli_best[nci][0]: nli_best[nci] = (ent_score, con_score, d_i) for ci, claim in enumerate(claims): if ci in numerical_results: # Numerical bypass — lexical match result status = numerical_results[ci] nli_score = 1.0 if status == "ENTAILED" else 0.0 best_doc_idx = 0 method = "numerical_bypass" else: # NLI result nci = nli_claim_indices.index(ci) if ci in nli_claim_indices else -1 best_entailment, best_contradiction, best_doc_idx = nli_best.get(nci, (0.0, 0.0, 0)) if best_entailment >= entailment_threshold: status = "ENTAILED" nli_score = best_entailment elif best_contradiction >= contradiction_threshold: status = "CONTRADICTED" nli_score = best_contradiction else: status = "NEUTRAL" nli_score = best_entailment method = "nli" if status == "ENTAILED": entailed += 1 elif status == "CONTRADICTED": contradicted += 1 else: neutral += 1 claim_results.append({ "claim": claim, "status": status, "best_chunk_id": ids[best_doc_idx], "nli_score": round(nli_score, 4), "method": method, }) total = len(claims) score = max(0.0, (entailed - contradicted) / total) if total > 0 else 0.0 usr_val = round((neutral + contradicted) / total, 4) if total > 0 else 0.0 details = { "total_claims": total, "entailed_count": entailed, "neutral_count": neutral, "contradicted_count": contradicted, "unsupported_sentence_ratio": usr_val, "claims": claim_results, } latency_ms = int((time.perf_counter() - t0) * 1000) logger.info( "Faithfulness: %.3f (%d/%d entailed) in %d ms", score, entailed, total, latency_ms, ) return EvalResult( module_name="faithfulness", score=score, details=details, latency_ms=latency_ms, )