| """Trajectory prediction — temporal evolution of the patient state. |
| |
| Two paths: |
| - **Bootstrap (today)**: wraps `trajectory_engine.predict_next_state` (LLM |
| over disease natural-history KG queries) and structures output for the |
| Gemeo viz/UI. |
| - **TGNN (Phase 2, gemeo/train/tgnn.py)**: trains a Temporal Graph Network |
| (TRANS / TETGN-style) over snapshot chains. When checkpoint exists, |
| overrides the LLM path. |
| """ |
| from __future__ import annotations |
| import asyncio |
| import logging |
| import os |
| from typing import Optional |
|
|
| from .types import TrajectorySpec, TrajectoryHorizon |
|
|
| logger = logging.getLogger("gemeo.trajectory") |
|
|
| TGNN_CKPT = os.environ.get( |
| "GEMEO_TGNN_CKPT", |
| os.path.join(os.path.dirname(__file__), "artifacts", "tgnn_trajectory.pt"), |
| ) |
|
|
|
|
| def _has_tgnn() -> bool: |
| return os.path.exists(TGNN_CKPT) |
|
|
|
|
| async def _predict_with_tgnn(space) -> Optional[TrajectorySpec]: |
| """Call the trained TGNN. Returns None if model unavailable.""" |
| if not _has_tgnn(): |
| return None |
| try: |
| import torch |
| from .train import tgnn as tgnn_mod |
| return await tgnn_mod.predict(space, TGNN_CKPT) |
| except Exception as e: |
| logger.warning(f"TGNN predict failed, falling back: {e}") |
| return None |
|
|
|
|
| async def _predict_with_llm(space, horizons_months: list[int]) -> TrajectorySpec: |
| """Wraps trajectory_engine.predict_next_state, structures into TrajectorySpec.""" |
| horizons = [] |
| natural_history = [] |
| try: |
| from trajectory_engine import predict_next_state, predict_complications |
|
|
| |
| orphas = [] |
| for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): |
| if getattr(hyp, "orpha_code", None): |
| orphas.append(hyp.orpha_code) |
| natural_history = orphas[:5] |
|
|
| |
| async def _one_horizon(h): |
| try: |
| try: |
| return h, await predict_next_state(space, time_horizon_months=h) |
| except TypeError: |
| return h, await predict_next_state(space, h) |
| except Exception as e: |
| logger.debug(f"predict_next_state failed for {h}m: {e}") |
| return h, None |
|
|
| horizon_results = await asyncio.gather(*[_one_horizon(h) for h in horizons_months]) |
| |
| horizon_results.sort(key=lambda x: horizons_months.index(x[0])) |
|
|
| for h, pred in horizon_results: |
| if pred is None: |
| horizons.append(TrajectoryHorizon( |
| months=h, state="unknown", risk_score=0.0, |
| confidence_low=0.0, confidence_high=0.0, |
| )) |
| continue |
|
|
| |
| if isinstance(pred, dict): |
| state = pred.get("state") or pred.get("predicted_state") or pred.get("summary", "") |
| risk = float(pred.get("risk_score") or pred.get("risk", 0.0) or 0.0) |
| ci_lo = float(pred.get("ci_low") or pred.get("confidence_low") or max(0, risk - 0.15)) |
| ci_hi = float(pred.get("ci_high") or pred.get("confidence_high") or min(1, risk + 0.15)) |
| phenos = pred.get("expected_phenotypes") or pred.get("new_phenotypes") or [] |
| comps = pred.get("expected_complications") or pred.get("complications") or [] |
| else: |
| state = getattr(pred, "predicted_state", "") or getattr(pred, "summary", "") |
| risk = float(getattr(pred, "risk_score", 0.0) or 0.0) |
| ci_lo = float(getattr(pred, "ci_low", max(0, risk - 0.15))) |
| ci_hi = float(getattr(pred, "ci_high", min(1, risk + 0.15))) |
| phenos = getattr(pred, "expected_phenotypes", []) or [] |
| comps = getattr(pred, "expected_complications", []) or [] |
|
|
| horizons.append(TrajectoryHorizon( |
| months=h, state=str(state)[:200], |
| risk_score=risk, confidence_low=ci_lo, confidence_high=ci_hi, |
| expected_phenotypes=phenos[:10], |
| expected_complications=comps[:10], |
| )) |
|
|
| |
| try: |
| comps = await predict_complications(space) |
| if comps and horizons: |
| |
| target = next((h for h in horizons if h.months == 12), horizons[-1]) |
| serialized = [] |
| for c in comps[:8]: |
| if isinstance(c, dict): |
| serialized.append({"name": c.get("name"), "prob": c.get("probability"), "when": c.get("expected_in")}) |
| else: |
| serialized.append({ |
| "name": getattr(c, "name", None), |
| "prob": getattr(c, "probability", None), |
| "when": getattr(c, "expected_in_months", None), |
| }) |
| target.expected_complications = serialized |
| except Exception as e: |
| logger.debug(f"predict_complications failed: {e}") |
|
|
| except ImportError as e: |
| logger.warning(f"trajectory_engine not importable: {e}") |
|
|
| return TrajectorySpec( |
| horizons=horizons, |
| model="llm_fallback", |
| natural_history_basis=natural_history, |
| ) |
|
|
|
|
| async def predict( |
| space, |
| horizons_months: list[int] = None, |
| ) -> TrajectorySpec: |
| """Predict patient state at the given horizons. |
| |
| Args: |
| space: PatientSpace object (from patient_space.py) |
| horizons_months: defaults to [6, 12, 24] |
| """ |
| horizons_months = horizons_months or [6, 12, 24] |
|
|
| |
| spec = await _predict_with_tgnn(space) |
| if spec is not None: |
| return spec |
|
|
| return await _predict_with_llm(space, horizons_months) |
|
|