File size: 5,972 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Trajectory prediction — temporal evolution of the patient state.

Two paths:
  - **Bootstrap (today)**: wraps `trajectory_engine.predict_next_state` (LLM
    over disease natural-history KG queries) and structures output for the
    Gemeo viz/UI.
  - **TGNN (Phase 2, gemeo/train/tgnn.py)**: trains a Temporal Graph Network
    (TRANS / TETGN-style) over snapshot chains. When checkpoint exists,
    overrides the LLM path.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional

from .types import TrajectorySpec, TrajectoryHorizon

logger = logging.getLogger("gemeo.trajectory")

TGNN_CKPT = os.environ.get(
    "GEMEO_TGNN_CKPT",
    os.path.join(os.path.dirname(__file__), "artifacts", "tgnn_trajectory.pt"),
)


def _has_tgnn() -> bool:
    return os.path.exists(TGNN_CKPT)


async def _predict_with_tgnn(space) -> Optional[TrajectorySpec]:
    """Call the trained TGNN. Returns None if model unavailable."""
    if not _has_tgnn():
        return None
    try:
        import torch  # noqa: F401
        from .train import tgnn as tgnn_mod
        return await tgnn_mod.predict(space, TGNN_CKPT)
    except Exception as e:
        logger.warning(f"TGNN predict failed, falling back: {e}")
        return None


async def _predict_with_llm(space, horizons_months: list[int]) -> TrajectorySpec:
    """Wraps trajectory_engine.predict_next_state, structures into TrajectorySpec."""
    horizons = []
    natural_history = []
    try:
        from trajectory_engine import predict_next_state, predict_complications

        # natural history basis: top hypotheses' ORPHA codes
        orphas = []
        for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
            if getattr(hyp, "orpha_code", None):
                orphas.append(hyp.orpha_code)
        natural_history = orphas[:5]

        # PARALLELIZE the per-horizon LLM calls — was sequential, now concurrent.
        async def _one_horizon(h):
            try:
                try:
                    return h, await predict_next_state(space, time_horizon_months=h)
                except TypeError:
                    return h, await predict_next_state(space, h)
            except Exception as e:
                logger.debug(f"predict_next_state failed for {h}m: {e}")
                return h, None

        horizon_results = await asyncio.gather(*[_one_horizon(h) for h in horizons_months])
        # Preserve input order
        horizon_results.sort(key=lambda x: horizons_months.index(x[0]))

        for h, pred in horizon_results:
            if pred is None:
                horizons.append(TrajectoryHorizon(
                    months=h, state="unknown", risk_score=0.0,
                    confidence_low=0.0, confidence_high=0.0,
                ))
                continue

            # adapt various return shapes
            if isinstance(pred, dict):
                state = pred.get("state") or pred.get("predicted_state") or pred.get("summary", "")
                risk = float(pred.get("risk_score") or pred.get("risk", 0.0) or 0.0)
                ci_lo = float(pred.get("ci_low") or pred.get("confidence_low") or max(0, risk - 0.15))
                ci_hi = float(pred.get("ci_high") or pred.get("confidence_high") or min(1, risk + 0.15))
                phenos = pred.get("expected_phenotypes") or pred.get("new_phenotypes") or []
                comps = pred.get("expected_complications") or pred.get("complications") or []
            else:
                state = getattr(pred, "predicted_state", "") or getattr(pred, "summary", "")
                risk = float(getattr(pred, "risk_score", 0.0) or 0.0)
                ci_lo = float(getattr(pred, "ci_low", max(0, risk - 0.15)))
                ci_hi = float(getattr(pred, "ci_high", min(1, risk + 0.15)))
                phenos = getattr(pred, "expected_phenotypes", []) or []
                comps = getattr(pred, "expected_complications", []) or []

            horizons.append(TrajectoryHorizon(
                months=h, state=str(state)[:200],
                risk_score=risk, confidence_low=ci_lo, confidence_high=ci_hi,
                expected_phenotypes=phenos[:10],
                expected_complications=comps[:10],
            ))

        # complications top-level (12m default)
        try:
            comps = await predict_complications(space)
            if comps and horizons:
                # merge into 12m horizon if available
                target = next((h for h in horizons if h.months == 12), horizons[-1])
                serialized = []
                for c in comps[:8]:
                    if isinstance(c, dict):
                        serialized.append({"name": c.get("name"), "prob": c.get("probability"), "when": c.get("expected_in")})
                    else:
                        serialized.append({
                            "name": getattr(c, "name", None),
                            "prob": getattr(c, "probability", None),
                            "when": getattr(c, "expected_in_months", None),
                        })
                target.expected_complications = serialized
        except Exception as e:
            logger.debug(f"predict_complications failed: {e}")

    except ImportError as e:
        logger.warning(f"trajectory_engine not importable: {e}")

    return TrajectorySpec(
        horizons=horizons,
        model="llm_fallback",
        natural_history_basis=natural_history,
    )


async def predict(
    space,
    horizons_months: list[int] = None,
) -> TrajectorySpec:
    """Predict patient state at the given horizons.

    Args:
        space: PatientSpace object (from patient_space.py)
        horizons_months: defaults to [6, 12, 24]
    """
    horizons_months = horizons_months or [6, 12, 24]

    # try TGNN first
    spec = await _predict_with_tgnn(space)
    if spec is not None:
        return spec

    return await _predict_with_llm(space, horizons_months)