File size: 5,301 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Med-TIV-style verifier loop.

Following "Scaling Medical Reasoning Verification via Tool-Integrated RL"
(arXiv 2601.20221), every Gemeo recommendation is post-processed by a
verifier agent that:
  1. Parses the recommendation into atomic claims.
  2. For each claim, attempts to ground it via:
     - gemeo_state(section) — twin lookup
     - gemeo_lookup(query) — GraphRAG
     - direct KG cypher (if needed)
  3. Flags claims that cannot be grounded (potential hallucination).
  4. Returns either OK (all grounded) or a corrected/redacted version.

Lightweight default: regex-based claim extraction + KG grounding via Aura.
Heavy mode: secondary LLM call to verify each claim.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger("gemeo.verifier")


@dataclass
class Claim:
    text: str
    kind: str          # diagnosis | drug | dose | risk | percentage | citation | hpo | gene | other
    grounded: bool = False
    evidence: list = field(default_factory=list)  # list of supporting items
    confidence: float = 0.5


@dataclass
class VerifierReport:
    claims: list = field(default_factory=list)        # list[Claim]
    n_total: int = 0
    n_grounded: int = 0
    n_unverified: int = 0
    grounding_rate: float = 0.0
    redacted_text: str = ""                            # text with unverified claims marked


_PATTERNS = {
    "hpo":       re.compile(r"\b(HP:\d{7})\b"),
    "orpha":     re.compile(r"\b(ORPHA:\d{1,7})\b", re.IGNORECASE),
    "gene":      re.compile(r"\b([A-Z][A-Z0-9]{1,7})\s*(?:gene|mutation|variant|c\.|p\.)", re.IGNORECASE),
    "drug_with_dose": re.compile(r"\b([A-Za-z][A-Za-z0-9-]+)\s+(\d+(?:\.\d+)?)\s*(mg|µg|ng|UI|IU|mEq)/?(kg|m2)?\b"),
    "percentage": re.compile(r"\b(\d{1,3}(?:\.\d+)?)\s*%"),
    "year":      re.compile(r"\b(20\d{2})\b"),
    "pmid":      re.compile(r"\bPMID:?\s*(\d+)\b"),
}


def extract_claims(text: str) -> list[Claim]:
    """Pull atomic factual claims from a free-text recommendation."""
    if not text:
        return []
    claims: list[Claim] = []
    for hpo in _PATTERNS["hpo"].findall(text):
        claims.append(Claim(text=hpo, kind="hpo"))
    for orpha in _PATTERNS["orpha"].findall(text):
        claims.append(Claim(text=orpha.upper(), kind="diagnosis"))
    for gene in _PATTERNS["gene"].findall(text):
        claims.append(Claim(text=gene.upper(), kind="gene"))
    for name, dose, unit, per in _PATTERNS["drug_with_dose"].findall(text):
        claims.append(Claim(text=f"{name} {dose} {unit}{('/' + per) if per else ''}", kind="dose"))
    for pct in _PATTERNS["percentage"].findall(text):
        claims.append(Claim(text=f"{pct}%", kind="percentage"))
    for pmid in _PATTERNS["pmid"].findall(text):
        claims.append(Claim(text=f"PMID:{pmid}", kind="citation"))
    return claims


async def _ground_via_kg(claim: Claim) -> tuple[bool, list]:
    """Try to ground a claim via direct Aura cypher."""
    try:
        from space_graph import _safe_query
    except ImportError:
        return False, []

    if claim.kind == "hpo":
        rows = await _safe_query("MATCH (p:Phenotype {hpoId: $h}) RETURN p.name AS name LIMIT 1",
                                  {"h": claim.text})
        return (bool(rows), [{"hpo": claim.text, "name": rows[0].get("name")}] if rows else [])
    if claim.kind == "diagnosis":
        orpha = claim.text.replace("ORPHA:", "")
        rows = await _safe_query("MATCH (d:Disease {orphaCode: $o}) RETURN d.name AS name LIMIT 1",
                                  {"o": orpha})
        return (bool(rows), [{"orpha": orpha, "name": rows[0].get("name")}] if rows else [])
    if claim.kind == "gene":
        rows = await _safe_query("MATCH (g:Gene {symbol: $s}) RETURN g.location AS loc LIMIT 1",
                                  {"s": claim.text})
        return (bool(rows), [{"gene": claim.text, "loc": rows[0].get("loc")}] if rows else [])
    # dose, percentage, citation: no KG ground; require LLM verification (skipped in light mode)
    return False, []


async def verify(case_id: str, recommendation_text: str, *,
                 mode: str = "light") -> VerifierReport:
    """Verify a recommendation. Light mode = KG-only; heavy mode = + LLM check."""
    claims = extract_claims(recommendation_text)
    n_total = len(claims)
    if n_total == 0:
        return VerifierReport(claims=[], n_total=0, n_grounded=0,
                              n_unverified=0, grounding_rate=0.0,
                              redacted_text=recommendation_text)

    for c in claims:
        try:
            ok, evidence = await _ground_via_kg(c)
            c.grounded = ok
            c.evidence = evidence
            c.confidence = 0.95 if ok else 0.30
        except Exception as e:
            logger.debug(f"verify claim failed: {e}")

    n_grounded = sum(1 for c in claims if c.grounded)
    redacted = recommendation_text
    for c in claims:
        if not c.grounded:
            redacted = redacted.replace(c.text, f"⚠[{c.text}?]")

    return VerifierReport(
        claims=claims, n_total=n_total, n_grounded=n_grounded,
        n_unverified=n_total - n_grounded,
        grounding_rate=n_grounded / n_total,
        redacted_text=redacted,
    )