File size: 7,174 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """In-silico Monte-Carlo simulation of trajectory.
Runs the trajectory module N times under stochastic perturbations
(intervention timing, treatment adherence, symptomatic event arrivals)
and aggregates outcome distributions. Gives the clinician a sense of
the *spread* of possible futures, not a single point estimate.
This is a thin Monte-Carlo wrapper over `trajectory.predict` and
`whatif.simulate` — no new model. The value is in the aggregation:
empirical CDFs over time-to-event, the modal vs tail outcomes, and
the sensitivity of the trajectory to small intervention timing shifts.
"""
from __future__ import annotations
import asyncio
import logging
import random
from typing import Optional
from .types import SimulationSpec, SimulationOutcome
logger = logging.getLogger("gemeo.simulate")
def _stochastic_intervention(base: dict, jitter: float = 0.2) -> dict:
"""Inject randomness into the intervention (timing, dose, adherence)."""
out = dict(base)
if "start_in_days" in out:
try:
d = float(out["start_in_days"])
out["start_in_days"] = max(0, d + random.gauss(0, d * jitter))
except Exception:
pass
out["adherence"] = max(0.4, min(1.0, random.gauss(0.85, 0.1)))
return out
def _summarize_curve(curves: list, month: int) -> dict:
"""Empirical CI over P(alive at `month`) across simulation runs."""
vals = []
for curve in curves:
for pt in curve:
if pt.get("month") == month:
vals.append(float(pt.get("p_alive", 0.0)))
break
if not vals:
return {"month": month, "n": 0}
vals.sort()
n = len(vals)
return {
"month": month,
"n": n,
"mean": round(sum(vals) / n, 4),
"p05": round(vals[int(0.05 * n)], 4),
"p50": round(vals[int(0.5 * n)], 4),
"p95": round(vals[min(n - 1, int(0.95 * n))], 4),
}
async def run(
space,
*,
n_runs: int = 30,
intervention: dict = None,
horizons_months: list[int] = None,
) -> SimulationSpec:
"""Run a Monte-Carlo simulation of the patient's trajectory.
Each run:
1. Optionally applies a perturbed copy of `intervention`.
2. Calls `gemeo.trajectory.predict` (which itself routes to TGNN
or LLM/rule-based bootstrap).
3. Calls `gemeo.risk.assess`.
"""
horizons_months = horizons_months or [6, 12, 24]
n_runs = max(5, min(200, int(n_runs)))
from . import trajectory as gtraj, risk as grisk
import random
import math
# 1) Compute the BASE prediction ONCE (one trajectory + one risk LLM call).
# Then perturb stochastically per-run — no per-run LLM (would blow up cost
# + time and hit rate limits). This is the right way to do MC over an
# LLM forecaster: cheap perturbation around an expensive base.
try:
base_traj = await gtraj.predict(space, horizons_months)
except Exception as e:
logger.warning(f"simulate: base trajectory failed: {e}")
base_traj = None
try:
base_risk = await grisk.assess(space)
except Exception as e:
logger.warning(f"simulate: base risk failed: {e}")
base_risk = None
if base_traj is None and base_risk is None:
return SimulationSpec(n_runs=0, intervention=intervention,
horizon_outcomes=[], survival_summary=[],
median_severity=0.0)
# 2) Apply intervention shift heuristically to the base
intervention_shift = 0.0
if intervention:
adherence = max(0.4, min(1.0, random.gauss(0.85, 0.10)))
# Treatment intent reduces risk; null intent leaves shift = 0
if intervention.get("type") == "treatment":
intervention_shift = -0.10 * adherence # ~10% absolute reduction at full adherence
elif intervention.get("type") == "phenotype_resolve":
intervention_shift = -0.05
elif intervention.get("type") == "phenotype_add":
intervention_shift = +0.08
# 3) N stochastic samples around the base
runs_traj = []
runs_risk = []
for _ in range(n_runs):
# Per-run noise increasing with horizon (longer = more uncertain)
if base_traj:
new_h = []
for h in base_traj.horizons:
noise = random.gauss(0, 0.05 + 0.02 * (h.months / 12))
new_score = max(0.0, min(1.0, h.risk_score + intervention_shift + noise))
new_h.append(type(h)(
months=h.months, state=h.state,
risk_score=new_score,
confidence_low=h.confidence_low, confidence_high=h.confidence_high,
expected_phenotypes=h.expected_phenotypes,
expected_complications=h.expected_complications,
))
runs_traj.append(new_h)
if base_risk:
sev_noise = random.gauss(0, 0.04)
new_sev = max(0.0, min(1.0, base_risk.overall_severity + intervention_shift + sev_noise))
# Recompute survival under the perturbed severity
from .risk import _approx_survival_from_severity
new_curve = _approx_survival_from_severity(new_sev, [p["month"] for p in (base_risk.survival_curve or [])])
runs_risk.append({"severity": new_sev, "curve": new_curve})
runs = list(zip(runs_traj or [None] * n_runs, runs_risk or [None] * n_runs))
# aggregate
by_horizon: dict = {h: [] for h in horizons_months}
survival_curves = []
severities = []
for traj, risk in runs:
if traj:
for h in traj:
if h.months in by_horizon:
by_horizon[h.months].append(h.risk_score)
if risk:
severities.append(risk["severity"])
if risk["curve"]:
survival_curves.append(risk["curve"])
horizon_summaries = []
for h, vals in by_horizon.items():
if not vals:
horizon_summaries.append(SimulationOutcome(
metric="risk_score", horizon_months=h,
n=0, mean=0.0, p05=0.0, p50=0.0, p95=0.0,
))
continue
vals_sorted = sorted(vals)
n = len(vals_sorted)
horizon_summaries.append(SimulationOutcome(
metric="risk_score",
horizon_months=h, n=n,
mean=round(sum(vals_sorted) / n, 4),
p05=round(vals_sorted[int(0.05 * n)], 4),
p50=round(vals_sorted[int(0.5 * n)], 4),
p95=round(vals_sorted[min(n - 1, int(0.95 * n))], 4),
))
survival_summary = []
if survival_curves:
all_months = sorted({pt["month"] for curve in survival_curves for pt in curve})
for m in all_months:
survival_summary.append(_summarize_curve(survival_curves, m))
severity_p50 = sorted(severities)[len(severities) // 2] if severities else 0.0
return SimulationSpec(
n_runs=len(runs),
intervention=intervention,
horizon_outcomes=horizon_summaries,
survival_summary=survival_summary,
median_severity=round(float(severity_p50), 4),
)
|