MediRAG-API / src /modules /faithfulness.py
joytheslothh's picture
Update: backend v3.2 — privacy pipeline, consensus, new scripts
7161baa verified
"""
FR-05: src/modules/faithfulness.py — Module 1: Faithfulness Scoring
=====================================================================
Uses cross-encoder/nli-deberta-v3-small to score how well the LLM answer
is entailed by the retrieved context chunks.
Architecture:
1. Split answer into individual claims (sentences via pysbd)
2. For each claim: compute NLI score against every context chunk
3. Assign claim status: ENTAILED / NEUTRAL / CONTRADICTED
4. score = entailed_count / total_claims
Thresholds (SRS Section 6.1):
entailment ≥ 0.50 → ENTAILED
contradiction ≥ 0.30 → CONTRADICTED
otherwise → NEUTRAL
Model loaded lazily and cached at module level (avoids double-loading
when called multiple times in same process).
"""
from __future__ import annotations
import logging
import time
from functools import lru_cache
from typing import TYPE_CHECKING
from src.modules.base import EvalResult
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# BioLinkBERT fine-tuned on MedNLI (clinical notes, MIMIC-III)
# Paper 15 (Chen et al. SemEval-2023): best single model for biomedical NLI (F1=0.765)
# Faster on CPU than DeBERTa-large (BERT-base architecture)
MODEL_NAME = "cnut1648/biolinkbert-mednli"
# MedNLI label order (verified): {0: entailment, 1: neutral, 2: contradiction}
LABEL_ENTAILMENT = 0
LABEL_NEUTRAL = 1
LABEL_CONTRADICTION = 2
ENTAILMENT_THRESHOLD = 0.50
CONTRADICTION_THRESHOLD = 0.30
# ---------------------------------------------------------------------------
# Lazy model loader
# ---------------------------------------------------------------------------
_model = None
_segmenter = None
def _get_model():
global _model
if _model is None:
try:
from sentence_transformers import CrossEncoder
logger.info("Loading NLI model: %s (first call only)", MODEL_NAME)
_model = CrossEncoder(MODEL_NAME)
logger.info("NLI model loaded.")
except ImportError:
logger.error("sentence_transformers not installed. Faithfulness will be stubbed.")
_model = "stub"
return _model
def _get_segmenter():
global _segmenter
if _segmenter is None:
try:
import pysbd
_segmenter = pysbd.Segmenter(language="en", clean=False)
except ImportError:
_segmenter = "stub"
return _segmenter
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def score_faithfulness(
answer: str,
context_docs: list[str],
chunk_ids: list[str] | None = None,
max_chunks: int = 3,
config: dict | None = None,
) -> EvalResult:
"""
Score the faithfulness of an answer against retrieved context documents.
Args:
answer : The LLM-generated answer text.
context_docs : List of context passage strings (top-k retrieved chunks).
chunk_ids : Optional IDs matching context_docs for traceability.
max_chunks : Maximum context chunks to consider (to limit API calls).
Returns:
EvalResult with module_name="faithfulness", score in [0,1], and details
dict matching the shape defined in src/modules/__init__.py.
"""
t0 = time.perf_counter()
_faith_cfg = (config or {}).get("modules", {}).get("faithfulness", {})
entailment_threshold = _faith_cfg.get("entailment_threshold", ENTAILMENT_THRESHOLD)
contradiction_threshold = CONTRADICTION_THRESHOLD
if not answer or not context_docs:
return EvalResult(
module_name="faithfulness",
score=0.0,
details={"error": "Empty answer or no context provided"},
error="Empty answer or no context",
latency_ms=0,
)
# Limit context size
docs = context_docs[:max_chunks]
ids = (chunk_ids or [f"chunk_{i}" for i in range(len(docs))])[:max_chunks]
# Strip markdown formatting from guideline/structured chunks before NLI
# DeBERTa NLI was trained on clean prose — markdown confuses it
import re as _re
_MD_CLEAN = _re.compile(r'\[([^\]]+)\]\n|#{1,6}\s+|•\s+|\*\*([^*]+)\*\*|\*([^*]+)\*|`[^`]+`')
docs = [_MD_CLEAN.sub(lambda m: m.group(2) or m.group(3) or '', d) for d in docs]
# Strip inline citations and markdown from the answer before claim splitting.
# LLM answers often include [Source: *title*] citations and **bold** text that
# confuse BioLinkBERT NLI — the model was trained on clean prose.
_CITE_RE = _re.compile(
r'\[Source:[^\]]*\]' # [Source: title] or [Source: *italic title*]
r'|\[[^\]]{0,120}\]' # other short bracket constructs
r'|\*\*([^*]+)\*\*' # **bold** → keep inner text
r'|\*([^*]+)\*' # *italic* → keep inner text
r'|`[^`]+`' # `code`
r'|^\s*[*•]\s+' # bullet points at line start
)
answer_clean = _CITE_RE.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip()
# Split answer into claims
seg = _get_segmenter()
try:
if seg == "stub":
claims = [s.strip() for s in answer_clean.split(".") if s.strip()]
else:
claims = [s.strip() for s in seg.segment(answer_clean) if s.strip()]
except Exception:
claims = [s.strip() for s in answer_clean.split(".") if s.strip()]
if not claims:
return EvalResult(
module_name="faithfulness",
score=0.5,
details={"error": "Could not extract claims from answer"},
error="No claims extracted",
latency_ms=0,
)
model = _get_model()
# Limit claims to avoid O(claims×chunks) explosion with the large model
claims = claims[:12]
# ---------------------------------------------------------------------------
# Numerical Bypass (Paper 14: non-optional for clinical NLI)
# NLI models structurally cannot verify numerical comparisons (≥6.5%, 126 mg/dL).
# Use direct string/lexical matching for claims containing clinical measurements.
# ---------------------------------------------------------------------------
import re as _re2
_NUM_PATTERN = _re2.compile(
r'[\d]+[\s]*(mg|mcg|%|mL|mmol|IU|units?|g|kg|≥|≤|>|<|±|mg/dL|mmol/L|mg/kg)',
_re2.IGNORECASE,
)
def _numerical_match(claim: str, context_chunks: list[str]) -> str:
"""
For claims with numerical clinical values, check if the key numbers
appear in any context chunk. Returns ENTAILED or NEUTRAL.
"""
nums = _re2.findall(r'[\d]+\.?[\d]*', claim)
if not nums:
return "NEUTRAL"
combined = " ".join(context_chunks).lower()
matched = sum(1 for n in nums if n in combined)
return "ENTAILED" if matched >= len(nums) * 0.6 else "NEUTRAL"
# Separate numerical claims (bypass NLI) from textual claims (use NLI)
numerical_results: dict[int, str] = {} # claim_idx → status
nli_claim_indices: list[int] = []
for ci, claim in enumerate(claims):
if _NUM_PATTERN.search(claim):
numerical_results[ci] = _numerical_match(claim, docs)
else:
nli_claim_indices.append(ci)
# Build NLI pairs only for non-numerical claims
nli_claims = [claims[ci] for ci in nli_claim_indices]
all_pairs = []
pair_map: list[tuple[int, int]] = [] # (nli_claim_idx, doc_idx)
for nci, claim in enumerate(nli_claims):
for di, doc in enumerate(docs):
all_pairs.append((doc, claim))
pair_map.append((nci, di))
# Batch NLI inference
try:
if model == "stub":
# Provide dummy scores if model is unavailable
scores_raw = [[0.1, 0.1, 0.8] for _ in all_pairs]
else:
scores_raw = model.predict(all_pairs, apply_softmax=True)
except Exception as exc:
logger.error("NLI model inference failed: %s", exc)
return EvalResult(
module_name="faithfulness",
score=0.0,
details={},
error=f"Model inference error: {exc}",
latency_ms=int((time.perf_counter() - t0) * 1000),
)
# Aggregate: for each claim find the context with the highest entailment
claim_results: list[dict] = []
entailed = 0
neutral = 0
contradicted = 0
# Build per-NLI-claim best scores from batch results
nli_best: dict[int, tuple[float, float, int]] = {} # nci → (best_ent, best_con, best_doc)
for idx, (nci, d_i) in enumerate(pair_map):
score_vec = scores_raw[idx]
ent_score = float(score_vec[LABEL_ENTAILMENT])
con_score = float(score_vec[LABEL_CONTRADICTION])
if nci not in nli_best or ent_score > nli_best[nci][0]:
nli_best[nci] = (ent_score, con_score, d_i)
for ci, claim in enumerate(claims):
if ci in numerical_results:
# Numerical bypass — lexical match result
status = numerical_results[ci]
nli_score = 1.0 if status == "ENTAILED" else 0.0
best_doc_idx = 0
method = "numerical_bypass"
else:
# NLI result
nci = nli_claim_indices.index(ci) if ci in nli_claim_indices else -1
best_entailment, best_contradiction, best_doc_idx = nli_best.get(nci, (0.0, 0.0, 0))
if best_entailment >= entailment_threshold:
status = "ENTAILED"
nli_score = best_entailment
elif best_contradiction >= contradiction_threshold:
status = "CONTRADICTED"
nli_score = best_contradiction
else:
status = "NEUTRAL"
nli_score = best_entailment
method = "nli"
if status == "ENTAILED":
entailed += 1
elif status == "CONTRADICTED":
contradicted += 1
else:
neutral += 1
claim_results.append({
"claim": claim,
"status": status,
"best_chunk_id": ids[best_doc_idx],
"nli_score": round(nli_score, 4),
"method": method,
})
total = len(claims)
score = max(0.0, (entailed - contradicted) / total) if total > 0 else 0.0
usr_val = round((neutral + contradicted) / total, 4) if total > 0 else 0.0
details = {
"total_claims": total,
"entailed_count": entailed,
"neutral_count": neutral,
"contradicted_count": contradicted,
"unsupported_sentence_ratio": usr_val,
"claims": claim_results,
}
latency_ms = int((time.perf_counter() - t0) * 1000)
logger.info(
"Faithfulness: %.3f (%d/%d entailed) in %d ms",
score, entailed, total, latency_ms,
)
return EvalResult(
module_name="faithfulness",
score=score,
details=details,
latency_ms=latency_ms,
)