gemeo-twin-stack / src /gemeo /verifier.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Med-TIV-style verifier loop.
Following "Scaling Medical Reasoning Verification via Tool-Integrated RL"
(arXiv 2601.20221), every Gemeo recommendation is post-processed by a
verifier agent that:
1. Parses the recommendation into atomic claims.
2. For each claim, attempts to ground it via:
- gemeo_state(section) — twin lookup
- gemeo_lookup(query) — GraphRAG
- direct KG cypher (if needed)
3. Flags claims that cannot be grounded (potential hallucination).
4. Returns either OK (all grounded) or a corrected/redacted version.
Lightweight default: regex-based claim extraction + KG grounding via Aura.
Heavy mode: secondary LLM call to verify each claim.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Optional
logger = logging.getLogger("gemeo.verifier")
@dataclass
class Claim:
text: str
kind: str # diagnosis | drug | dose | risk | percentage | citation | hpo | gene | other
grounded: bool = False
evidence: list = field(default_factory=list) # list of supporting items
confidence: float = 0.5
@dataclass
class VerifierReport:
claims: list = field(default_factory=list) # list[Claim]
n_total: int = 0
n_grounded: int = 0
n_unverified: int = 0
grounding_rate: float = 0.0
redacted_text: str = "" # text with unverified claims marked
_PATTERNS = {
"hpo": re.compile(r"\b(HP:\d{7})\b"),
"orpha": re.compile(r"\b(ORPHA:\d{1,7})\b", re.IGNORECASE),
"gene": re.compile(r"\b([A-Z][A-Z0-9]{1,7})\s*(?:gene|mutation|variant|c\.|p\.)", re.IGNORECASE),
"drug_with_dose": re.compile(r"\b([A-Za-z][A-Za-z0-9-]+)\s+(\d+(?:\.\d+)?)\s*(mg|µg|ng|UI|IU|mEq)/?(kg|m2)?\b"),
"percentage": re.compile(r"\b(\d{1,3}(?:\.\d+)?)\s*%"),
"year": re.compile(r"\b(20\d{2})\b"),
"pmid": re.compile(r"\bPMID:?\s*(\d+)\b"),
}
def extract_claims(text: str) -> list[Claim]:
"""Pull atomic factual claims from a free-text recommendation."""
if not text:
return []
claims: list[Claim] = []
for hpo in _PATTERNS["hpo"].findall(text):
claims.append(Claim(text=hpo, kind="hpo"))
for orpha in _PATTERNS["orpha"].findall(text):
claims.append(Claim(text=orpha.upper(), kind="diagnosis"))
for gene in _PATTERNS["gene"].findall(text):
claims.append(Claim(text=gene.upper(), kind="gene"))
for name, dose, unit, per in _PATTERNS["drug_with_dose"].findall(text):
claims.append(Claim(text=f"{name} {dose} {unit}{('/' + per) if per else ''}", kind="dose"))
for pct in _PATTERNS["percentage"].findall(text):
claims.append(Claim(text=f"{pct}%", kind="percentage"))
for pmid in _PATTERNS["pmid"].findall(text):
claims.append(Claim(text=f"PMID:{pmid}", kind="citation"))
return claims
async def _ground_via_kg(claim: Claim) -> tuple[bool, list]:
"""Try to ground a claim via direct Aura cypher."""
try:
from space_graph import _safe_query
except ImportError:
return False, []
if claim.kind == "hpo":
rows = await _safe_query("MATCH (p:Phenotype {hpoId: $h}) RETURN p.name AS name LIMIT 1",
{"h": claim.text})
return (bool(rows), [{"hpo": claim.text, "name": rows[0].get("name")}] if rows else [])
if claim.kind == "diagnosis":
orpha = claim.text.replace("ORPHA:", "")
rows = await _safe_query("MATCH (d:Disease {orphaCode: $o}) RETURN d.name AS name LIMIT 1",
{"o": orpha})
return (bool(rows), [{"orpha": orpha, "name": rows[0].get("name")}] if rows else [])
if claim.kind == "gene":
rows = await _safe_query("MATCH (g:Gene {symbol: $s}) RETURN g.location AS loc LIMIT 1",
{"s": claim.text})
return (bool(rows), [{"gene": claim.text, "loc": rows[0].get("loc")}] if rows else [])
# dose, percentage, citation: no KG ground; require LLM verification (skipped in light mode)
return False, []
async def verify(case_id: str, recommendation_text: str, *,
mode: str = "light") -> VerifierReport:
"""Verify a recommendation. Light mode = KG-only; heavy mode = + LLM check."""
claims = extract_claims(recommendation_text)
n_total = len(claims)
if n_total == 0:
return VerifierReport(claims=[], n_total=0, n_grounded=0,
n_unverified=0, grounding_rate=0.0,
redacted_text=recommendation_text)
for c in claims:
try:
ok, evidence = await _ground_via_kg(c)
c.grounded = ok
c.evidence = evidence
c.confidence = 0.95 if ok else 0.30
except Exception as e:
logger.debug(f"verify claim failed: {e}")
n_grounded = sum(1 for c in claims if c.grounded)
redacted = recommendation_text
for c in claims:
if not c.grounded:
redacted = redacted.replace(c.text, f"⚠[{c.text}?]")
return VerifierReport(
claims=claims, n_total=n_total, n_grounded=n_grounded,
n_unverified=n_total - n_grounded,
grounding_rate=n_grounded / n_total,
redacted_text=redacted,
)