File size: 5,301 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """Med-TIV-style verifier loop.
Following "Scaling Medical Reasoning Verification via Tool-Integrated RL"
(arXiv 2601.20221), every Gemeo recommendation is post-processed by a
verifier agent that:
1. Parses the recommendation into atomic claims.
2. For each claim, attempts to ground it via:
- gemeo_state(section) — twin lookup
- gemeo_lookup(query) — GraphRAG
- direct KG cypher (if needed)
3. Flags claims that cannot be grounded (potential hallucination).
4. Returns either OK (all grounded) or a corrected/redacted version.
Lightweight default: regex-based claim extraction + KG grounding via Aura.
Heavy mode: secondary LLM call to verify each claim.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Optional
logger = logging.getLogger("gemeo.verifier")
@dataclass
class Claim:
text: str
kind: str # diagnosis | drug | dose | risk | percentage | citation | hpo | gene | other
grounded: bool = False
evidence: list = field(default_factory=list) # list of supporting items
confidence: float = 0.5
@dataclass
class VerifierReport:
claims: list = field(default_factory=list) # list[Claim]
n_total: int = 0
n_grounded: int = 0
n_unverified: int = 0
grounding_rate: float = 0.0
redacted_text: str = "" # text with unverified claims marked
_PATTERNS = {
"hpo": re.compile(r"\b(HP:\d{7})\b"),
"orpha": re.compile(r"\b(ORPHA:\d{1,7})\b", re.IGNORECASE),
"gene": re.compile(r"\b([A-Z][A-Z0-9]{1,7})\s*(?:gene|mutation|variant|c\.|p\.)", re.IGNORECASE),
"drug_with_dose": re.compile(r"\b([A-Za-z][A-Za-z0-9-]+)\s+(\d+(?:\.\d+)?)\s*(mg|µg|ng|UI|IU|mEq)/?(kg|m2)?\b"),
"percentage": re.compile(r"\b(\d{1,3}(?:\.\d+)?)\s*%"),
"year": re.compile(r"\b(20\d{2})\b"),
"pmid": re.compile(r"\bPMID:?\s*(\d+)\b"),
}
def extract_claims(text: str) -> list[Claim]:
"""Pull atomic factual claims from a free-text recommendation."""
if not text:
return []
claims: list[Claim] = []
for hpo in _PATTERNS["hpo"].findall(text):
claims.append(Claim(text=hpo, kind="hpo"))
for orpha in _PATTERNS["orpha"].findall(text):
claims.append(Claim(text=orpha.upper(), kind="diagnosis"))
for gene in _PATTERNS["gene"].findall(text):
claims.append(Claim(text=gene.upper(), kind="gene"))
for name, dose, unit, per in _PATTERNS["drug_with_dose"].findall(text):
claims.append(Claim(text=f"{name} {dose} {unit}{('/' + per) if per else ''}", kind="dose"))
for pct in _PATTERNS["percentage"].findall(text):
claims.append(Claim(text=f"{pct}%", kind="percentage"))
for pmid in _PATTERNS["pmid"].findall(text):
claims.append(Claim(text=f"PMID:{pmid}", kind="citation"))
return claims
async def _ground_via_kg(claim: Claim) -> tuple[bool, list]:
"""Try to ground a claim via direct Aura cypher."""
try:
from space_graph import _safe_query
except ImportError:
return False, []
if claim.kind == "hpo":
rows = await _safe_query("MATCH (p:Phenotype {hpoId: $h}) RETURN p.name AS name LIMIT 1",
{"h": claim.text})
return (bool(rows), [{"hpo": claim.text, "name": rows[0].get("name")}] if rows else [])
if claim.kind == "diagnosis":
orpha = claim.text.replace("ORPHA:", "")
rows = await _safe_query("MATCH (d:Disease {orphaCode: $o}) RETURN d.name AS name LIMIT 1",
{"o": orpha})
return (bool(rows), [{"orpha": orpha, "name": rows[0].get("name")}] if rows else [])
if claim.kind == "gene":
rows = await _safe_query("MATCH (g:Gene {symbol: $s}) RETURN g.location AS loc LIMIT 1",
{"s": claim.text})
return (bool(rows), [{"gene": claim.text, "loc": rows[0].get("loc")}] if rows else [])
# dose, percentage, citation: no KG ground; require LLM verification (skipped in light mode)
return False, []
async def verify(case_id: str, recommendation_text: str, *,
mode: str = "light") -> VerifierReport:
"""Verify a recommendation. Light mode = KG-only; heavy mode = + LLM check."""
claims = extract_claims(recommendation_text)
n_total = len(claims)
if n_total == 0:
return VerifierReport(claims=[], n_total=0, n_grounded=0,
n_unverified=0, grounding_rate=0.0,
redacted_text=recommendation_text)
for c in claims:
try:
ok, evidence = await _ground_via_kg(c)
c.grounded = ok
c.evidence = evidence
c.confidence = 0.95 if ok else 0.30
except Exception as e:
logger.debug(f"verify claim failed: {e}")
n_grounded = sum(1 for c in claims if c.grounded)
redacted = recommendation_text
for c in claims:
if not c.grounded:
redacted = redacted.replace(c.text, f"⚠[{c.text}?]")
return VerifierReport(
claims=claims, n_total=n_total, n_grounded=n_grounded,
n_unverified=n_total - n_grounded,
grounding_rate=n_grounded / n_total,
redacted_text=redacted,
)
|