"""Gemeo orchestrator — the Gemeo class and high-level entry points. This is the public face of the module. Wraps every capability behind one clean async surface: twin = await build_gemeo(case_text="...", patient_info={...}, context={...}) print(twin.cohort, twin.subgraph, twin.trajectory, twin.risk, ...) new = await evolve_gemeo(twin.id, new_phenotypes=[...]) cf = await what_if(twin.id, intervention={"type": "treatment", "drug": "Cerezyme"}) """ from __future__ import annotations import asyncio import logging from typing import Optional from .types import GemeoTwin, VizData from . import encoder as gencoder from . import cohort as gcohort from . import subgraph as gsub from . import trajectory as gtraj from . import risk as grisk from . import repurpose as gdrugs from . import ask as gask from . import ground_sus as gsus from . import viz as gviz from . import whatif as gwhatif from . import ddi as gddi from . import pharmacogen as gpharm from . import family as gfamily from . import reverse_pheno as grpheno from . import protocol_compliance as gpcdt from . import consult as gconsult from . import simulate as gsim logger = logging.getLogger("gemeo.core") # in-memory twin registry — case_id → GemeoTwin # (Persistent state still lives in PatientSpace + Neo4j; this is a fast cache) _TWINS: dict = {} # ─── helpers ─────────────────────────────────────────────────────────────── async def _ensure_space(case_text: str, patient_info: dict, context: dict, run_diagnosis: bool): """Create or load a PatientSpace via the existing digital_twin_workflow.""" try: from digital_twin_workflow import create_digital_twin except ImportError as e: logger.error(f"digital_twin_workflow not importable: {e}") raise payload = await create_digital_twin( case_text=case_text, patient_info=patient_info or {}, context=context or {}, run_diagnosis=run_diagnosis, run_full_analysis=False, ) return payload async def _load_space(case_id: str): from digital_twin_workflow import _get_or_load_space return await _get_or_load_space(case_id) def _extract_inputs(space): """Pull HPO/Gene lists + dx + sus_region from a PatientSpace.""" hpo_ids, gene_symbols, diagnoses = [], [], [] sus_region = None snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None if snap is not None: for p in snap.phenotypes: if p.get("hpo_id"): hpo_ids.append(p["hpo_id"]) for g in snap.genes: if g.get("symbol"): gene_symbols.append(str(g["symbol"]).upper()) for d in snap.diagnoses: if d.get("orpha"): diagnoses.append({ "orpha": d["orpha"], "name": d.get("disease") or d.get("name"), "probability": d.get("probability", 0.5), "status": d.get("status", "active"), }) sus_region = (snap.context or {}).get("sus_region") # also pull active hypotheses for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): orpha = getattr(hyp, "orpha_code", None) if orpha and orpha not in {d["orpha"] for d in diagnoses}: diagnoses.append({ "orpha": orpha, "name": getattr(hyp, "disease_name", "") or getattr(hyp, "name", ""), "probability": getattr(hyp, "probability", 0.5), "status": getattr(hyp, "status", "active"), }) diagnoses.sort(key=lambda d: d.get("probability", 0), reverse=True) return hpo_ids, gene_symbols, diagnoses, sus_region # ─── primary entry points ────────────────────────────────────────────────── async def build_gemeo( *, case_text: str, patient_info: dict = None, context: dict = None, run_diagnosis: bool = True, cohort_k: int = 10, horizons_months: list[int] = None, fast: bool = False, ) -> GemeoTwin: """Build a complete digital twin from a clinical case. Args: case_text: free-text or structured case description (PT-BR or EN) patient_info: {age, sex, ethnicity, ...} context: {sus_region, access_level, ...} run_diagnosis: run engine_v2 to populate hypotheses cohort_k: patients-like-mine size horizons_months: trajectory horizons (default [6, 12, 24]) fast: skip slow stages (drugs, trials) for snappy demos """ horizons_months = horizons_months or [6, 12, 24] # 1) PatientSpace creation (existing pipeline) payload = await _ensure_space(case_text, patient_info, context, run_diagnosis) case_id = payload.get("case_id") or payload.get("id") space = await _load_space(case_id) if case_id else None if space is None: logger.warning(f"could not load PatientSpace for case {case_id}") # 2) extract working set hpo_ids, gene_symbols, diagnoses, sus_region = ([], [], [], None) if space is not None: hpo_ids, gene_symbols, diagnoses, sus_region = _extract_inputs(space) if context and context.get("sus_region"): sus_region = sus_region or context["sus_region"] # 3) embed embedding, enc_quality = (None, "skipped") if space is not None: try: embedding, enc_quality = gencoder.encode_patient_space(space) except Exception as e: logger.warning(f"encode failed: {e}") # 4) parallelize the read-only stages # Prefer an explicit override (passed via patient_info or context) so # callers can pre-seed the Orphanet code BEFORE the diagnosis pipeline # has had a chance to run. Falls back to the top diagnosis. target_orpha = None if context and context.get("target_orpha"): target_orpha = str(context["target_orpha"]) elif patient_info and patient_info.get("target_orpha"): target_orpha = str(patient_info["target_orpha"]) elif diagnoses: target_orpha = diagnoses[0].get("orpha") async def _safe(awaitable): try: return await awaitable except Exception as e: logger.warning(f"stage failed: {e}") return None cohort_task = _safe(gcohort.find_cohort( embedding=embedding, hpo_ids=hpo_ids, orpha_codes=[d["orpha"] for d in diagnoses if d.get("orpha")], k=cohort_k, include_literature=True, )) if (embedding is not None or hpo_ids or diagnoses) else asyncio.sleep(0, result=None) subgraph_task = _safe(gsub.extract( patient_id=case_id or "anon", hpo_ids=hpo_ids, gene_symbols=gene_symbols, target_orpha=target_orpha, )) if hpo_ids or gene_symbols or target_orpha else asyncio.sleep(0, result=None) # trajectory uses 3 LLM calls — heaviest stage by 5x. Skip in fast mode. trajectory_task = ( _safe(gtraj.predict(space, horizons_months)) if (space and not fast) else asyncio.sleep(0, result=None) ) risk_task = _safe(grisk.assess(space, embedding=embedding)) if space else asyncio.sleep(0, result=None) ask_task = _safe(gask.recommend(space, top_n=5)) if space else asyncio.sleep(0, result=None) if not fast: drugs_task = _safe(gdrugs.find(space, embedding=embedding, sus_region=sus_region)) if space else asyncio.sleep(0, result=None) else: drugs_task = asyncio.sleep(0, result=None) # Case-driven additions: family / reverse-pheno / pcdt / ddi (medications) medications_in_snapshot = [] if space is not None: snap_now = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None if snap_now is not None: medications_in_snapshot = list(snap_now.medications or []) family_task = _safe(gfamily.assess( orpha=target_orpha, family_history=(snap_now.family_history if (space and hasattr(space, "get_current_snapshot") and space.get_current_snapshot()) else []), sex=(patient_info or {}).get("sex"), )) if target_orpha else asyncio.sleep(0, result=None) rpheno_task = _safe(grpheno.look_for( orpha=target_orpha, already_present=hpo_ids, top_n=10, )) if target_orpha else asyncio.sleep(0, result=None) async def _pcdt_async(): return gpcdt.assess( orpha=target_orpha, current_treatments=medications_in_snapshot, current_labs=(snap_now.labs if snap_now else []), current_imaging=(snap_now.imaging if snap_now else []), ) pcdt_task = _safe(_pcdt_async()) if target_orpha else asyncio.sleep(0, result=None) ddi_task = _safe(gddi.predict( medications=medications_in_snapshot, )) if (medications_in_snapshot and len(medications_in_snapshot) >= 2) else asyncio.sleep(0, result=None) cohort_v, subgraph_v, traj_v, risk_v, ask_v, drugs_v, family_v, rpheno_v, pcdt_v, ddi_v = await asyncio.gather( cohort_task, subgraph_task, trajectory_task, risk_task, ask_task, drugs_task, family_task, rpheno_task, pcdt_task, ddi_task, ) # Pharmacogenomics — needs both genes and drug candidates pharm_v = None if space is not None and (snap_now.genes if snap_now else []) and drugs_v and drugs_v.candidates: try: pharm_v = await gpharm.assess( genes=snap_now.genes, drug_candidates=drugs_v.candidates[:8], ) except Exception as e: logger.debug(f"pharmacogen failed: {e}") # 5) trials — wrap existing # trial_matcher.TrialMatchResult uses `trials: list[TrialMatch]`, # NOT `matches`. Each TrialMatch has trial_id (NCT number), # match_score (not "score"), conditions/interventions/url/etc. # The previous mapping looked for nct_id/score/eligibility_summary # which don't exist on the dataclass → every field serialized as # None and the count came out 0 even when ClinicalTrials.gov # returned real hits. trials_v = None if space and not fast: try: from .types import TrialSpec from trial_matcher import match_trials tm = await match_trials(space, max_trials=8) matches = [] if tm is not None: raw = ( getattr(tm, "trials", None) or getattr(tm, "matches", None) or (tm.get("trials") if isinstance(tm, dict) else None) or (tm.get("matches") if isinstance(tm, dict) else None) or [] ) for m in raw[:8]: if isinstance(m, dict): matches.append(m) else: matches.append({ "nct_id": getattr(m, "trial_id", None) or getattr(m, "nct_id", None), "title": getattr(m, "title", None), "phase": getattr(m, "phase", None), "status": getattr(m, "status", None), "score": getattr(m, "match_score", None) or getattr(m, "score", None), "conditions": getattr(m, "conditions", []) or [], "interventions": getattr(m, "interventions", []) or [], "explanation": getattr(m, "explanation", None), "url": getattr(m, "url", None), "has_brazil_site": getattr(m, "has_brazil_site", False), }) n_searched = int(getattr(tm, "total_found", 0) or 0) if tm is not None else 0 trials_v = TrialSpec(matches=matches, model="trialgpt_bootstrap", n_searched=max(n_searched, len(matches))) except Exception as e: logger.debug(f"trial_matcher failed: {e}") # 6) SUS grounding sus_v = None if target_orpha: try: sus_v = gsus.check( orpha=target_orpha, disease_name=diagnoses[0].get("name") if diagnoses else None, uf=sus_region, ) except Exception as e: logger.warning(f"SUS grounding failed: {e}") # 7) viz viz_v = None if subgraph_v is not None: try: viz_v = gviz.from_subgraph(subgraph_v, center_id=f"patient:{case_id}") except Exception as e: logger.warning(f"viz formatting failed: {e}") # 8) assemble snap = space.get_current_snapshot() if (space and hasattr(space, "get_current_snapshot")) else None twin = GemeoTwin( case_id=case_id, patient_id=(patient_info or {}).get("id"), embedding=(embedding.tolist() if embedding is not None and hasattr(embedding, "tolist") else None), embedding_dim=(int(embedding.shape[0]) if embedding is not None and hasattr(embedding, "shape") else 0), diagnoses=diagnoses, cohort=cohort_v, subgraph=subgraph_v, trajectory=traj_v, risk=risk_v, drugs=drugs_v, trials=trials_v, next_questions=ask_v or [], sus_check=sus_v, viz_data=viz_v, ddi=ddi_v, pharmacogen=pharm_v, family=family_v, reverse_pheno=rpheno_v, protocol_compliance=pcdt_v, snapshot_versions=[s.version for s in (space.get_trajectory() if space and hasattr(space, "get_trajectory") else [])], n_phenotypes=len(snap.phenotypes) if snap else 0, n_genes=len(snap.genes) if snap else 0, n_labs=len(snap.labs) if snap else 0, extra={ "encoder_quality": enc_quality, "sus_region": sus_region, "target_orpha": target_orpha, }, ) if case_id: _TWINS[case_id] = twin return twin async def evolve_gemeo( case_id: str, *, new_phenotypes: list = None, new_genes: list = None, new_labs: list = None, new_treatments: list = None, cohort_k: int = 10, horizons_months: list[int] = None, ) -> Optional[GemeoTwin]: """Add new clinical data to an existing twin and re-run all stages.""" try: from digital_twin_workflow import evolve_digital_twin except ImportError as e: logger.error(f"evolve_digital_twin not importable: {e}") return None await evolve_digital_twin( case_id=case_id, new_phenotypes=new_phenotypes or [], new_genes=new_genes or [], new_labs=new_labs or [], new_treatments=new_treatments or [], ) space = await _load_space(case_id) if space is None: return None hpo_ids, gene_symbols, diagnoses, sus_region = _extract_inputs(space) embedding, enc_quality = gencoder.encode_patient_space(space) target_orpha = diagnoses[0]["orpha"] if diagnoses else None cohort_v, subgraph_v, traj_v, risk_v, ask_v = await asyncio.gather( gcohort.find_cohort(embedding=embedding, hpo_ids=hpo_ids, k=cohort_k), gsub.extract(patient_id=case_id, hpo_ids=hpo_ids, gene_symbols=gene_symbols, target_orpha=target_orpha), gtraj.predict(space, horizons_months or [6, 12, 24]), grisk.assess(space, embedding=embedding), gask.recommend(space, top_n=5), return_exceptions=True, ) # exceptions → None def _ok(v): return None if isinstance(v, Exception) else v twin = _TWINS.get(case_id) or GemeoTwin(case_id=case_id) twin.embedding = embedding.tolist() if hasattr(embedding, "tolist") else None twin.embedding_dim = int(embedding.shape[0]) if hasattr(embedding, "shape") else 0 twin.diagnoses = diagnoses twin.cohort = _ok(cohort_v) twin.subgraph = _ok(subgraph_v) twin.trajectory = _ok(traj_v) twin.risk = _ok(risk_v) twin.next_questions = _ok(ask_v) or [] if twin.subgraph: twin.viz_data = gviz.from_subgraph(twin.subgraph, center_id=f"patient:{case_id}") snap = space.get_current_snapshot() if snap: twin.n_phenotypes = len(snap.phenotypes) twin.n_genes = len(snap.genes) twin.n_labs = len(snap.labs) if target_orpha: twin.sus_check = gsus.check(orpha=target_orpha, disease_name=diagnoses[0].get("name") if diagnoses else None, uf=sus_region) from datetime import datetime, timezone twin.updated_at = datetime.now(timezone.utc).isoformat() _TWINS[case_id] = twin return twin async def what_if(case_id: str, intervention: dict) -> Optional[dict]: """Run a counterfactual on the twin. Returns serialized WhatIfResult.""" space = await _load_space(case_id) if space is None: return None twin = _TWINS.get(case_id) base_risk = twin.risk if twin else None base_traj = twin.trajectory if twin else None result = await gwhatif.simulate( space, intervention, baseline_risk=base_risk, baseline_trajectory=base_traj, ) from dataclasses import asdict return asdict(result) async def query_gemeo(case_id: str) -> Optional[GemeoTwin]: """Return the cached twin or rebuild lazily from PatientSpace.""" twin = _TWINS.get(case_id) if twin is not None: return twin space = await _load_space(case_id) if space is None: return None return await build_gemeo( case_text="", patient_info=None, context=None, run_diagnosis=False, fast=True, ) def get_gemeo(case_id: str) -> Optional[GemeoTwin]: """Synchronous in-memory lookup (no Neo4j fallback).""" return _TWINS.get(case_id) async def consult(case_id: str, panel: list[str] = None, question: str = None) -> Optional[dict]: """Run a multi-specialist consult on the twin.""" twin = await query_gemeo(case_id) if twin is None: return None spec = await gconsult.consult( twin, panel=panel, question=question or "Synthesise your opinion on this case.", ) twin.consult = spec if case_id: _TWINS[case_id] = twin from dataclasses import asdict return asdict(spec) async def simulate( case_id: str, *, n_runs: int = 30, intervention: dict = None, horizons_months: list[int] = None, ) -> Optional[dict]: """Monte-Carlo simulation of trajectory under stochastic intervention.""" space = await _load_space(case_id) if space is None: return None spec = await gsim.run( space, n_runs=n_runs, intervention=intervention, horizons_months=horizons_months, ) from dataclasses import asdict return asdict(spec)