| """Gemeo orchestrator — the Gemeo class and high-level entry points. |
| |
| This is the public face of the module. Wraps every capability behind one |
| clean async surface: |
| |
| twin = await build_gemeo(case_text="...", patient_info={...}, context={...}) |
| print(twin.cohort, twin.subgraph, twin.trajectory, twin.risk, ...) |
| |
| new = await evolve_gemeo(twin.id, new_phenotypes=[...]) |
| cf = await what_if(twin.id, intervention={"type": "treatment", "drug": "Cerezyme"}) |
| """ |
| from __future__ import annotations |
| import asyncio |
| import logging |
| from typing import Optional |
|
|
| from .types import GemeoTwin, VizData |
| from . import encoder as gencoder |
| from . import cohort as gcohort |
| from . import subgraph as gsub |
| from . import trajectory as gtraj |
| from . import risk as grisk |
| from . import repurpose as gdrugs |
| from . import ask as gask |
| from . import ground_sus as gsus |
| from . import viz as gviz |
| from . import whatif as gwhatif |
| from . import ddi as gddi |
| from . import pharmacogen as gpharm |
| from . import family as gfamily |
| from . import reverse_pheno as grpheno |
| from . import protocol_compliance as gpcdt |
| from . import consult as gconsult |
| from . import simulate as gsim |
|
|
| logger = logging.getLogger("gemeo.core") |
|
|
|
|
| |
| |
| _TWINS: dict = {} |
|
|
|
|
| |
|
|
| async def _ensure_space(case_text: str, patient_info: dict, context: dict, run_diagnosis: bool): |
| """Create or load a PatientSpace via the existing digital_twin_workflow.""" |
| try: |
| from digital_twin_workflow import create_digital_twin |
| except ImportError as e: |
| logger.error(f"digital_twin_workflow not importable: {e}") |
| raise |
|
|
| payload = await create_digital_twin( |
| case_text=case_text, |
| patient_info=patient_info or {}, |
| context=context or {}, |
| run_diagnosis=run_diagnosis, |
| run_full_analysis=False, |
| ) |
| return payload |
|
|
|
|
| async def _load_space(case_id: str): |
| from digital_twin_workflow import _get_or_load_space |
| return await _get_or_load_space(case_id) |
|
|
|
|
| def _extract_inputs(space): |
| """Pull HPO/Gene lists + dx + sus_region from a PatientSpace.""" |
| hpo_ids, gene_symbols, diagnoses = [], [], [] |
| sus_region = None |
| snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None |
| if snap is not None: |
| for p in snap.phenotypes: |
| if p.get("hpo_id"): |
| hpo_ids.append(p["hpo_id"]) |
| for g in snap.genes: |
| if g.get("symbol"): |
| gene_symbols.append(str(g["symbol"]).upper()) |
| for d in snap.diagnoses: |
| if d.get("orpha"): |
| diagnoses.append({ |
| "orpha": d["orpha"], |
| "name": d.get("disease") or d.get("name"), |
| "probability": d.get("probability", 0.5), |
| "status": d.get("status", "active"), |
| }) |
| sus_region = (snap.context or {}).get("sus_region") |
| |
| for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): |
| orpha = getattr(hyp, "orpha_code", None) |
| if orpha and orpha not in {d["orpha"] for d in diagnoses}: |
| diagnoses.append({ |
| "orpha": orpha, |
| "name": getattr(hyp, "disease_name", "") or getattr(hyp, "name", ""), |
| "probability": getattr(hyp, "probability", 0.5), |
| "status": getattr(hyp, "status", "active"), |
| }) |
| diagnoses.sort(key=lambda d: d.get("probability", 0), reverse=True) |
| return hpo_ids, gene_symbols, diagnoses, sus_region |
|
|
|
|
| |
|
|
| async def build_gemeo( |
| *, |
| case_text: str, |
| patient_info: dict = None, |
| context: dict = None, |
| run_diagnosis: bool = True, |
| cohort_k: int = 10, |
| horizons_months: list[int] = None, |
| fast: bool = False, |
| ) -> GemeoTwin: |
| """Build a complete digital twin from a clinical case. |
| |
| Args: |
| case_text: free-text or structured case description (PT-BR or EN) |
| patient_info: {age, sex, ethnicity, ...} |
| context: {sus_region, access_level, ...} |
| run_diagnosis: run engine_v2 to populate hypotheses |
| cohort_k: patients-like-mine size |
| horizons_months: trajectory horizons (default [6, 12, 24]) |
| fast: skip slow stages (drugs, trials) for snappy demos |
| """ |
| horizons_months = horizons_months or [6, 12, 24] |
|
|
| |
| payload = await _ensure_space(case_text, patient_info, context, run_diagnosis) |
| case_id = payload.get("case_id") or payload.get("id") |
| space = await _load_space(case_id) if case_id else None |
| if space is None: |
| logger.warning(f"could not load PatientSpace for case {case_id}") |
|
|
| |
| hpo_ids, gene_symbols, diagnoses, sus_region = ([], [], [], None) |
| if space is not None: |
| hpo_ids, gene_symbols, diagnoses, sus_region = _extract_inputs(space) |
| if context and context.get("sus_region"): |
| sus_region = sus_region or context["sus_region"] |
|
|
| |
| embedding, enc_quality = (None, "skipped") |
| if space is not None: |
| try: |
| embedding, enc_quality = gencoder.encode_patient_space(space) |
| except Exception as e: |
| logger.warning(f"encode failed: {e}") |
|
|
| |
| |
| |
| |
| target_orpha = None |
| if context and context.get("target_orpha"): |
| target_orpha = str(context["target_orpha"]) |
| elif patient_info and patient_info.get("target_orpha"): |
| target_orpha = str(patient_info["target_orpha"]) |
| elif diagnoses: |
| target_orpha = diagnoses[0].get("orpha") |
|
|
| async def _safe(awaitable): |
| try: |
| return await awaitable |
| except Exception as e: |
| logger.warning(f"stage failed: {e}") |
| return None |
|
|
| cohort_task = _safe(gcohort.find_cohort( |
| embedding=embedding, |
| hpo_ids=hpo_ids, |
| orpha_codes=[d["orpha"] for d in diagnoses if d.get("orpha")], |
| k=cohort_k, |
| include_literature=True, |
| )) if (embedding is not None or hpo_ids or diagnoses) else asyncio.sleep(0, result=None) |
|
|
| subgraph_task = _safe(gsub.extract( |
| patient_id=case_id or "anon", |
| hpo_ids=hpo_ids, |
| gene_symbols=gene_symbols, |
| target_orpha=target_orpha, |
| )) if hpo_ids or gene_symbols or target_orpha else asyncio.sleep(0, result=None) |
|
|
| |
| trajectory_task = ( |
| _safe(gtraj.predict(space, horizons_months)) |
| if (space and not fast) |
| else asyncio.sleep(0, result=None) |
| ) |
| risk_task = _safe(grisk.assess(space, embedding=embedding)) if space else asyncio.sleep(0, result=None) |
| ask_task = _safe(gask.recommend(space, top_n=5)) if space else asyncio.sleep(0, result=None) |
|
|
| if not fast: |
| drugs_task = _safe(gdrugs.find(space, embedding=embedding, sus_region=sus_region)) if space else asyncio.sleep(0, result=None) |
| else: |
| drugs_task = asyncio.sleep(0, result=None) |
|
|
| |
| medications_in_snapshot = [] |
| if space is not None: |
| snap_now = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None |
| if snap_now is not None: |
| medications_in_snapshot = list(snap_now.medications or []) |
|
|
| family_task = _safe(gfamily.assess( |
| orpha=target_orpha, |
| family_history=(snap_now.family_history if (space and hasattr(space, "get_current_snapshot") and space.get_current_snapshot()) else []), |
| sex=(patient_info or {}).get("sex"), |
| )) if target_orpha else asyncio.sleep(0, result=None) |
|
|
| rpheno_task = _safe(grpheno.look_for( |
| orpha=target_orpha, |
| already_present=hpo_ids, |
| top_n=10, |
| )) if target_orpha else asyncio.sleep(0, result=None) |
|
|
| async def _pcdt_async(): |
| return gpcdt.assess( |
| orpha=target_orpha, |
| current_treatments=medications_in_snapshot, |
| current_labs=(snap_now.labs if snap_now else []), |
| current_imaging=(snap_now.imaging if snap_now else []), |
| ) |
| pcdt_task = _safe(_pcdt_async()) if target_orpha else asyncio.sleep(0, result=None) |
|
|
| ddi_task = _safe(gddi.predict( |
| medications=medications_in_snapshot, |
| )) if (medications_in_snapshot and len(medications_in_snapshot) >= 2) else asyncio.sleep(0, result=None) |
|
|
| cohort_v, subgraph_v, traj_v, risk_v, ask_v, drugs_v, family_v, rpheno_v, pcdt_v, ddi_v = await asyncio.gather( |
| cohort_task, subgraph_task, trajectory_task, risk_task, ask_task, drugs_task, |
| family_task, rpheno_task, pcdt_task, ddi_task, |
| ) |
|
|
| |
| pharm_v = None |
| if space is not None and (snap_now.genes if snap_now else []) and drugs_v and drugs_v.candidates: |
| try: |
| pharm_v = await gpharm.assess( |
| genes=snap_now.genes, |
| drug_candidates=drugs_v.candidates[:8], |
| ) |
| except Exception as e: |
| logger.debug(f"pharmacogen failed: {e}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| trials_v = None |
| if space and not fast: |
| try: |
| from .types import TrialSpec |
| from trial_matcher import match_trials |
| tm = await match_trials(space, max_trials=8) |
| matches = [] |
| if tm is not None: |
| raw = ( |
| getattr(tm, "trials", None) |
| or getattr(tm, "matches", None) |
| or (tm.get("trials") if isinstance(tm, dict) else None) |
| or (tm.get("matches") if isinstance(tm, dict) else None) |
| or [] |
| ) |
| for m in raw[:8]: |
| if isinstance(m, dict): |
| matches.append(m) |
| else: |
| matches.append({ |
| "nct_id": getattr(m, "trial_id", None) or getattr(m, "nct_id", None), |
| "title": getattr(m, "title", None), |
| "phase": getattr(m, "phase", None), |
| "status": getattr(m, "status", None), |
| "score": getattr(m, "match_score", None) or getattr(m, "score", None), |
| "conditions": getattr(m, "conditions", []) or [], |
| "interventions": getattr(m, "interventions", []) or [], |
| "explanation": getattr(m, "explanation", None), |
| "url": getattr(m, "url", None), |
| "has_brazil_site": getattr(m, "has_brazil_site", False), |
| }) |
| n_searched = int(getattr(tm, "total_found", 0) or 0) if tm is not None else 0 |
| trials_v = TrialSpec(matches=matches, model="trialgpt_bootstrap", |
| n_searched=max(n_searched, len(matches))) |
| except Exception as e: |
| logger.debug(f"trial_matcher failed: {e}") |
|
|
| |
| sus_v = None |
| if target_orpha: |
| try: |
| sus_v = gsus.check( |
| orpha=target_orpha, |
| disease_name=diagnoses[0].get("name") if diagnoses else None, |
| uf=sus_region, |
| ) |
| except Exception as e: |
| logger.warning(f"SUS grounding failed: {e}") |
|
|
| |
| viz_v = None |
| if subgraph_v is not None: |
| try: |
| viz_v = gviz.from_subgraph(subgraph_v, center_id=f"patient:{case_id}") |
| except Exception as e: |
| logger.warning(f"viz formatting failed: {e}") |
|
|
| |
| snap = space.get_current_snapshot() if (space and hasattr(space, "get_current_snapshot")) else None |
| twin = GemeoTwin( |
| case_id=case_id, |
| patient_id=(patient_info or {}).get("id"), |
| embedding=(embedding.tolist() if embedding is not None and hasattr(embedding, "tolist") else None), |
| embedding_dim=(int(embedding.shape[0]) if embedding is not None and hasattr(embedding, "shape") else 0), |
| diagnoses=diagnoses, |
| cohort=cohort_v, |
| subgraph=subgraph_v, |
| trajectory=traj_v, |
| risk=risk_v, |
| drugs=drugs_v, |
| trials=trials_v, |
| next_questions=ask_v or [], |
| sus_check=sus_v, |
| viz_data=viz_v, |
| ddi=ddi_v, |
| pharmacogen=pharm_v, |
| family=family_v, |
| reverse_pheno=rpheno_v, |
| protocol_compliance=pcdt_v, |
| snapshot_versions=[s.version for s in (space.get_trajectory() if space and hasattr(space, "get_trajectory") else [])], |
| n_phenotypes=len(snap.phenotypes) if snap else 0, |
| n_genes=len(snap.genes) if snap else 0, |
| n_labs=len(snap.labs) if snap else 0, |
| extra={ |
| "encoder_quality": enc_quality, |
| "sus_region": sus_region, |
| "target_orpha": target_orpha, |
| }, |
| ) |
| if case_id: |
| _TWINS[case_id] = twin |
| return twin |
|
|
|
|
| async def evolve_gemeo( |
| case_id: str, |
| *, |
| new_phenotypes: list = None, |
| new_genes: list = None, |
| new_labs: list = None, |
| new_treatments: list = None, |
| cohort_k: int = 10, |
| horizons_months: list[int] = None, |
| ) -> Optional[GemeoTwin]: |
| """Add new clinical data to an existing twin and re-run all stages.""" |
| try: |
| from digital_twin_workflow import evolve_digital_twin |
| except ImportError as e: |
| logger.error(f"evolve_digital_twin not importable: {e}") |
| return None |
|
|
| await evolve_digital_twin( |
| case_id=case_id, |
| new_phenotypes=new_phenotypes or [], |
| new_genes=new_genes or [], |
| new_labs=new_labs or [], |
| new_treatments=new_treatments or [], |
| ) |
| space = await _load_space(case_id) |
| if space is None: |
| return None |
|
|
| hpo_ids, gene_symbols, diagnoses, sus_region = _extract_inputs(space) |
| embedding, enc_quality = gencoder.encode_patient_space(space) |
| target_orpha = diagnoses[0]["orpha"] if diagnoses else None |
|
|
| cohort_v, subgraph_v, traj_v, risk_v, ask_v = await asyncio.gather( |
| gcohort.find_cohort(embedding=embedding, hpo_ids=hpo_ids, k=cohort_k), |
| gsub.extract(patient_id=case_id, hpo_ids=hpo_ids, gene_symbols=gene_symbols, target_orpha=target_orpha), |
| gtraj.predict(space, horizons_months or [6, 12, 24]), |
| grisk.assess(space, embedding=embedding), |
| gask.recommend(space, top_n=5), |
| return_exceptions=True, |
| ) |
|
|
| |
| def _ok(v): |
| return None if isinstance(v, Exception) else v |
|
|
| twin = _TWINS.get(case_id) or GemeoTwin(case_id=case_id) |
| twin.embedding = embedding.tolist() if hasattr(embedding, "tolist") else None |
| twin.embedding_dim = int(embedding.shape[0]) if hasattr(embedding, "shape") else 0 |
| twin.diagnoses = diagnoses |
| twin.cohort = _ok(cohort_v) |
| twin.subgraph = _ok(subgraph_v) |
| twin.trajectory = _ok(traj_v) |
| twin.risk = _ok(risk_v) |
| twin.next_questions = _ok(ask_v) or [] |
| if twin.subgraph: |
| twin.viz_data = gviz.from_subgraph(twin.subgraph, center_id=f"patient:{case_id}") |
| snap = space.get_current_snapshot() |
| if snap: |
| twin.n_phenotypes = len(snap.phenotypes) |
| twin.n_genes = len(snap.genes) |
| twin.n_labs = len(snap.labs) |
| if target_orpha: |
| twin.sus_check = gsus.check(orpha=target_orpha, disease_name=diagnoses[0].get("name") if diagnoses else None, uf=sus_region) |
|
|
| from datetime import datetime, timezone |
| twin.updated_at = datetime.now(timezone.utc).isoformat() |
| _TWINS[case_id] = twin |
| return twin |
|
|
|
|
| async def what_if(case_id: str, intervention: dict) -> Optional[dict]: |
| """Run a counterfactual on the twin. Returns serialized WhatIfResult.""" |
| space = await _load_space(case_id) |
| if space is None: |
| return None |
| twin = _TWINS.get(case_id) |
| base_risk = twin.risk if twin else None |
| base_traj = twin.trajectory if twin else None |
| result = await gwhatif.simulate( |
| space, intervention, |
| baseline_risk=base_risk, |
| baseline_trajectory=base_traj, |
| ) |
| from dataclasses import asdict |
| return asdict(result) |
|
|
|
|
| async def query_gemeo(case_id: str) -> Optional[GemeoTwin]: |
| """Return the cached twin or rebuild lazily from PatientSpace.""" |
| twin = _TWINS.get(case_id) |
| if twin is not None: |
| return twin |
| space = await _load_space(case_id) |
| if space is None: |
| return None |
| return await build_gemeo( |
| case_text="<reload>", |
| patient_info=None, |
| context=None, |
| run_diagnosis=False, |
| fast=True, |
| ) |
|
|
|
|
| def get_gemeo(case_id: str) -> Optional[GemeoTwin]: |
| """Synchronous in-memory lookup (no Neo4j fallback).""" |
| return _TWINS.get(case_id) |
|
|
|
|
| async def consult(case_id: str, panel: list[str] = None, question: str = None) -> Optional[dict]: |
| """Run a multi-specialist consult on the twin.""" |
| twin = await query_gemeo(case_id) |
| if twin is None: |
| return None |
| spec = await gconsult.consult( |
| twin, |
| panel=panel, |
| question=question or "Synthesise your opinion on this case.", |
| ) |
| twin.consult = spec |
| if case_id: |
| _TWINS[case_id] = twin |
| from dataclasses import asdict |
| return asdict(spec) |
|
|
|
|
| async def simulate( |
| case_id: str, |
| *, |
| n_runs: int = 30, |
| intervention: dict = None, |
| horizons_months: list[int] = None, |
| ) -> Optional[dict]: |
| """Monte-Carlo simulation of trajectory under stochastic intervention.""" |
| space = await _load_space(case_id) |
| if space is None: |
| return None |
| spec = await gsim.run( |
| space, |
| n_runs=n_runs, |
| intervention=intervention, |
| horizons_months=horizons_months, |
| ) |
| from dataclasses import asdict |
| return asdict(spec) |
|
|