| """Med-TIV-style verifier loop. |
| |
| Following "Scaling Medical Reasoning Verification via Tool-Integrated RL" |
| (arXiv 2601.20221), every Gemeo recommendation is post-processed by a |
| verifier agent that: |
| 1. Parses the recommendation into atomic claims. |
| 2. For each claim, attempts to ground it via: |
| - gemeo_state(section) — twin lookup |
| - gemeo_lookup(query) — GraphRAG |
| - direct KG cypher (if needed) |
| 3. Flags claims that cannot be grounded (potential hallucination). |
| 4. Returns either OK (all grounded) or a corrected/redacted version. |
| |
| Lightweight default: regex-based claim extraction + KG grounding via Aura. |
| Heavy mode: secondary LLM call to verify each claim. |
| """ |
| from __future__ import annotations |
| import logging |
| import re |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| logger = logging.getLogger("gemeo.verifier") |
|
|
|
|
| @dataclass |
| class Claim: |
| text: str |
| kind: str |
| grounded: bool = False |
| evidence: list = field(default_factory=list) |
| confidence: float = 0.5 |
|
|
|
|
| @dataclass |
| class VerifierReport: |
| claims: list = field(default_factory=list) |
| n_total: int = 0 |
| n_grounded: int = 0 |
| n_unverified: int = 0 |
| grounding_rate: float = 0.0 |
| redacted_text: str = "" |
|
|
|
|
| _PATTERNS = { |
| "hpo": re.compile(r"\b(HP:\d{7})\b"), |
| "orpha": re.compile(r"\b(ORPHA:\d{1,7})\b", re.IGNORECASE), |
| "gene": re.compile(r"\b([A-Z][A-Z0-9]{1,7})\s*(?:gene|mutation|variant|c\.|p\.)", re.IGNORECASE), |
| "drug_with_dose": re.compile(r"\b([A-Za-z][A-Za-z0-9-]+)\s+(\d+(?:\.\d+)?)\s*(mg|µg|ng|UI|IU|mEq)/?(kg|m2)?\b"), |
| "percentage": re.compile(r"\b(\d{1,3}(?:\.\d+)?)\s*%"), |
| "year": re.compile(r"\b(20\d{2})\b"), |
| "pmid": re.compile(r"\bPMID:?\s*(\d+)\b"), |
| } |
|
|
|
|
| def extract_claims(text: str) -> list[Claim]: |
| """Pull atomic factual claims from a free-text recommendation.""" |
| if not text: |
| return [] |
| claims: list[Claim] = [] |
| for hpo in _PATTERNS["hpo"].findall(text): |
| claims.append(Claim(text=hpo, kind="hpo")) |
| for orpha in _PATTERNS["orpha"].findall(text): |
| claims.append(Claim(text=orpha.upper(), kind="diagnosis")) |
| for gene in _PATTERNS["gene"].findall(text): |
| claims.append(Claim(text=gene.upper(), kind="gene")) |
| for name, dose, unit, per in _PATTERNS["drug_with_dose"].findall(text): |
| claims.append(Claim(text=f"{name} {dose} {unit}{('/' + per) if per else ''}", kind="dose")) |
| for pct in _PATTERNS["percentage"].findall(text): |
| claims.append(Claim(text=f"{pct}%", kind="percentage")) |
| for pmid in _PATTERNS["pmid"].findall(text): |
| claims.append(Claim(text=f"PMID:{pmid}", kind="citation")) |
| return claims |
|
|
|
|
| async def _ground_via_kg(claim: Claim) -> tuple[bool, list]: |
| """Try to ground a claim via direct Aura cypher.""" |
| try: |
| from space_graph import _safe_query |
| except ImportError: |
| return False, [] |
|
|
| if claim.kind == "hpo": |
| rows = await _safe_query("MATCH (p:Phenotype {hpoId: $h}) RETURN p.name AS name LIMIT 1", |
| {"h": claim.text}) |
| return (bool(rows), [{"hpo": claim.text, "name": rows[0].get("name")}] if rows else []) |
| if claim.kind == "diagnosis": |
| orpha = claim.text.replace("ORPHA:", "") |
| rows = await _safe_query("MATCH (d:Disease {orphaCode: $o}) RETURN d.name AS name LIMIT 1", |
| {"o": orpha}) |
| return (bool(rows), [{"orpha": orpha, "name": rows[0].get("name")}] if rows else []) |
| if claim.kind == "gene": |
| rows = await _safe_query("MATCH (g:Gene {symbol: $s}) RETURN g.location AS loc LIMIT 1", |
| {"s": claim.text}) |
| return (bool(rows), [{"gene": claim.text, "loc": rows[0].get("loc")}] if rows else []) |
| |
| return False, [] |
|
|
|
|
| async def verify(case_id: str, recommendation_text: str, *, |
| mode: str = "light") -> VerifierReport: |
| """Verify a recommendation. Light mode = KG-only; heavy mode = + LLM check.""" |
| claims = extract_claims(recommendation_text) |
| n_total = len(claims) |
| if n_total == 0: |
| return VerifierReport(claims=[], n_total=0, n_grounded=0, |
| n_unverified=0, grounding_rate=0.0, |
| redacted_text=recommendation_text) |
|
|
| for c in claims: |
| try: |
| ok, evidence = await _ground_via_kg(c) |
| c.grounded = ok |
| c.evidence = evidence |
| c.confidence = 0.95 if ok else 0.30 |
| except Exception as e: |
| logger.debug(f"verify claim failed: {e}") |
|
|
| n_grounded = sum(1 for c in claims if c.grounded) |
| redacted = recommendation_text |
| for c in claims: |
| if not c.grounded: |
| redacted = redacted.replace(c.text, f"⚠[{c.text}?]") |
|
|
| return VerifierReport( |
| claims=claims, n_total=n_total, n_grounded=n_grounded, |
| n_unverified=n_total - n_grounded, |
| grounding_rate=n_grounded / n_total, |
| redacted_text=redacted, |
| ) |
|
|