timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Gemeo orchestrator — the Gemeo class and high-level entry points.
This is the public face of the module. Wraps every capability behind one
clean async surface:
twin = await build_gemeo(case_text="...", patient_info={...}, context={...})
print(twin.cohort, twin.subgraph, twin.trajectory, twin.risk, ...)
new = await evolve_gemeo(twin.id, new_phenotypes=[...])
cf = await what_if(twin.id, intervention={"type": "treatment", "drug": "Cerezyme"})
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from .types import GemeoTwin, VizData
from . import encoder as gencoder
from . import cohort as gcohort
from . import subgraph as gsub
from . import trajectory as gtraj
from . import risk as grisk
from . import repurpose as gdrugs
from . import ask as gask
from . import ground_sus as gsus
from . import viz as gviz
from . import whatif as gwhatif
from . import ddi as gddi
from . import pharmacogen as gpharm
from . import family as gfamily
from . import reverse_pheno as grpheno
from . import protocol_compliance as gpcdt
from . import consult as gconsult
from . import simulate as gsim
logger = logging.getLogger("gemeo.core")
# in-memory twin registry — case_id → GemeoTwin
# (Persistent state still lives in PatientSpace + Neo4j; this is a fast cache)
_TWINS: dict = {}
# ─── helpers ───────────────────────────────────────────────────────────────
async def _ensure_space(case_text: str, patient_info: dict, context: dict, run_diagnosis: bool):
"""Create or load a PatientSpace via the existing digital_twin_workflow."""
try:
from digital_twin_workflow import create_digital_twin
except ImportError as e:
logger.error(f"digital_twin_workflow not importable: {e}")
raise
payload = await create_digital_twin(
case_text=case_text,
patient_info=patient_info or {},
context=context or {},
run_diagnosis=run_diagnosis,
run_full_analysis=False,
)
return payload
async def _load_space(case_id: str):
from digital_twin_workflow import _get_or_load_space
return await _get_or_load_space(case_id)
def _extract_inputs(space):
"""Pull HPO/Gene lists + dx + sus_region from a PatientSpace."""
hpo_ids, gene_symbols, diagnoses = [], [], []
sus_region = None
snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
if snap is not None:
for p in snap.phenotypes:
if p.get("hpo_id"):
hpo_ids.append(p["hpo_id"])
for g in snap.genes:
if g.get("symbol"):
gene_symbols.append(str(g["symbol"]).upper())
for d in snap.diagnoses:
if d.get("orpha"):
diagnoses.append({
"orpha": d["orpha"],
"name": d.get("disease") or d.get("name"),
"probability": d.get("probability", 0.5),
"status": d.get("status", "active"),
})
sus_region = (snap.context or {}).get("sus_region")
# also pull active hypotheses
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
orpha = getattr(hyp, "orpha_code", None)
if orpha and orpha not in {d["orpha"] for d in diagnoses}:
diagnoses.append({
"orpha": orpha,
"name": getattr(hyp, "disease_name", "") or getattr(hyp, "name", ""),
"probability": getattr(hyp, "probability", 0.5),
"status": getattr(hyp, "status", "active"),
})
diagnoses.sort(key=lambda d: d.get("probability", 0), reverse=True)
return hpo_ids, gene_symbols, diagnoses, sus_region
# ─── primary entry points ──────────────────────────────────────────────────
async def build_gemeo(
*,
case_text: str,
patient_info: dict = None,
context: dict = None,
run_diagnosis: bool = True,
cohort_k: int = 10,
horizons_months: list[int] = None,
fast: bool = False,
) -> GemeoTwin:
"""Build a complete digital twin from a clinical case.
Args:
case_text: free-text or structured case description (PT-BR or EN)
patient_info: {age, sex, ethnicity, ...}
context: {sus_region, access_level, ...}
run_diagnosis: run engine_v2 to populate hypotheses
cohort_k: patients-like-mine size
horizons_months: trajectory horizons (default [6, 12, 24])
fast: skip slow stages (drugs, trials) for snappy demos
"""
horizons_months = horizons_months or [6, 12, 24]
# 1) PatientSpace creation (existing pipeline)
payload = await _ensure_space(case_text, patient_info, context, run_diagnosis)
case_id = payload.get("case_id") or payload.get("id")
space = await _load_space(case_id) if case_id else None
if space is None:
logger.warning(f"could not load PatientSpace for case {case_id}")
# 2) extract working set
hpo_ids, gene_symbols, diagnoses, sus_region = ([], [], [], None)
if space is not None:
hpo_ids, gene_symbols, diagnoses, sus_region = _extract_inputs(space)
if context and context.get("sus_region"):
sus_region = sus_region or context["sus_region"]
# 3) embed
embedding, enc_quality = (None, "skipped")
if space is not None:
try:
embedding, enc_quality = gencoder.encode_patient_space(space)
except Exception as e:
logger.warning(f"encode failed: {e}")
# 4) parallelize the read-only stages
# Prefer an explicit override (passed via patient_info or context) so
# callers can pre-seed the Orphanet code BEFORE the diagnosis pipeline
# has had a chance to run. Falls back to the top diagnosis.
target_orpha = None
if context and context.get("target_orpha"):
target_orpha = str(context["target_orpha"])
elif patient_info and patient_info.get("target_orpha"):
target_orpha = str(patient_info["target_orpha"])
elif diagnoses:
target_orpha = diagnoses[0].get("orpha")
async def _safe(awaitable):
try:
return await awaitable
except Exception as e:
logger.warning(f"stage failed: {e}")
return None
cohort_task = _safe(gcohort.find_cohort(
embedding=embedding,
hpo_ids=hpo_ids,
orpha_codes=[d["orpha"] for d in diagnoses if d.get("orpha")],
k=cohort_k,
include_literature=True,
)) if (embedding is not None or hpo_ids or diagnoses) else asyncio.sleep(0, result=None)
subgraph_task = _safe(gsub.extract(
patient_id=case_id or "anon",
hpo_ids=hpo_ids,
gene_symbols=gene_symbols,
target_orpha=target_orpha,
)) if hpo_ids or gene_symbols or target_orpha else asyncio.sleep(0, result=None)
# trajectory uses 3 LLM calls — heaviest stage by 5x. Skip in fast mode.
trajectory_task = (
_safe(gtraj.predict(space, horizons_months))
if (space and not fast)
else asyncio.sleep(0, result=None)
)
risk_task = _safe(grisk.assess(space, embedding=embedding)) if space else asyncio.sleep(0, result=None)
ask_task = _safe(gask.recommend(space, top_n=5)) if space else asyncio.sleep(0, result=None)
if not fast:
drugs_task = _safe(gdrugs.find(space, embedding=embedding, sus_region=sus_region)) if space else asyncio.sleep(0, result=None)
else:
drugs_task = asyncio.sleep(0, result=None)
# Case-driven additions: family / reverse-pheno / pcdt / ddi (medications)
medications_in_snapshot = []
if space is not None:
snap_now = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
if snap_now is not None:
medications_in_snapshot = list(snap_now.medications or [])
family_task = _safe(gfamily.assess(
orpha=target_orpha,
family_history=(snap_now.family_history if (space and hasattr(space, "get_current_snapshot") and space.get_current_snapshot()) else []),
sex=(patient_info or {}).get("sex"),
)) if target_orpha else asyncio.sleep(0, result=None)
rpheno_task = _safe(grpheno.look_for(
orpha=target_orpha,
already_present=hpo_ids,
top_n=10,
)) if target_orpha else asyncio.sleep(0, result=None)
async def _pcdt_async():
return gpcdt.assess(
orpha=target_orpha,
current_treatments=medications_in_snapshot,
current_labs=(snap_now.labs if snap_now else []),
current_imaging=(snap_now.imaging if snap_now else []),
)
pcdt_task = _safe(_pcdt_async()) if target_orpha else asyncio.sleep(0, result=None)
ddi_task = _safe(gddi.predict(
medications=medications_in_snapshot,
)) if (medications_in_snapshot and len(medications_in_snapshot) >= 2) else asyncio.sleep(0, result=None)
cohort_v, subgraph_v, traj_v, risk_v, ask_v, drugs_v, family_v, rpheno_v, pcdt_v, ddi_v = await asyncio.gather(
cohort_task, subgraph_task, trajectory_task, risk_task, ask_task, drugs_task,
family_task, rpheno_task, pcdt_task, ddi_task,
)
# Pharmacogenomics — needs both genes and drug candidates
pharm_v = None
if space is not None and (snap_now.genes if snap_now else []) and drugs_v and drugs_v.candidates:
try:
pharm_v = await gpharm.assess(
genes=snap_now.genes,
drug_candidates=drugs_v.candidates[:8],
)
except Exception as e:
logger.debug(f"pharmacogen failed: {e}")
# 5) trials — wrap existing
# trial_matcher.TrialMatchResult uses `trials: list[TrialMatch]`,
# NOT `matches`. Each TrialMatch has trial_id (NCT number),
# match_score (not "score"), conditions/interventions/url/etc.
# The previous mapping looked for nct_id/score/eligibility_summary
# which don't exist on the dataclass → every field serialized as
# None and the count came out 0 even when ClinicalTrials.gov
# returned real hits.
trials_v = None
if space and not fast:
try:
from .types import TrialSpec
from trial_matcher import match_trials
tm = await match_trials(space, max_trials=8)
matches = []
if tm is not None:
raw = (
getattr(tm, "trials", None)
or getattr(tm, "matches", None)
or (tm.get("trials") if isinstance(tm, dict) else None)
or (tm.get("matches") if isinstance(tm, dict) else None)
or []
)
for m in raw[:8]:
if isinstance(m, dict):
matches.append(m)
else:
matches.append({
"nct_id": getattr(m, "trial_id", None) or getattr(m, "nct_id", None),
"title": getattr(m, "title", None),
"phase": getattr(m, "phase", None),
"status": getattr(m, "status", None),
"score": getattr(m, "match_score", None) or getattr(m, "score", None),
"conditions": getattr(m, "conditions", []) or [],
"interventions": getattr(m, "interventions", []) or [],
"explanation": getattr(m, "explanation", None),
"url": getattr(m, "url", None),
"has_brazil_site": getattr(m, "has_brazil_site", False),
})
n_searched = int(getattr(tm, "total_found", 0) or 0) if tm is not None else 0
trials_v = TrialSpec(matches=matches, model="trialgpt_bootstrap",
n_searched=max(n_searched, len(matches)))
except Exception as e:
logger.debug(f"trial_matcher failed: {e}")
# 6) SUS grounding
sus_v = None
if target_orpha:
try:
sus_v = gsus.check(
orpha=target_orpha,
disease_name=diagnoses[0].get("name") if diagnoses else None,
uf=sus_region,
)
except Exception as e:
logger.warning(f"SUS grounding failed: {e}")
# 7) viz
viz_v = None
if subgraph_v is not None:
try:
viz_v = gviz.from_subgraph(subgraph_v, center_id=f"patient:{case_id}")
except Exception as e:
logger.warning(f"viz formatting failed: {e}")
# 8) assemble
snap = space.get_current_snapshot() if (space and hasattr(space, "get_current_snapshot")) else None
twin = GemeoTwin(
case_id=case_id,
patient_id=(patient_info or {}).get("id"),
embedding=(embedding.tolist() if embedding is not None and hasattr(embedding, "tolist") else None),
embedding_dim=(int(embedding.shape[0]) if embedding is not None and hasattr(embedding, "shape") else 0),
diagnoses=diagnoses,
cohort=cohort_v,
subgraph=subgraph_v,
trajectory=traj_v,
risk=risk_v,
drugs=drugs_v,
trials=trials_v,
next_questions=ask_v or [],
sus_check=sus_v,
viz_data=viz_v,
ddi=ddi_v,
pharmacogen=pharm_v,
family=family_v,
reverse_pheno=rpheno_v,
protocol_compliance=pcdt_v,
snapshot_versions=[s.version for s in (space.get_trajectory() if space and hasattr(space, "get_trajectory") else [])],
n_phenotypes=len(snap.phenotypes) if snap else 0,
n_genes=len(snap.genes) if snap else 0,
n_labs=len(snap.labs) if snap else 0,
extra={
"encoder_quality": enc_quality,
"sus_region": sus_region,
"target_orpha": target_orpha,
},
)
if case_id:
_TWINS[case_id] = twin
return twin
async def evolve_gemeo(
case_id: str,
*,
new_phenotypes: list = None,
new_genes: list = None,
new_labs: list = None,
new_treatments: list = None,
cohort_k: int = 10,
horizons_months: list[int] = None,
) -> Optional[GemeoTwin]:
"""Add new clinical data to an existing twin and re-run all stages."""
try:
from digital_twin_workflow import evolve_digital_twin
except ImportError as e:
logger.error(f"evolve_digital_twin not importable: {e}")
return None
await evolve_digital_twin(
case_id=case_id,
new_phenotypes=new_phenotypes or [],
new_genes=new_genes or [],
new_labs=new_labs or [],
new_treatments=new_treatments or [],
)
space = await _load_space(case_id)
if space is None:
return None
hpo_ids, gene_symbols, diagnoses, sus_region = _extract_inputs(space)
embedding, enc_quality = gencoder.encode_patient_space(space)
target_orpha = diagnoses[0]["orpha"] if diagnoses else None
cohort_v, subgraph_v, traj_v, risk_v, ask_v = await asyncio.gather(
gcohort.find_cohort(embedding=embedding, hpo_ids=hpo_ids, k=cohort_k),
gsub.extract(patient_id=case_id, hpo_ids=hpo_ids, gene_symbols=gene_symbols, target_orpha=target_orpha),
gtraj.predict(space, horizons_months or [6, 12, 24]),
grisk.assess(space, embedding=embedding),
gask.recommend(space, top_n=5),
return_exceptions=True,
)
# exceptions → None
def _ok(v):
return None if isinstance(v, Exception) else v
twin = _TWINS.get(case_id) or GemeoTwin(case_id=case_id)
twin.embedding = embedding.tolist() if hasattr(embedding, "tolist") else None
twin.embedding_dim = int(embedding.shape[0]) if hasattr(embedding, "shape") else 0
twin.diagnoses = diagnoses
twin.cohort = _ok(cohort_v)
twin.subgraph = _ok(subgraph_v)
twin.trajectory = _ok(traj_v)
twin.risk = _ok(risk_v)
twin.next_questions = _ok(ask_v) or []
if twin.subgraph:
twin.viz_data = gviz.from_subgraph(twin.subgraph, center_id=f"patient:{case_id}")
snap = space.get_current_snapshot()
if snap:
twin.n_phenotypes = len(snap.phenotypes)
twin.n_genes = len(snap.genes)
twin.n_labs = len(snap.labs)
if target_orpha:
twin.sus_check = gsus.check(orpha=target_orpha, disease_name=diagnoses[0].get("name") if diagnoses else None, uf=sus_region)
from datetime import datetime, timezone
twin.updated_at = datetime.now(timezone.utc).isoformat()
_TWINS[case_id] = twin
return twin
async def what_if(case_id: str, intervention: dict) -> Optional[dict]:
"""Run a counterfactual on the twin. Returns serialized WhatIfResult."""
space = await _load_space(case_id)
if space is None:
return None
twin = _TWINS.get(case_id)
base_risk = twin.risk if twin else None
base_traj = twin.trajectory if twin else None
result = await gwhatif.simulate(
space, intervention,
baseline_risk=base_risk,
baseline_trajectory=base_traj,
)
from dataclasses import asdict
return asdict(result)
async def query_gemeo(case_id: str) -> Optional[GemeoTwin]:
"""Return the cached twin or rebuild lazily from PatientSpace."""
twin = _TWINS.get(case_id)
if twin is not None:
return twin
space = await _load_space(case_id)
if space is None:
return None
return await build_gemeo(
case_text="<reload>",
patient_info=None,
context=None,
run_diagnosis=False,
fast=True,
)
def get_gemeo(case_id: str) -> Optional[GemeoTwin]:
"""Synchronous in-memory lookup (no Neo4j fallback)."""
return _TWINS.get(case_id)
async def consult(case_id: str, panel: list[str] = None, question: str = None) -> Optional[dict]:
"""Run a multi-specialist consult on the twin."""
twin = await query_gemeo(case_id)
if twin is None:
return None
spec = await gconsult.consult(
twin,
panel=panel,
question=question or "Synthesise your opinion on this case.",
)
twin.consult = spec
if case_id:
_TWINS[case_id] = twin
from dataclasses import asdict
return asdict(spec)
async def simulate(
case_id: str,
*,
n_runs: int = 30,
intervention: dict = None,
horizons_months: list[int] = None,
) -> Optional[dict]:
"""Monte-Carlo simulation of trajectory under stochastic intervention."""
space = await _load_space(case_id)
if space is None:
return None
spec = await gsim.run(
space,
n_runs=n_runs,
intervention=intervention,
horizons_months=horizons_months,
)
from dataclasses import asdict
return asdict(spec)