gemeo-twin-stack / src /gemeo /trajectory.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Trajectory prediction — temporal evolution of the patient state.
Two paths:
- **Bootstrap (today)**: wraps `trajectory_engine.predict_next_state` (LLM
over disease natural-history KG queries) and structures output for the
Gemeo viz/UI.
- **TGNN (Phase 2, gemeo/train/tgnn.py)**: trains a Temporal Graph Network
(TRANS / TETGN-style) over snapshot chains. When checkpoint exists,
overrides the LLM path.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
from .types import TrajectorySpec, TrajectoryHorizon
logger = logging.getLogger("gemeo.trajectory")
TGNN_CKPT = os.environ.get(
"GEMEO_TGNN_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "tgnn_trajectory.pt"),
)
def _has_tgnn() -> bool:
return os.path.exists(TGNN_CKPT)
async def _predict_with_tgnn(space) -> Optional[TrajectorySpec]:
"""Call the trained TGNN. Returns None if model unavailable."""
if not _has_tgnn():
return None
try:
import torch # noqa: F401
from .train import tgnn as tgnn_mod
return await tgnn_mod.predict(space, TGNN_CKPT)
except Exception as e:
logger.warning(f"TGNN predict failed, falling back: {e}")
return None
async def _predict_with_llm(space, horizons_months: list[int]) -> TrajectorySpec:
"""Wraps trajectory_engine.predict_next_state, structures into TrajectorySpec."""
horizons = []
natural_history = []
try:
from trajectory_engine import predict_next_state, predict_complications
# natural history basis: top hypotheses' ORPHA codes
orphas = []
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
if getattr(hyp, "orpha_code", None):
orphas.append(hyp.orpha_code)
natural_history = orphas[:5]
# PARALLELIZE the per-horizon LLM calls — was sequential, now concurrent.
async def _one_horizon(h):
try:
try:
return h, await predict_next_state(space, time_horizon_months=h)
except TypeError:
return h, await predict_next_state(space, h)
except Exception as e:
logger.debug(f"predict_next_state failed for {h}m: {e}")
return h, None
horizon_results = await asyncio.gather(*[_one_horizon(h) for h in horizons_months])
# Preserve input order
horizon_results.sort(key=lambda x: horizons_months.index(x[0]))
for h, pred in horizon_results:
if pred is None:
horizons.append(TrajectoryHorizon(
months=h, state="unknown", risk_score=0.0,
confidence_low=0.0, confidence_high=0.0,
))
continue
# adapt various return shapes
if isinstance(pred, dict):
state = pred.get("state") or pred.get("predicted_state") or pred.get("summary", "")
risk = float(pred.get("risk_score") or pred.get("risk", 0.0) or 0.0)
ci_lo = float(pred.get("ci_low") or pred.get("confidence_low") or max(0, risk - 0.15))
ci_hi = float(pred.get("ci_high") or pred.get("confidence_high") or min(1, risk + 0.15))
phenos = pred.get("expected_phenotypes") or pred.get("new_phenotypes") or []
comps = pred.get("expected_complications") or pred.get("complications") or []
else:
state = getattr(pred, "predicted_state", "") or getattr(pred, "summary", "")
risk = float(getattr(pred, "risk_score", 0.0) or 0.0)
ci_lo = float(getattr(pred, "ci_low", max(0, risk - 0.15)))
ci_hi = float(getattr(pred, "ci_high", min(1, risk + 0.15)))
phenos = getattr(pred, "expected_phenotypes", []) or []
comps = getattr(pred, "expected_complications", []) or []
horizons.append(TrajectoryHorizon(
months=h, state=str(state)[:200],
risk_score=risk, confidence_low=ci_lo, confidence_high=ci_hi,
expected_phenotypes=phenos[:10],
expected_complications=comps[:10],
))
# complications top-level (12m default)
try:
comps = await predict_complications(space)
if comps and horizons:
# merge into 12m horizon if available
target = next((h for h in horizons if h.months == 12), horizons[-1])
serialized = []
for c in comps[:8]:
if isinstance(c, dict):
serialized.append({"name": c.get("name"), "prob": c.get("probability"), "when": c.get("expected_in")})
else:
serialized.append({
"name": getattr(c, "name", None),
"prob": getattr(c, "probability", None),
"when": getattr(c, "expected_in_months", None),
})
target.expected_complications = serialized
except Exception as e:
logger.debug(f"predict_complications failed: {e}")
except ImportError as e:
logger.warning(f"trajectory_engine not importable: {e}")
return TrajectorySpec(
horizons=horizons,
model="llm_fallback",
natural_history_basis=natural_history,
)
async def predict(
space,
horizons_months: list[int] = None,
) -> TrajectorySpec:
"""Predict patient state at the given horizons.
Args:
space: PatientSpace object (from patient_space.py)
horizons_months: defaults to [6, 12, 24]
"""
horizons_months = horizons_months or [6, 12, 24]
# try TGNN first
spec = await _predict_with_tgnn(space)
if spec is not None:
return spec
return await _predict_with_llm(space, horizons_months)