"""LLM-based clinical entity extractor — replaces the regex absorb path. Robust to PT-BR, sinonímias, negação ("paciente NÃO tem ataxia" should NOT extract HP:0001251), age qualifiers, family-history vs proband distinction. Falls back to the regex extractor (`gemeo.llm_context._HPO_RE` etc.) when no LLM router is available, so it always produces an answer. """ from __future__ import annotations import json import logging from dataclasses import dataclass, field from typing import Optional logger = logging.getLogger("gemeo.extractor") @dataclass class ClinicalEntities: """Structured output of the extractor.""" phenotypes: list = field(default_factory=list) # [{hpo_id, name, status: present|absent|family|past, severity?}] diseases: list = field(default_factory=list) # [{orpha, name, status: confirmed|suspected|ruled_out}] genes: list = field(default_factory=list) # [{symbol, variant?, zygosity?, pathogenicity?, status}] labs: list = field(default_factory=list) # [{test, value, unit, abnormal?, date?}] medications: list = field(default_factory=list) # [{name, dose?, frequency?, status: current|past|stopped}] treatments: list = field(default_factory=list) # [{name, type?, response?}] raw: dict = field(default_factory=dict) _SYSTEM_PROMPT = ( "You extract structured clinical entities from a clinical message. " "Return STRICT JSON matching this schema:\n" "{\n" ' "phenotypes": [{"hpo_id": "HP:NNNNNNN", "name": "...", "status": "present|absent|family|past", "severity": "mild|moderate|severe|null"}],\n' ' "diseases": [{"orpha": "NNNN", "name": "...", "status": "confirmed|suspected|ruled_out"}],\n' ' "genes": [{"symbol": "GENE", "variant": "c.X>Y or null", "zygosity": "het|hom|null", "pathogenicity": "benign|likely_benign|VUS|likely_pathogenic|pathogenic|null", "status": "present|absent"}],\n' ' "labs": [{"test": "AFP", "value": "280", "unit": "ng/mL", "abnormal": true, "date": "YYYY-MM-DD or null"}],\n' ' "medications": [{"name": "IVIG", "dose": "...", "frequency": "...", "status": "current|past|stopped"}],\n' ' "treatments": [{"name": "ERT", "type": "...", "response": "good|partial|poor|null"}]\n' "}\n" "Critical rules:\n" " - Detect NEGATION: 'paciente NÃO tem ataxia' → status='absent'.\n" " - Detect FAMILY HISTORY: 'irmão com hemofilia' → status='family'.\n" " - PT-BR synonyms: 'convulsão'→HP:0001250, 'ataxia'→HP:0001251, 'telangiectasia'→HP:0001009.\n" " - Use STRICT HP/ORPHA codes; if unsure, omit the entity.\n" " - Empty arrays are fine. NEVER include explanations outside the JSON." ) def _get_llm(): """Pick the best available cloud LLM for structured extraction. Order: 1. llm_router.get_check_llm — Gemini-flash-lite class (fast + cheap, configured for this stack) 2. llm_router.get_orchestrator_llm — Gemini-flash class (mid) 3. Direct ChatGoogleGenerativeAI with GEMINI_API_KEY Returns the LLM (with .ainvoke) or None if nothing works. """ import os # Tier 1: try the swarm's check llm try: from llm_router import get_check_llm llm = get_check_llm(temperature=0.0) # filter out the local rarasnet-* models that need a Modal backend model_attr = getattr(llm, "model_name", None) or getattr(llm, "model", None) or "" if "rarasnet" not in str(model_attr).lower(): return llm except Exception as e: logger.debug(f"get_check_llm unavailable: {e}") # Tier 2: orchestrator try: from llm_router import get_orchestrator_llm llm = get_orchestrator_llm(temperature=0.0) model_attr = getattr(llm, "model_name", None) or getattr(llm, "model", None) or "" if "rarasnet" not in str(model_attr).lower(): return llm except Exception as e: logger.debug(f"get_orchestrator_llm unavailable: {e}") # Tier 3: Gemini direct api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") if not api_key: return None try: from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI( model="gemini-2.5-flash-lite", google_api_key=api_key, temperature=0.0, ) except Exception as e: logger.debug(f"direct Gemini unavailable: {e}") return None _PHRASE_PROMPT = ( "Extract CLINICAL PHRASES (as written, PT-BR ok) from a doctor's free-text case. " "Do NOT invent HPO/ORPHA codes. Return STRICT JSON:\n" "{\n" ' "phenotype_phrases": ["ataxia troncular","telangiectasia bulbar","IgA baixa", …],\n' ' "gene_symbols": ["ATM","CFTR"],\n' ' "disease_mentions": ["Ataxia-telangiectasia","ORPHA:100"],\n' ' "labs": [{"test":"AFP","value":"280","unit":"ng/mL","abnormal":true}],\n' ' "medications": [{"name":"…","status":"current|past|stopped"}],\n' ' "negated_phrases": ["sem hepatomegalia"],\n' ' "family_phrases": ["irmão com hemofilia"]\n' "}\n" "Rules:\n" " - One phrase per clinical sign — don't merge multiple findings.\n" " - Keep PT-BR wording verbatim; downstream KG maps to HPO codes.\n" " - Negation goes to `negated_phrases`; proband-only signs to `phenotype_phrases`.\n" " - Family history goes to `family_phrases` (don't double-count as proband).\n" " - DO NOT include explanations outside the JSON." ) async def _phrase_extract(message: str) -> Optional[dict]: """Step 1: LLM extracts PT-BR clinical phrases (NOT HPO codes). The phrases are then normalized by `_kg_normalize_phrases` against the raras-app `phenotype_search` FULLTEXT index. Pipeline: text → LLM phrases → KG fulltext → HPO codes This is the DeepRare-style 2-stage approach the user asked for (proper hpo-brasil pipeline) — the KG already has PT-BR names, synonyms, cultural-PT variants, and BioLORD embeddings for every one of the 11.652 Phenotype nodes, so we don't need to ship the sentence-transformers model + npz inside the orch image. """ import os import httpx providers = [ ("DEEPSEEK_API_KEY", "https://api.deepseek.com/v1/chat/completions", "deepseek-chat"), ("CEREBRAS_API_KEY", "https://api.cerebras.ai/v1/chat/completions", "llama-3.3-70b"), ("GROQ_API_KEY", "https://api.groq.com/openai/v1/chat/completions", "llama-3.3-70b-versatile"), ] async with httpx.AsyncClient(timeout=45.0) as http: for env_key, url, model in providers: key = os.environ.get(env_key) if not key: continue try: r = await http.post( url, json={ "model": model, "messages": [ {"role": "system", "content": _PHRASE_PROMPT}, {"role": "user", "content": message}, ], "temperature": 0.0, "response_format": {"type": "json_object"}, }, headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"}, ) if r.status_code != 200: continue txt = (r.json().get("choices") or [{}])[0].get("message", {}).get("content", "").strip() if txt.startswith("```"): txt = txt.strip("`") if txt.lower().startswith("json"): txt = txt[4:] txt = txt.strip().rstrip("`").rstrip() s, e = txt.find("{"), txt.rfind("}") if s >= 0 and e > s: return json.loads(txt[s : e + 1]) except Exception as exc: logger.debug(f"phrase extractor {env_key}: {exc}") continue return None async def _kg_normalize_phrases(phrases: list[str], gene_symbols: list[str], status: str = "present") -> dict: """Step 2: phrases → HPO via raras-app KG `phenotype_search` FULLTEXT. Index covers: name, namePt, synonymsPt, culturalSynonymsPt, definitionPt. Lucene score threshold ≥ 4.0 (empirically: clearly-named phenotypes score 8-12; weak matches fall below 4). """ out_phenos: list[dict] = [] out_genes: list[dict] = [] seen: set[str] = set() try: from tools import run_query except ImportError: return {"phenotypes": [], "genes": []} for phrase in phrases[:30]: try: rows = await run_query( """ CALL db.index.fulltext.queryNodes('phenotype_search', $q) YIELD node, score WHERE score >= 4.0 RETURN node.hpoId AS hpo, node.name AS name, coalesce(node.namePt, node.name) AS name_pt, score ORDER BY score DESC LIMIT 1 """, {"q": phrase}, timeout=8.0, ) except Exception as exc: logger.debug(f"KG normalize '{phrase}': {exc}") rows = [] if rows and rows[0].get("hpo") and rows[0]["hpo"] not in seen: hpo = rows[0]["hpo"] seen.add(hpo) out_phenos.append({ "hpo_id": hpo, "name": rows[0].get("name_pt") or rows[0].get("name") or hpo, "status": status, "_source": "kg-fulltext", "_score": float(rows[0]["score"]), "_phrase": phrase, }) # Validate gene symbols against KG so we don't store hallucinated names for symbol in gene_symbols[:10]: try: rows = await run_query( "MATCH (g:Gene {symbol: $s}) RETURN g.symbol AS symbol LIMIT 1", {"s": symbol.upper()}, timeout=5.0, ) if rows: out_genes.append({"symbol": symbol.upper(), "status": "present"}) except Exception: pass return {"phenotypes": out_phenos, "genes": out_genes} async def _resolve_diseases(mentions: list[str]) -> list[dict]: """Map disease mentions → ORPHA via Disease.name fulltext.""" out: list[dict] = [] try: from tools import run_query except ImportError: return out seen: set[str] = set() for m in mentions[:6]: # Explicit ORPHA:NNNN import re explicit = re.search(r"ORPHA[:\s]*(\d+)", m, re.I) if explicit: orpha = explicit.group(1) if orpha not in seen: seen.add(orpha) out.append({"orpha": orpha, "name": m, "status": "suspected"}) continue try: rows = await run_query( """ CALL db.index.fulltext.queryNodes('disease_search', $q) YIELD node, score WHERE score >= 4.0 RETURN node.orphaCode AS orpha, node.name AS name, score ORDER BY score DESC LIMIT 1 """, {"q": m}, timeout=8.0, ) except Exception: rows = [] if rows and rows[0].get("orpha") and rows[0]["orpha"] not in seen: seen.add(rows[0]["orpha"]) out.append({"orpha": rows[0]["orpha"], "name": rows[0]["name"], "status": "suspected"}) return out async def _llm_extract_direct(message: str) -> Optional[dict]: """2-stage extraction: LLM phrases → raras-app KG (hpo-brasil substitute). The raras-app KG ships PT-BR names + synonyms + cultural variants + BioLORD embeddings for every Phenotype node — so we get hpo-brasil- grade normalization without dragging torch + 350MB of model into the orch image. Diseases resolve via `disease_search` fulltext too. """ phrases = await _phrase_extract(message) if not phrases: return None normalized = await _kg_normalize_phrases( phrases.get("phenotype_phrases", []) or [], phrases.get("gene_symbols", []) or [], ) # Family-history & negated phrases get normalized too, with status # set accordingly. Useful for diff differential later. family = await _kg_normalize_phrases(phrases.get("family_phrases", []) or [], [], status="family") negated = await _kg_normalize_phrases(phrases.get("negated_phrases", []) or [], [], status="absent") diseases = await _resolve_diseases(phrases.get("disease_mentions", []) or []) logger.info( f"extractor: kg-normalized {len(normalized['phenotypes'])} phenotypes " f"+ {len(family['phenotypes'])} family + {len(negated['phenotypes'])} negated " f"+ {len(diseases)} diseases" ) return { "phenotypes": normalized["phenotypes"] + family["phenotypes"] + negated["phenotypes"], "diseases": diseases, "genes": normalized["genes"], "labs": phrases.get("labs", []) or [], "medications": phrases.get("medications", []) or [], "treatments": [], } async def _llm_extract(message: str) -> Optional[dict]: """Try to extract via the configured LLM. Robust to JSON-with-fences output.""" try: from langchain_core.messages import SystemMessage, HumanMessage except ImportError: # No langchain installed (slim orch image) → go direct. return await _llm_extract_direct(message) llm = _get_llm() if llm is None: return await _llm_extract_direct(message) try: msgs = [ SystemMessage(content=_SYSTEM_PROMPT), HumanMessage(content=f"Extract from this clinical message:\n\n{message}"), ] resp = await llm.ainvoke(msgs) text = getattr(resp, "content", None) or str(resp) text = text.strip() # Strip ```json ... ``` fences if text.startswith("```"): text = text.strip("`") if text.lower().startswith("json"): text = text[4:] text = text.strip() if text.endswith("```"): text = text[:-3].strip() # Extract the JSON object body start = text.find("{") end = text.rfind("}") if start >= 0 and end > start: text = text[start : end + 1] return json.loads(text) except Exception as e: logger.debug(f"LLM extractor call failed: {e} — trying direct httpx fallback") return await _llm_extract_direct(message) def _regex_extract(message: str) -> dict: """Fallback: lightweight regex pass.""" from .llm_context import _HPO_RE, _ORPHA_RE, _GENE_RE hpos = [{"hpo_id": f"HP:{m.group(1)}", "name": f"HP:{m.group(1)}", "status": "present"} for m in _HPO_RE.finditer(message)] orphas = [{"orpha": m.group(1), "status": "suspected"} for m in _ORPHA_RE.finditer(message)] genes = [{"symbol": m.group(1).upper(), "status": "present"} for m in _GENE_RE.finditer(message)] return { "phenotypes": hpos, "diseases": orphas, "genes": genes, "labs": [], "medications": [], "treatments": [], } async def extract(message: str) -> ClinicalEntities: """Top-level: try LLM, fall back to regex. Always returns ClinicalEntities.""" if not message or not message.strip(): return ClinicalEntities() raw = await _llm_extract(message) if raw is None: raw = _regex_extract(message) return ClinicalEntities( phenotypes=raw.get("phenotypes", []) or [], diseases=raw.get("diseases", []) or [], genes=raw.get("genes", []) or [], labs=raw.get("labs", []) or [], medications=raw.get("medications", []) or [], treatments=raw.get("treatments", []) or [], raw=raw, ) async def absorb(case_id: str, message: str, *, source: str = "user") -> dict: """Extract + feed into the twin. Honors negation/family-history flags. - 'present' phenotypes/genes feed into evolve_gemeo as new_phenotypes/new_genes - 'absent' / 'family' / 'past' are recorded as metadata only, NOT as patient findings - 'confirmed' diseases become diagnoses - 'ruled_out' diseases become rejected hypotheses """ if not case_id: return {"absorbed": False, "reason": "no case_id"} ents = await extract(message) new_phenotypes = [ {"hpo_id": p.get("hpo_id"), "name": p.get("name") or p.get("hpo_id"), "severity": p.get("severity"), "source": source, "status": "extracted"} for p in ents.phenotypes if p.get("status") == "present" and p.get("hpo_id") ] new_genes = [ {"symbol": g.get("symbol"), "variant": g.get("variant"), "zygosity": g.get("zygosity"), "pathogenicity": g.get("pathogenicity"), "source": source, "status": "extracted"} for g in ents.genes if g.get("status") == "present" and g.get("symbol") ] new_labs = [ {"test": l.get("test"), "value": l.get("value"), "unit": l.get("unit"), "abnormal": l.get("abnormal"), "date": l.get("date"), "source": source} for l in ents.labs if l.get("test") ] new_treatments = [ {"name": t.get("name") or m.get("name"), "type": t.get("type"), "response": t.get("response"), "source": source} for items in (ents.treatments, ents.medications) for t in items for m in [{}] # let mixed lists pass if (t.get("name") or m.get("name")) ] try: from . import core as gcore if new_phenotypes or new_genes or new_labs or new_treatments: await gcore.evolve_gemeo( case_id, new_phenotypes=new_phenotypes, new_genes=new_genes, new_labs=new_labs, new_treatments=new_treatments, ) except Exception as e: logger.warning(f"evolve_gemeo failed during absorb: {e}") return {"absorbed": False, "error": str(e), "extracted": ents.raw} return { "absorbed": True, "source": source, "added": { "phenotypes": len(new_phenotypes), "genes": len(new_genes), "labs": len(new_labs), "treatments": len(new_treatments), }, "skipped_negated": sum(1 for p in ents.phenotypes if p.get("status") == "absent"), "skipped_family": sum(1 for p in ents.phenotypes if p.get("status") == "family"), "extracted": ents.raw, }