"""Med-TIV-style verifier loop. Following "Scaling Medical Reasoning Verification via Tool-Integrated RL" (arXiv 2601.20221), every Gemeo recommendation is post-processed by a verifier agent that: 1. Parses the recommendation into atomic claims. 2. For each claim, attempts to ground it via: - gemeo_state(section) — twin lookup - gemeo_lookup(query) — GraphRAG - direct KG cypher (if needed) 3. Flags claims that cannot be grounded (potential hallucination). 4. Returns either OK (all grounded) or a corrected/redacted version. Lightweight default: regex-based claim extraction + KG grounding via Aura. Heavy mode: secondary LLM call to verify each claim. """ from __future__ import annotations import logging import re from dataclasses import dataclass, field from typing import Optional logger = logging.getLogger("gemeo.verifier") @dataclass class Claim: text: str kind: str # diagnosis | drug | dose | risk | percentage | citation | hpo | gene | other grounded: bool = False evidence: list = field(default_factory=list) # list of supporting items confidence: float = 0.5 @dataclass class VerifierReport: claims: list = field(default_factory=list) # list[Claim] n_total: int = 0 n_grounded: int = 0 n_unverified: int = 0 grounding_rate: float = 0.0 redacted_text: str = "" # text with unverified claims marked _PATTERNS = { "hpo": re.compile(r"\b(HP:\d{7})\b"), "orpha": re.compile(r"\b(ORPHA:\d{1,7})\b", re.IGNORECASE), "gene": re.compile(r"\b([A-Z][A-Z0-9]{1,7})\s*(?:gene|mutation|variant|c\.|p\.)", re.IGNORECASE), "drug_with_dose": re.compile(r"\b([A-Za-z][A-Za-z0-9-]+)\s+(\d+(?:\.\d+)?)\s*(mg|µg|ng|UI|IU|mEq)/?(kg|m2)?\b"), "percentage": re.compile(r"\b(\d{1,3}(?:\.\d+)?)\s*%"), "year": re.compile(r"\b(20\d{2})\b"), "pmid": re.compile(r"\bPMID:?\s*(\d+)\b"), } def extract_claims(text: str) -> list[Claim]: """Pull atomic factual claims from a free-text recommendation.""" if not text: return [] claims: list[Claim] = [] for hpo in _PATTERNS["hpo"].findall(text): claims.append(Claim(text=hpo, kind="hpo")) for orpha in _PATTERNS["orpha"].findall(text): claims.append(Claim(text=orpha.upper(), kind="diagnosis")) for gene in _PATTERNS["gene"].findall(text): claims.append(Claim(text=gene.upper(), kind="gene")) for name, dose, unit, per in _PATTERNS["drug_with_dose"].findall(text): claims.append(Claim(text=f"{name} {dose} {unit}{('/' + per) if per else ''}", kind="dose")) for pct in _PATTERNS["percentage"].findall(text): claims.append(Claim(text=f"{pct}%", kind="percentage")) for pmid in _PATTERNS["pmid"].findall(text): claims.append(Claim(text=f"PMID:{pmid}", kind="citation")) return claims async def _ground_via_kg(claim: Claim) -> tuple[bool, list]: """Try to ground a claim via direct Aura cypher.""" try: from space_graph import _safe_query except ImportError: return False, [] if claim.kind == "hpo": rows = await _safe_query("MATCH (p:Phenotype {hpoId: $h}) RETURN p.name AS name LIMIT 1", {"h": claim.text}) return (bool(rows), [{"hpo": claim.text, "name": rows[0].get("name")}] if rows else []) if claim.kind == "diagnosis": orpha = claim.text.replace("ORPHA:", "") rows = await _safe_query("MATCH (d:Disease {orphaCode: $o}) RETURN d.name AS name LIMIT 1", {"o": orpha}) return (bool(rows), [{"orpha": orpha, "name": rows[0].get("name")}] if rows else []) if claim.kind == "gene": rows = await _safe_query("MATCH (g:Gene {symbol: $s}) RETURN g.location AS loc LIMIT 1", {"s": claim.text}) return (bool(rows), [{"gene": claim.text, "loc": rows[0].get("loc")}] if rows else []) # dose, percentage, citation: no KG ground; require LLM verification (skipped in light mode) return False, [] async def verify(case_id: str, recommendation_text: str, *, mode: str = "light") -> VerifierReport: """Verify a recommendation. Light mode = KG-only; heavy mode = + LLM check.""" claims = extract_claims(recommendation_text) n_total = len(claims) if n_total == 0: return VerifierReport(claims=[], n_total=0, n_grounded=0, n_unverified=0, grounding_rate=0.0, redacted_text=recommendation_text) for c in claims: try: ok, evidence = await _ground_via_kg(c) c.grounded = ok c.evidence = evidence c.confidence = 0.95 if ok else 0.30 except Exception as e: logger.debug(f"verify claim failed: {e}") n_grounded = sum(1 for c in claims if c.grounded) redacted = recommendation_text for c in claims: if not c.grounded: redacted = redacted.replace(c.text, f"⚠[{c.text}?]") return VerifierReport( claims=claims, n_total=n_total, n_grounded=n_grounded, n_unverified=n_total - n_grounded, grounding_rate=n_grounded / n_total, redacted_text=redacted, )