""" FR-09: src/modules/entity_verifier.py — Module 2: Medical Entity Verification ============================================================================== Uses SciSpaCy NER (en_core_sci_lg) to extract medical entities from the answer, then verifies drug entities against the RxNorm cache and/or REST API. Verification pipeline (SRS Section 6.2): 1. NER: extract DRUG, DOSAGE, CONDITION, PROCEDURE entities from answer 2. For each DRUG entity: a. Look up in local rxnorm_cache.csv (fast, offline) b. If not found, query RxNorm REST API /approximateTerm (live fallback) c. If still not found, mark as NOT_FOUND 3. Cross-check entity presence in context docs (optional validation) 4. Score = verified_drug_count / total_drug_count (non-drug entities have no score impact) Entity status values: VERIFIED — drug found in RxNorm cache or API with rxcui FLAGGED — entity found but has a known dangerous synonym conflict NOT_FOUND — drug name not resolvable via any layer Severity mapping (for FLAGGED): brand ↔ generic mismatch → CRITICAL dosage discrepancy → MODERATE minor synonym variant → MINOR """ from __future__ import annotations import logging import re import time from functools import lru_cache from pathlib import Path from typing import Optional import pandas as pd import requests from src.modules.base import EvalResult logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- RXNORM_APPROX_URL = "https://rxnav.nlm.nih.gov/REST/approximateTerm.json" DEFAULT_CACHE_PATH = "data/rxnorm_cache.csv" NER_MODEL = "en_ner_bc5cdr_md" DOSAGE_TOLERANCE_PCT = 10 # flag if answer dose differs from context dose by > 10% # Matches clinical dose values: "500 mg", "2.5 mcg/kg", "10 IU", etc. _DOSE_RE = re.compile( r'(\d+(?:\.\d+)?)\s*(?:mg|mcg|g\b|ml|iu|units?|mg/kg|mg/dl)', re.IGNORECASE, ) # Map spacy entity labels to our schema types _ENTITY_TYPE_MAP = { # en_core_sci_lg (CRAFT corpus) labels "CHEBI": "DRUG", # Chemical Entities of Biological Interest — covers drugs "GGP": "CONDITION", # Gene or Gene Product "SO": "CONDITION", # Sequence Ontology "TAXON": "CONDITION", "GO": "CONDITION", # Gene Ontology "CL": "CONDITION", # Cell Line "DNA": "CONDITION", "RNA": "CONDITION", "CELL_TYPE": "CONDITION", "CELL_LINE": "CONDITION", "PROTEIN": "CONDITION", # BC5CDR labels (used by some scispacy models) "Chemical": "DRUG", "Disease": "CONDITION", # Generic / fallback labels "CHEMICAL": "DRUG", "DRUG": "DRUG", "COMPOUND": "DRUG", "DISEASE": "CONDITION", "SYMPTOM": "CONDITION", "PROCEDURE": "PROCEDURE", "DOSAGE": "DOSAGE", } DRUG_TYPES = {"DRUG"} # only these get verified against RxNorm # --------------------------------------------------------------------------- # Module-level caches # --------------------------------------------------------------------------- _spacy_model = None _rxnorm_cache: dict[str, str] | None = None # drug_name -> rxcui _rxnorm_cache_path: str = DEFAULT_CACHE_PATH def _get_spacy_model(): global _spacy_model if _spacy_model is None: import spacy logger.info("Loading SciSpaCy NER model: %s (first call only)", NER_MODEL) try: _spacy_model = spacy.load(NER_MODEL) logger.info("SciSpaCy model loaded.") except OSError as exc: logger.error( "Failed to load '%s'. Install with: " "pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/" "releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz\nError: %s", NER_MODEL, exc, ) raise return _spacy_model def _load_rxnorm_cache(cache_path: str) -> dict[str, str]: """Load the RxNorm cache CSV into a lowercase drug_name → rxcui dict.""" path = Path(cache_path) if not path.exists(): logger.warning( "RxNorm cache not found at '%s'. Live API only will be used.", cache_path ) return {} try: df = pd.read_csv(path, dtype=str) cache = { str(row["drug_name"]).strip().lower(): str(row["rxcui"]).strip() for _, row in df.iterrows() if pd.notna(row.get("drug_name")) and pd.notna(row.get("rxcui")) and str(row.get("rxcui", "")).strip() } logger.info("RxNorm cache loaded: %d entries from %s", len(cache), cache_path) return cache except Exception as exc: logger.warning("Failed to load RxNorm cache: %s", exc) return {} def _get_rxnorm_cache(cache_path: str) -> dict[str, str]: global _rxnorm_cache, _rxnorm_cache_path if _rxnorm_cache is None or cache_path != _rxnorm_cache_path: _rxnorm_cache_path = cache_path _rxnorm_cache = _load_rxnorm_cache(cache_path) return _rxnorm_cache def _extract_doses_near(text: str, drug_name: str, window: int = 180) -> list[float]: """Return numeric dose values found within `window` chars of `drug_name` in `text`.""" idx = text.lower().find(drug_name.lower()) if idx == -1: return [] vicinity = text[max(0, idx - window // 2): idx + len(drug_name) + window] return [float(m.group(1)) for m in _DOSE_RE.finditer(vicinity)] def _lookup_rxnorm_api(drug_name: str, timeout: int = 4) -> Optional[str]: """Query RxNorm REST API. Returns rxcui string or None.""" try: resp = requests.get( RXNORM_APPROX_URL, params={"term": drug_name, "maxEntries": "1", "option": "1"}, timeout=timeout, ) if resp.status_code != 200: return None candidates = ( resp.json() .get("approximateGroup", {}) .get("candidate", []) ) if candidates: return str(candidates[0].get("rxcui", "")).strip() or None except Exception: pass return None # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def verify_entities( answer: str, question: str = "", context_docs: list[str] | None = None, rxnorm_cache_path: str = DEFAULT_CACHE_PATH, use_api_fallback: bool = True, ) -> EvalResult: """ Extract and verify medical entities from the LLM answer. Args: answer : LLM-generated answer text. question : Original question (NER'd alongside answer for richer entity set). context_docs : Retrieved context passages (used for cross-checking). rxnorm_cache_path : Path to rxnorm_cache.csv. use_api_fallback : Whether to call RxNorm REST API for cache misses. Returns: EvalResult with module_name="entity_verifier", score in [0,1], and details matching the shape from src/modules/__init__.py. """ t0 = time.perf_counter() # --- NER ----------------------------------------------------------------- try: nlp = _get_spacy_model() except Exception as exc: return EvalResult( module_name="entity_verifier", score=0.5, # neutral fallback — don't penalise if model not available details={"error": str(exc), "entities": []}, error=f"NER model unavailable: {exc}", latency_ms=int((time.perf_counter() - t0) * 1000), ) # Combine question + answer for richer entity extraction combined_text = f"{question} {answer}" if question else answer doc = nlp(combined_text) # Collect entities with deduplication seen: set[str] = set() raw_entities: list[tuple[str, str]] = [] # (text, type) for ent in doc.ents: key = ent.text.strip().lower() if not key or key in seen: continue seen.add(key) entity_type = _ENTITY_TYPE_MAP.get(ent.label_, "CONDITION") raw_entities.append((ent.text.strip(), entity_type)) if not raw_entities: return EvalResult( module_name="entity_verifier", score=0.5, # neutral — cannot verify what isn't there details={"total_entities": 0, "verified_count": 0, "flagged_count": 0, "entities": []}, latency_ms=int((time.perf_counter() - t0) * 1000), ) # --- RxNorm verification for DRUG entities ------------------------------- cache = _get_rxnorm_cache(rxnorm_cache_path) context_text = " ".join(context_docs or []).lower() entity_results: list[dict] = [] drug_total = 0 drug_verified = 0 drug_flagged = 0 for entity_text, entity_type in raw_entities: result = { "entity": entity_text, "type": entity_type, "status": "NOT_FOUND", "severity": None, "answer_value": entity_text, "context_value": None, "rxcui": None, } if entity_type in DRUG_TYPES: drug_total += 1 key = entity_text.lower() # Layer 1: Local cache lookup rxcui = cache.get(key) # Layer 2: API fallback if not rxcui and use_api_fallback: rxcui = _lookup_rxnorm_api(entity_text) if rxcui: result["rxcui"] = rxcui # Check for dosage discrepancy before marking VERIFIED answer_doses = _extract_doses_near(answer, entity_text) context_doses = _extract_doses_near(context_text, entity_text) flagged_dose = False if answer_doses and context_doses: a_dose = answer_doses[0] c_dose = min(context_doses, key=lambda d: abs(d - a_dose)) pct_diff = abs(a_dose - c_dose) / max(c_dose, 1e-9) * 100 if pct_diff > DOSAGE_TOLERANCE_PCT: result["status"] = "FLAGGED" result["severity"] = "MODERATE" result["answer_value"] = f"{a_dose} (answer)" result["context_value"] = f"{c_dose} (context, Δ{pct_diff:.0f}%)" drug_flagged += 1 flagged_dose = True logger.warning( "Dosage discrepancy for '%s': answer=%.1f context=%.1f (%.0f%%)", entity_text, a_dose, c_dose, pct_diff, ) if not flagged_dose: result["status"] = "VERIFIED" drug_verified += 1 if key in context_text: result["context_value"] = entity_text else: result["status"] = "NOT_FOUND" elif entity_type in ("CONDITION", "PROCEDURE"): # Non-drug entities: check presence in context only if entity_text.lower() in context_text: result["status"] = "VERIFIED" result["context_value"] = entity_text else: result["status"] = "NOT_FOUND" entity_results.append(result) # --- Score --------------------------------------------------------------- # Score is based on drug entities only (per SRS Section 6.2) if drug_total == 0: score = 0.5 # neutral — no drug entities to verify else: score = drug_verified / drug_total details = { "total_entities": len(raw_entities), "drug_total": drug_total, "verified_count": drug_verified, "flagged_count": drug_flagged, "entities": entity_results, } latency_ms = int((time.perf_counter() - t0) * 1000) logger.info( "Entity verification: %.3f (%d/%d drugs verified) in %d ms", score, drug_verified, drug_total, latency_ms, ) return EvalResult( module_name="entity_verifier", score=score, details=details, latency_ms=latency_ms, )