gemeo-twin-stack / src /gemeo /simulate.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""In-silico Monte-Carlo simulation of trajectory.
Runs the trajectory module N times under stochastic perturbations
(intervention timing, treatment adherence, symptomatic event arrivals)
and aggregates outcome distributions. Gives the clinician a sense of
the *spread* of possible futures, not a single point estimate.
This is a thin Monte-Carlo wrapper over `trajectory.predict` and
`whatif.simulate` — no new model. The value is in the aggregation:
empirical CDFs over time-to-event, the modal vs tail outcomes, and
the sensitivity of the trajectory to small intervention timing shifts.
"""
from __future__ import annotations
import asyncio
import logging
import random
from typing import Optional
from .types import SimulationSpec, SimulationOutcome
logger = logging.getLogger("gemeo.simulate")
def _stochastic_intervention(base: dict, jitter: float = 0.2) -> dict:
"""Inject randomness into the intervention (timing, dose, adherence)."""
out = dict(base)
if "start_in_days" in out:
try:
d = float(out["start_in_days"])
out["start_in_days"] = max(0, d + random.gauss(0, d * jitter))
except Exception:
pass
out["adherence"] = max(0.4, min(1.0, random.gauss(0.85, 0.1)))
return out
def _summarize_curve(curves: list, month: int) -> dict:
"""Empirical CI over P(alive at `month`) across simulation runs."""
vals = []
for curve in curves:
for pt in curve:
if pt.get("month") == month:
vals.append(float(pt.get("p_alive", 0.0)))
break
if not vals:
return {"month": month, "n": 0}
vals.sort()
n = len(vals)
return {
"month": month,
"n": n,
"mean": round(sum(vals) / n, 4),
"p05": round(vals[int(0.05 * n)], 4),
"p50": round(vals[int(0.5 * n)], 4),
"p95": round(vals[min(n - 1, int(0.95 * n))], 4),
}
async def run(
space,
*,
n_runs: int = 30,
intervention: dict = None,
horizons_months: list[int] = None,
) -> SimulationSpec:
"""Run a Monte-Carlo simulation of the patient's trajectory.
Each run:
1. Optionally applies a perturbed copy of `intervention`.
2. Calls `gemeo.trajectory.predict` (which itself routes to TGNN
or LLM/rule-based bootstrap).
3. Calls `gemeo.risk.assess`.
"""
horizons_months = horizons_months or [6, 12, 24]
n_runs = max(5, min(200, int(n_runs)))
from . import trajectory as gtraj, risk as grisk
import random
import math
# 1) Compute the BASE prediction ONCE (one trajectory + one risk LLM call).
# Then perturb stochastically per-run — no per-run LLM (would blow up cost
# + time and hit rate limits). This is the right way to do MC over an
# LLM forecaster: cheap perturbation around an expensive base.
try:
base_traj = await gtraj.predict(space, horizons_months)
except Exception as e:
logger.warning(f"simulate: base trajectory failed: {e}")
base_traj = None
try:
base_risk = await grisk.assess(space)
except Exception as e:
logger.warning(f"simulate: base risk failed: {e}")
base_risk = None
if base_traj is None and base_risk is None:
return SimulationSpec(n_runs=0, intervention=intervention,
horizon_outcomes=[], survival_summary=[],
median_severity=0.0)
# 2) Apply intervention shift heuristically to the base
intervention_shift = 0.0
if intervention:
adherence = max(0.4, min(1.0, random.gauss(0.85, 0.10)))
# Treatment intent reduces risk; null intent leaves shift = 0
if intervention.get("type") == "treatment":
intervention_shift = -0.10 * adherence # ~10% absolute reduction at full adherence
elif intervention.get("type") == "phenotype_resolve":
intervention_shift = -0.05
elif intervention.get("type") == "phenotype_add":
intervention_shift = +0.08
# 3) N stochastic samples around the base
runs_traj = []
runs_risk = []
for _ in range(n_runs):
# Per-run noise increasing with horizon (longer = more uncertain)
if base_traj:
new_h = []
for h in base_traj.horizons:
noise = random.gauss(0, 0.05 + 0.02 * (h.months / 12))
new_score = max(0.0, min(1.0, h.risk_score + intervention_shift + noise))
new_h.append(type(h)(
months=h.months, state=h.state,
risk_score=new_score,
confidence_low=h.confidence_low, confidence_high=h.confidence_high,
expected_phenotypes=h.expected_phenotypes,
expected_complications=h.expected_complications,
))
runs_traj.append(new_h)
if base_risk:
sev_noise = random.gauss(0, 0.04)
new_sev = max(0.0, min(1.0, base_risk.overall_severity + intervention_shift + sev_noise))
# Recompute survival under the perturbed severity
from .risk import _approx_survival_from_severity
new_curve = _approx_survival_from_severity(new_sev, [p["month"] for p in (base_risk.survival_curve or [])])
runs_risk.append({"severity": new_sev, "curve": new_curve})
runs = list(zip(runs_traj or [None] * n_runs, runs_risk or [None] * n_runs))
# aggregate
by_horizon: dict = {h: [] for h in horizons_months}
survival_curves = []
severities = []
for traj, risk in runs:
if traj:
for h in traj:
if h.months in by_horizon:
by_horizon[h.months].append(h.risk_score)
if risk:
severities.append(risk["severity"])
if risk["curve"]:
survival_curves.append(risk["curve"])
horizon_summaries = []
for h, vals in by_horizon.items():
if not vals:
horizon_summaries.append(SimulationOutcome(
metric="risk_score", horizon_months=h,
n=0, mean=0.0, p05=0.0, p50=0.0, p95=0.0,
))
continue
vals_sorted = sorted(vals)
n = len(vals_sorted)
horizon_summaries.append(SimulationOutcome(
metric="risk_score",
horizon_months=h, n=n,
mean=round(sum(vals_sorted) / n, 4),
p05=round(vals_sorted[int(0.05 * n)], 4),
p50=round(vals_sorted[int(0.5 * n)], 4),
p95=round(vals_sorted[min(n - 1, int(0.95 * n))], 4),
))
survival_summary = []
if survival_curves:
all_months = sorted({pt["month"] for curve in survival_curves for pt in curve})
for m in all_months:
survival_summary.append(_summarize_curve(survival_curves, m))
severity_p50 = sorted(severities)[len(severities) // 2] if severities else 0.0
return SimulationSpec(
n_runs=len(runs),
intervention=intervention,
horizon_outcomes=horizon_summaries,
survival_summary=survival_summary,
median_severity=round(float(severity_p50), 4),
)