File size: 5,972 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """Trajectory prediction — temporal evolution of the patient state.
Two paths:
- **Bootstrap (today)**: wraps `trajectory_engine.predict_next_state` (LLM
over disease natural-history KG queries) and structures output for the
Gemeo viz/UI.
- **TGNN (Phase 2, gemeo/train/tgnn.py)**: trains a Temporal Graph Network
(TRANS / TETGN-style) over snapshot chains. When checkpoint exists,
overrides the LLM path.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
from .types import TrajectorySpec, TrajectoryHorizon
logger = logging.getLogger("gemeo.trajectory")
TGNN_CKPT = os.environ.get(
"GEMEO_TGNN_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "tgnn_trajectory.pt"),
)
def _has_tgnn() -> bool:
return os.path.exists(TGNN_CKPT)
async def _predict_with_tgnn(space) -> Optional[TrajectorySpec]:
"""Call the trained TGNN. Returns None if model unavailable."""
if not _has_tgnn():
return None
try:
import torch # noqa: F401
from .train import tgnn as tgnn_mod
return await tgnn_mod.predict(space, TGNN_CKPT)
except Exception as e:
logger.warning(f"TGNN predict failed, falling back: {e}")
return None
async def _predict_with_llm(space, horizons_months: list[int]) -> TrajectorySpec:
"""Wraps trajectory_engine.predict_next_state, structures into TrajectorySpec."""
horizons = []
natural_history = []
try:
from trajectory_engine import predict_next_state, predict_complications
# natural history basis: top hypotheses' ORPHA codes
orphas = []
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
if getattr(hyp, "orpha_code", None):
orphas.append(hyp.orpha_code)
natural_history = orphas[:5]
# PARALLELIZE the per-horizon LLM calls — was sequential, now concurrent.
async def _one_horizon(h):
try:
try:
return h, await predict_next_state(space, time_horizon_months=h)
except TypeError:
return h, await predict_next_state(space, h)
except Exception as e:
logger.debug(f"predict_next_state failed for {h}m: {e}")
return h, None
horizon_results = await asyncio.gather(*[_one_horizon(h) for h in horizons_months])
# Preserve input order
horizon_results.sort(key=lambda x: horizons_months.index(x[0]))
for h, pred in horizon_results:
if pred is None:
horizons.append(TrajectoryHorizon(
months=h, state="unknown", risk_score=0.0,
confidence_low=0.0, confidence_high=0.0,
))
continue
# adapt various return shapes
if isinstance(pred, dict):
state = pred.get("state") or pred.get("predicted_state") or pred.get("summary", "")
risk = float(pred.get("risk_score") or pred.get("risk", 0.0) or 0.0)
ci_lo = float(pred.get("ci_low") or pred.get("confidence_low") or max(0, risk - 0.15))
ci_hi = float(pred.get("ci_high") or pred.get("confidence_high") or min(1, risk + 0.15))
phenos = pred.get("expected_phenotypes") or pred.get("new_phenotypes") or []
comps = pred.get("expected_complications") or pred.get("complications") or []
else:
state = getattr(pred, "predicted_state", "") or getattr(pred, "summary", "")
risk = float(getattr(pred, "risk_score", 0.0) or 0.0)
ci_lo = float(getattr(pred, "ci_low", max(0, risk - 0.15)))
ci_hi = float(getattr(pred, "ci_high", min(1, risk + 0.15)))
phenos = getattr(pred, "expected_phenotypes", []) or []
comps = getattr(pred, "expected_complications", []) or []
horizons.append(TrajectoryHorizon(
months=h, state=str(state)[:200],
risk_score=risk, confidence_low=ci_lo, confidence_high=ci_hi,
expected_phenotypes=phenos[:10],
expected_complications=comps[:10],
))
# complications top-level (12m default)
try:
comps = await predict_complications(space)
if comps and horizons:
# merge into 12m horizon if available
target = next((h for h in horizons if h.months == 12), horizons[-1])
serialized = []
for c in comps[:8]:
if isinstance(c, dict):
serialized.append({"name": c.get("name"), "prob": c.get("probability"), "when": c.get("expected_in")})
else:
serialized.append({
"name": getattr(c, "name", None),
"prob": getattr(c, "probability", None),
"when": getattr(c, "expected_in_months", None),
})
target.expected_complications = serialized
except Exception as e:
logger.debug(f"predict_complications failed: {e}")
except ImportError as e:
logger.warning(f"trajectory_engine not importable: {e}")
return TrajectorySpec(
horizons=horizons,
model="llm_fallback",
natural_history_basis=natural_history,
)
async def predict(
space,
horizons_months: list[int] = None,
) -> TrajectorySpec:
"""Predict patient state at the given horizons.
Args:
space: PatientSpace object (from patient_space.py)
horizons_months: defaults to [6, 12, 24]
"""
horizons_months = horizons_months or [6, 12, 24]
# try TGNN first
spec = await _predict_with_tgnn(space)
if spec is not None:
return spec
return await _predict_with_llm(space, horizons_months)
|