gemeo-twin-stack / src /gemeo /extractor.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""LLM-based clinical entity extractor — replaces the regex absorb path.
Robust to PT-BR, sinonímias, negação ("paciente NÃO tem ataxia" should
NOT extract HP:0001251), age qualifiers, family-history vs proband
distinction.
Falls back to the regex extractor (`gemeo.llm_context._HPO_RE` etc.)
when no LLM router is available, so it always produces an answer.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from typing import Optional
logger = logging.getLogger("gemeo.extractor")
@dataclass
class ClinicalEntities:
"""Structured output of the extractor."""
phenotypes: list = field(default_factory=list) # [{hpo_id, name, status: present|absent|family|past, severity?}]
diseases: list = field(default_factory=list) # [{orpha, name, status: confirmed|suspected|ruled_out}]
genes: list = field(default_factory=list) # [{symbol, variant?, zygosity?, pathogenicity?, status}]
labs: list = field(default_factory=list) # [{test, value, unit, abnormal?, date?}]
medications: list = field(default_factory=list) # [{name, dose?, frequency?, status: current|past|stopped}]
treatments: list = field(default_factory=list) # [{name, type?, response?}]
raw: dict = field(default_factory=dict)
_SYSTEM_PROMPT = (
"You extract structured clinical entities from a clinical message. "
"Return STRICT JSON matching this schema:\n"
"{\n"
' "phenotypes": [{"hpo_id": "HP:NNNNNNN", "name": "...", "status": "present|absent|family|past", "severity": "mild|moderate|severe|null"}],\n'
' "diseases": [{"orpha": "NNNN", "name": "...", "status": "confirmed|suspected|ruled_out"}],\n'
' "genes": [{"symbol": "GENE", "variant": "c.X>Y or null", "zygosity": "het|hom|null", "pathogenicity": "benign|likely_benign|VUS|likely_pathogenic|pathogenic|null", "status": "present|absent"}],\n'
' "labs": [{"test": "AFP", "value": "280", "unit": "ng/mL", "abnormal": true, "date": "YYYY-MM-DD or null"}],\n'
' "medications": [{"name": "IVIG", "dose": "...", "frequency": "...", "status": "current|past|stopped"}],\n'
' "treatments": [{"name": "ERT", "type": "...", "response": "good|partial|poor|null"}]\n'
"}\n"
"Critical rules:\n"
" - Detect NEGATION: 'paciente NÃO tem ataxia' → status='absent'.\n"
" - Detect FAMILY HISTORY: 'irmão com hemofilia' → status='family'.\n"
" - PT-BR synonyms: 'convulsão'→HP:0001250, 'ataxia'→HP:0001251, 'telangiectasia'→HP:0001009.\n"
" - Use STRICT HP/ORPHA codes; if unsure, omit the entity.\n"
" - Empty arrays are fine. NEVER include explanations outside the JSON."
)
def _get_llm():
"""Pick the best available cloud LLM for structured extraction.
Order:
1. llm_router.get_check_llm — Gemini-flash-lite class (fast + cheap, configured for this stack)
2. llm_router.get_orchestrator_llm — Gemini-flash class (mid)
3. Direct ChatGoogleGenerativeAI with GEMINI_API_KEY
Returns the LLM (with .ainvoke) or None if nothing works.
"""
import os
# Tier 1: try the swarm's check llm
try:
from llm_router import get_check_llm
llm = get_check_llm(temperature=0.0)
# filter out the local rarasnet-* models that need a Modal backend
model_attr = getattr(llm, "model_name", None) or getattr(llm, "model", None) or ""
if "rarasnet" not in str(model_attr).lower():
return llm
except Exception as e:
logger.debug(f"get_check_llm unavailable: {e}")
# Tier 2: orchestrator
try:
from llm_router import get_orchestrator_llm
llm = get_orchestrator_llm(temperature=0.0)
model_attr = getattr(llm, "model_name", None) or getattr(llm, "model", None) or ""
if "rarasnet" not in str(model_attr).lower():
return llm
except Exception as e:
logger.debug(f"get_orchestrator_llm unavailable: {e}")
# Tier 3: Gemini direct
api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
if not api_key:
return None
try:
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model="gemini-2.5-flash-lite",
google_api_key=api_key,
temperature=0.0,
)
except Exception as e:
logger.debug(f"direct Gemini unavailable: {e}")
return None
_PHRASE_PROMPT = (
"Extract CLINICAL PHRASES (as written, PT-BR ok) from a doctor's free-text case. "
"Do NOT invent HPO/ORPHA codes. Return STRICT JSON:\n"
"{\n"
' "phenotype_phrases": ["ataxia troncular","telangiectasia bulbar","IgA baixa", …],\n'
' "gene_symbols": ["ATM","CFTR"],\n'
' "disease_mentions": ["Ataxia-telangiectasia","ORPHA:100"],\n'
' "labs": [{"test":"AFP","value":"280","unit":"ng/mL","abnormal":true}],\n'
' "medications": [{"name":"…","status":"current|past|stopped"}],\n'
' "negated_phrases": ["sem hepatomegalia"],\n'
' "family_phrases": ["irmão com hemofilia"]\n'
"}\n"
"Rules:\n"
" - One phrase per clinical sign — don't merge multiple findings.\n"
" - Keep PT-BR wording verbatim; downstream KG maps to HPO codes.\n"
" - Negation goes to `negated_phrases`; proband-only signs to `phenotype_phrases`.\n"
" - Family history goes to `family_phrases` (don't double-count as proband).\n"
" - DO NOT include explanations outside the JSON."
)
async def _phrase_extract(message: str) -> Optional[dict]:
"""Step 1: LLM extracts PT-BR clinical phrases (NOT HPO codes).
The phrases are then normalized by `_kg_normalize_phrases` against
the raras-app `phenotype_search` FULLTEXT index. Pipeline:
text → LLM phrases → KG fulltext → HPO codes
This is the DeepRare-style 2-stage approach the user asked for
(proper hpo-brasil pipeline) — the KG already has PT-BR names,
synonyms, cultural-PT variants, and BioLORD embeddings for every
one of the 11.652 Phenotype nodes, so we don't need to ship the
sentence-transformers model + npz inside the orch image.
"""
import os
import httpx
providers = [
("DEEPSEEK_API_KEY", "https://api.deepseek.com/v1/chat/completions", "deepseek-chat"),
("CEREBRAS_API_KEY", "https://api.cerebras.ai/v1/chat/completions", "llama-3.3-70b"),
("GROQ_API_KEY", "https://api.groq.com/openai/v1/chat/completions", "llama-3.3-70b-versatile"),
]
async with httpx.AsyncClient(timeout=45.0) as http:
for env_key, url, model in providers:
key = os.environ.get(env_key)
if not key:
continue
try:
r = await http.post(
url,
json={
"model": model,
"messages": [
{"role": "system", "content": _PHRASE_PROMPT},
{"role": "user", "content": message},
],
"temperature": 0.0,
"response_format": {"type": "json_object"},
},
headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"},
)
if r.status_code != 200:
continue
txt = (r.json().get("choices") or [{}])[0].get("message", {}).get("content", "").strip()
if txt.startswith("```"):
txt = txt.strip("`")
if txt.lower().startswith("json"): txt = txt[4:]
txt = txt.strip().rstrip("`").rstrip()
s, e = txt.find("{"), txt.rfind("}")
if s >= 0 and e > s:
return json.loads(txt[s : e + 1])
except Exception as exc:
logger.debug(f"phrase extractor {env_key}: {exc}")
continue
return None
async def _kg_normalize_phrases(phrases: list[str], gene_symbols: list[str], status: str = "present") -> dict:
"""Step 2: phrases → HPO via raras-app KG `phenotype_search` FULLTEXT.
Index covers: name, namePt, synonymsPt, culturalSynonymsPt, definitionPt.
Lucene score threshold ≥ 4.0 (empirically: clearly-named phenotypes
score 8-12; weak matches fall below 4).
"""
out_phenos: list[dict] = []
out_genes: list[dict] = []
seen: set[str] = set()
try:
from tools import run_query
except ImportError:
return {"phenotypes": [], "genes": []}
for phrase in phrases[:30]:
try:
rows = await run_query(
"""
CALL db.index.fulltext.queryNodes('phenotype_search', $q)
YIELD node, score
WHERE score >= 4.0
RETURN node.hpoId AS hpo, node.name AS name,
coalesce(node.namePt, node.name) AS name_pt, score
ORDER BY score DESC LIMIT 1
""",
{"q": phrase}, timeout=8.0,
)
except Exception as exc:
logger.debug(f"KG normalize '{phrase}': {exc}")
rows = []
if rows and rows[0].get("hpo") and rows[0]["hpo"] not in seen:
hpo = rows[0]["hpo"]
seen.add(hpo)
out_phenos.append({
"hpo_id": hpo,
"name": rows[0].get("name_pt") or rows[0].get("name") or hpo,
"status": status,
"_source": "kg-fulltext",
"_score": float(rows[0]["score"]),
"_phrase": phrase,
})
# Validate gene symbols against KG so we don't store hallucinated names
for symbol in gene_symbols[:10]:
try:
rows = await run_query(
"MATCH (g:Gene {symbol: $s}) RETURN g.symbol AS symbol LIMIT 1",
{"s": symbol.upper()}, timeout=5.0,
)
if rows:
out_genes.append({"symbol": symbol.upper(), "status": "present"})
except Exception:
pass
return {"phenotypes": out_phenos, "genes": out_genes}
async def _resolve_diseases(mentions: list[str]) -> list[dict]:
"""Map disease mentions → ORPHA via Disease.name fulltext."""
out: list[dict] = []
try:
from tools import run_query
except ImportError:
return out
seen: set[str] = set()
for m in mentions[:6]:
# Explicit ORPHA:NNNN
import re
explicit = re.search(r"ORPHA[:\s]*(\d+)", m, re.I)
if explicit:
orpha = explicit.group(1)
if orpha not in seen:
seen.add(orpha)
out.append({"orpha": orpha, "name": m, "status": "suspected"})
continue
try:
rows = await run_query(
"""
CALL db.index.fulltext.queryNodes('disease_search', $q)
YIELD node, score
WHERE score >= 4.0
RETURN node.orphaCode AS orpha, node.name AS name, score
ORDER BY score DESC LIMIT 1
""",
{"q": m}, timeout=8.0,
)
except Exception:
rows = []
if rows and rows[0].get("orpha") and rows[0]["orpha"] not in seen:
seen.add(rows[0]["orpha"])
out.append({"orpha": rows[0]["orpha"], "name": rows[0]["name"], "status": "suspected"})
return out
async def _llm_extract_direct(message: str) -> Optional[dict]:
"""2-stage extraction: LLM phrases → raras-app KG (hpo-brasil substitute).
The raras-app KG ships PT-BR names + synonyms + cultural variants +
BioLORD embeddings for every Phenotype node — so we get hpo-brasil-
grade normalization without dragging torch + 350MB of model into the
orch image. Diseases resolve via `disease_search` fulltext too.
"""
phrases = await _phrase_extract(message)
if not phrases:
return None
normalized = await _kg_normalize_phrases(
phrases.get("phenotype_phrases", []) or [],
phrases.get("gene_symbols", []) or [],
)
# Family-history & negated phrases get normalized too, with status
# set accordingly. Useful for diff differential later.
family = await _kg_normalize_phrases(phrases.get("family_phrases", []) or [], [], status="family")
negated = await _kg_normalize_phrases(phrases.get("negated_phrases", []) or [], [], status="absent")
diseases = await _resolve_diseases(phrases.get("disease_mentions", []) or [])
logger.info(
f"extractor: kg-normalized {len(normalized['phenotypes'])} phenotypes "
f"+ {len(family['phenotypes'])} family + {len(negated['phenotypes'])} negated "
f"+ {len(diseases)} diseases"
)
return {
"phenotypes": normalized["phenotypes"] + family["phenotypes"] + negated["phenotypes"],
"diseases": diseases,
"genes": normalized["genes"],
"labs": phrases.get("labs", []) or [],
"medications": phrases.get("medications", []) or [],
"treatments": [],
}
async def _llm_extract(message: str) -> Optional[dict]:
"""Try to extract via the configured LLM. Robust to JSON-with-fences output."""
try:
from langchain_core.messages import SystemMessage, HumanMessage
except ImportError:
# No langchain installed (slim orch image) → go direct.
return await _llm_extract_direct(message)
llm = _get_llm()
if llm is None:
return await _llm_extract_direct(message)
try:
msgs = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(content=f"Extract from this clinical message:\n\n{message}"),
]
resp = await llm.ainvoke(msgs)
text = getattr(resp, "content", None) or str(resp)
text = text.strip()
# Strip ```json ... ``` fences
if text.startswith("```"):
text = text.strip("`")
if text.lower().startswith("json"):
text = text[4:]
text = text.strip()
if text.endswith("```"):
text = text[:-3].strip()
# Extract the JSON object body
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
text = text[start : end + 1]
return json.loads(text)
except Exception as e:
logger.debug(f"LLM extractor call failed: {e} — trying direct httpx fallback")
return await _llm_extract_direct(message)
def _regex_extract(message: str) -> dict:
"""Fallback: lightweight regex pass."""
from .llm_context import _HPO_RE, _ORPHA_RE, _GENE_RE
hpos = [{"hpo_id": f"HP:{m.group(1)}", "name": f"HP:{m.group(1)}", "status": "present"}
for m in _HPO_RE.finditer(message)]
orphas = [{"orpha": m.group(1), "status": "suspected"} for m in _ORPHA_RE.finditer(message)]
genes = [{"symbol": m.group(1).upper(), "status": "present"} for m in _GENE_RE.finditer(message)]
return {
"phenotypes": hpos, "diseases": orphas, "genes": genes,
"labs": [], "medications": [], "treatments": [],
}
async def extract(message: str) -> ClinicalEntities:
"""Top-level: try LLM, fall back to regex. Always returns ClinicalEntities."""
if not message or not message.strip():
return ClinicalEntities()
raw = await _llm_extract(message)
if raw is None:
raw = _regex_extract(message)
return ClinicalEntities(
phenotypes=raw.get("phenotypes", []) or [],
diseases=raw.get("diseases", []) or [],
genes=raw.get("genes", []) or [],
labs=raw.get("labs", []) or [],
medications=raw.get("medications", []) or [],
treatments=raw.get("treatments", []) or [],
raw=raw,
)
async def absorb(case_id: str, message: str, *, source: str = "user") -> dict:
"""Extract + feed into the twin. Honors negation/family-history flags.
- 'present' phenotypes/genes feed into evolve_gemeo as new_phenotypes/new_genes
- 'absent' / 'family' / 'past' are recorded as metadata only, NOT as patient findings
- 'confirmed' diseases become diagnoses
- 'ruled_out' diseases become rejected hypotheses
"""
if not case_id:
return {"absorbed": False, "reason": "no case_id"}
ents = await extract(message)
new_phenotypes = [
{"hpo_id": p.get("hpo_id"), "name": p.get("name") or p.get("hpo_id"),
"severity": p.get("severity"), "source": source, "status": "extracted"}
for p in ents.phenotypes if p.get("status") == "present" and p.get("hpo_id")
]
new_genes = [
{"symbol": g.get("symbol"), "variant": g.get("variant"),
"zygosity": g.get("zygosity"), "pathogenicity": g.get("pathogenicity"),
"source": source, "status": "extracted"}
for g in ents.genes if g.get("status") == "present" and g.get("symbol")
]
new_labs = [
{"test": l.get("test"), "value": l.get("value"), "unit": l.get("unit"),
"abnormal": l.get("abnormal"), "date": l.get("date"), "source": source}
for l in ents.labs if l.get("test")
]
new_treatments = [
{"name": t.get("name") or m.get("name"), "type": t.get("type"),
"response": t.get("response"), "source": source}
for items in (ents.treatments, ents.medications)
for t in items
for m in [{}] # let mixed lists pass
if (t.get("name") or m.get("name"))
]
try:
from . import core as gcore
if new_phenotypes or new_genes or new_labs or new_treatments:
await gcore.evolve_gemeo(
case_id,
new_phenotypes=new_phenotypes,
new_genes=new_genes,
new_labs=new_labs,
new_treatments=new_treatments,
)
except Exception as e:
logger.warning(f"evolve_gemeo failed during absorb: {e}")
return {"absorbed": False, "error": str(e), "extracted": ents.raw}
return {
"absorbed": True,
"source": source,
"added": {
"phenotypes": len(new_phenotypes),
"genes": len(new_genes),
"labs": len(new_labs),
"treatments": len(new_treatments),
},
"skipped_negated": sum(1 for p in ents.phenotypes if p.get("status") == "absent"),
"skipped_family": sum(1 for p in ents.phenotypes if p.get("status") == "family"),
"extracted": ents.raw,
}