"""Trajectory prediction — temporal evolution of the patient state. Two paths: - **Bootstrap (today)**: wraps `trajectory_engine.predict_next_state` (LLM over disease natural-history KG queries) and structures output for the Gemeo viz/UI. - **TGNN (Phase 2, gemeo/train/tgnn.py)**: trains a Temporal Graph Network (TRANS / TETGN-style) over snapshot chains. When checkpoint exists, overrides the LLM path. """ from __future__ import annotations import asyncio import logging import os from typing import Optional from .types import TrajectorySpec, TrajectoryHorizon logger = logging.getLogger("gemeo.trajectory") TGNN_CKPT = os.environ.get( "GEMEO_TGNN_CKPT", os.path.join(os.path.dirname(__file__), "artifacts", "tgnn_trajectory.pt"), ) def _has_tgnn() -> bool: return os.path.exists(TGNN_CKPT) async def _predict_with_tgnn(space) -> Optional[TrajectorySpec]: """Call the trained TGNN. Returns None if model unavailable.""" if not _has_tgnn(): return None try: import torch # noqa: F401 from .train import tgnn as tgnn_mod return await tgnn_mod.predict(space, TGNN_CKPT) except Exception as e: logger.warning(f"TGNN predict failed, falling back: {e}") return None async def _predict_with_llm(space, horizons_months: list[int]) -> TrajectorySpec: """Wraps trajectory_engine.predict_next_state, structures into TrajectorySpec.""" horizons = [] natural_history = [] try: from trajectory_engine import predict_next_state, predict_complications # natural history basis: top hypotheses' ORPHA codes orphas = [] for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): if getattr(hyp, "orpha_code", None): orphas.append(hyp.orpha_code) natural_history = orphas[:5] # PARALLELIZE the per-horizon LLM calls — was sequential, now concurrent. async def _one_horizon(h): try: try: return h, await predict_next_state(space, time_horizon_months=h) except TypeError: return h, await predict_next_state(space, h) except Exception as e: logger.debug(f"predict_next_state failed for {h}m: {e}") return h, None horizon_results = await asyncio.gather(*[_one_horizon(h) for h in horizons_months]) # Preserve input order horizon_results.sort(key=lambda x: horizons_months.index(x[0])) for h, pred in horizon_results: if pred is None: horizons.append(TrajectoryHorizon( months=h, state="unknown", risk_score=0.0, confidence_low=0.0, confidence_high=0.0, )) continue # adapt various return shapes if isinstance(pred, dict): state = pred.get("state") or pred.get("predicted_state") or pred.get("summary", "") risk = float(pred.get("risk_score") or pred.get("risk", 0.0) or 0.0) ci_lo = float(pred.get("ci_low") or pred.get("confidence_low") or max(0, risk - 0.15)) ci_hi = float(pred.get("ci_high") or pred.get("confidence_high") or min(1, risk + 0.15)) phenos = pred.get("expected_phenotypes") or pred.get("new_phenotypes") or [] comps = pred.get("expected_complications") or pred.get("complications") or [] else: state = getattr(pred, "predicted_state", "") or getattr(pred, "summary", "") risk = float(getattr(pred, "risk_score", 0.0) or 0.0) ci_lo = float(getattr(pred, "ci_low", max(0, risk - 0.15))) ci_hi = float(getattr(pred, "ci_high", min(1, risk + 0.15))) phenos = getattr(pred, "expected_phenotypes", []) or [] comps = getattr(pred, "expected_complications", []) or [] horizons.append(TrajectoryHorizon( months=h, state=str(state)[:200], risk_score=risk, confidence_low=ci_lo, confidence_high=ci_hi, expected_phenotypes=phenos[:10], expected_complications=comps[:10], )) # complications top-level (12m default) try: comps = await predict_complications(space) if comps and horizons: # merge into 12m horizon if available target = next((h for h in horizons if h.months == 12), horizons[-1]) serialized = [] for c in comps[:8]: if isinstance(c, dict): serialized.append({"name": c.get("name"), "prob": c.get("probability"), "when": c.get("expected_in")}) else: serialized.append({ "name": getattr(c, "name", None), "prob": getattr(c, "probability", None), "when": getattr(c, "expected_in_months", None), }) target.expected_complications = serialized except Exception as e: logger.debug(f"predict_complications failed: {e}") except ImportError as e: logger.warning(f"trajectory_engine not importable: {e}") return TrajectorySpec( horizons=horizons, model="llm_fallback", natural_history_basis=natural_history, ) async def predict( space, horizons_months: list[int] = None, ) -> TrajectorySpec: """Predict patient state at the given horizons. Args: space: PatientSpace object (from patient_space.py) horizons_months: defaults to [6, 12, 24] """ horizons_months = horizons_months or [6, 12, 24] # try TGNN first spec = await _predict_with_tgnn(space) if spec is not None: return spec return await _predict_with_llm(space, horizons_months)