| """In-silico Monte-Carlo simulation of trajectory. |
| |
| Runs the trajectory module N times under stochastic perturbations |
| (intervention timing, treatment adherence, symptomatic event arrivals) |
| and aggregates outcome distributions. Gives the clinician a sense of |
| the *spread* of possible futures, not a single point estimate. |
| |
| This is a thin Monte-Carlo wrapper over `trajectory.predict` and |
| `whatif.simulate` — no new model. The value is in the aggregation: |
| empirical CDFs over time-to-event, the modal vs tail outcomes, and |
| the sensitivity of the trajectory to small intervention timing shifts. |
| """ |
| from __future__ import annotations |
| import asyncio |
| import logging |
| import random |
| from typing import Optional |
|
|
| from .types import SimulationSpec, SimulationOutcome |
|
|
| logger = logging.getLogger("gemeo.simulate") |
|
|
|
|
| def _stochastic_intervention(base: dict, jitter: float = 0.2) -> dict: |
| """Inject randomness into the intervention (timing, dose, adherence).""" |
| out = dict(base) |
| if "start_in_days" in out: |
| try: |
| d = float(out["start_in_days"]) |
| out["start_in_days"] = max(0, d + random.gauss(0, d * jitter)) |
| except Exception: |
| pass |
| out["adherence"] = max(0.4, min(1.0, random.gauss(0.85, 0.1))) |
| return out |
|
|
|
|
| def _summarize_curve(curves: list, month: int) -> dict: |
| """Empirical CI over P(alive at `month`) across simulation runs.""" |
| vals = [] |
| for curve in curves: |
| for pt in curve: |
| if pt.get("month") == month: |
| vals.append(float(pt.get("p_alive", 0.0))) |
| break |
| if not vals: |
| return {"month": month, "n": 0} |
| vals.sort() |
| n = len(vals) |
| return { |
| "month": month, |
| "n": n, |
| "mean": round(sum(vals) / n, 4), |
| "p05": round(vals[int(0.05 * n)], 4), |
| "p50": round(vals[int(0.5 * n)], 4), |
| "p95": round(vals[min(n - 1, int(0.95 * n))], 4), |
| } |
|
|
|
|
| async def run( |
| space, |
| *, |
| n_runs: int = 30, |
| intervention: dict = None, |
| horizons_months: list[int] = None, |
| ) -> SimulationSpec: |
| """Run a Monte-Carlo simulation of the patient's trajectory. |
| |
| Each run: |
| 1. Optionally applies a perturbed copy of `intervention`. |
| 2. Calls `gemeo.trajectory.predict` (which itself routes to TGNN |
| or LLM/rule-based bootstrap). |
| 3. Calls `gemeo.risk.assess`. |
| """ |
| horizons_months = horizons_months or [6, 12, 24] |
| n_runs = max(5, min(200, int(n_runs))) |
|
|
| from . import trajectory as gtraj, risk as grisk |
| import random |
| import math |
|
|
| |
| |
| |
| |
| try: |
| base_traj = await gtraj.predict(space, horizons_months) |
| except Exception as e: |
| logger.warning(f"simulate: base trajectory failed: {e}") |
| base_traj = None |
| try: |
| base_risk = await grisk.assess(space) |
| except Exception as e: |
| logger.warning(f"simulate: base risk failed: {e}") |
| base_risk = None |
|
|
| if base_traj is None and base_risk is None: |
| return SimulationSpec(n_runs=0, intervention=intervention, |
| horizon_outcomes=[], survival_summary=[], |
| median_severity=0.0) |
|
|
| |
| intervention_shift = 0.0 |
| if intervention: |
| adherence = max(0.4, min(1.0, random.gauss(0.85, 0.10))) |
| |
| if intervention.get("type") == "treatment": |
| intervention_shift = -0.10 * adherence |
| elif intervention.get("type") == "phenotype_resolve": |
| intervention_shift = -0.05 |
| elif intervention.get("type") == "phenotype_add": |
| intervention_shift = +0.08 |
|
|
| |
| runs_traj = [] |
| runs_risk = [] |
| for _ in range(n_runs): |
| |
| if base_traj: |
| new_h = [] |
| for h in base_traj.horizons: |
| noise = random.gauss(0, 0.05 + 0.02 * (h.months / 12)) |
| new_score = max(0.0, min(1.0, h.risk_score + intervention_shift + noise)) |
| new_h.append(type(h)( |
| months=h.months, state=h.state, |
| risk_score=new_score, |
| confidence_low=h.confidence_low, confidence_high=h.confidence_high, |
| expected_phenotypes=h.expected_phenotypes, |
| expected_complications=h.expected_complications, |
| )) |
| runs_traj.append(new_h) |
| if base_risk: |
| sev_noise = random.gauss(0, 0.04) |
| new_sev = max(0.0, min(1.0, base_risk.overall_severity + intervention_shift + sev_noise)) |
| |
| from .risk import _approx_survival_from_severity |
| new_curve = _approx_survival_from_severity(new_sev, [p["month"] for p in (base_risk.survival_curve or [])]) |
| runs_risk.append({"severity": new_sev, "curve": new_curve}) |
|
|
| runs = list(zip(runs_traj or [None] * n_runs, runs_risk or [None] * n_runs)) |
|
|
| |
| by_horizon: dict = {h: [] for h in horizons_months} |
| survival_curves = [] |
| severities = [] |
| for traj, risk in runs: |
| if traj: |
| for h in traj: |
| if h.months in by_horizon: |
| by_horizon[h.months].append(h.risk_score) |
| if risk: |
| severities.append(risk["severity"]) |
| if risk["curve"]: |
| survival_curves.append(risk["curve"]) |
|
|
| horizon_summaries = [] |
| for h, vals in by_horizon.items(): |
| if not vals: |
| horizon_summaries.append(SimulationOutcome( |
| metric="risk_score", horizon_months=h, |
| n=0, mean=0.0, p05=0.0, p50=0.0, p95=0.0, |
| )) |
| continue |
| vals_sorted = sorted(vals) |
| n = len(vals_sorted) |
| horizon_summaries.append(SimulationOutcome( |
| metric="risk_score", |
| horizon_months=h, n=n, |
| mean=round(sum(vals_sorted) / n, 4), |
| p05=round(vals_sorted[int(0.05 * n)], 4), |
| p50=round(vals_sorted[int(0.5 * n)], 4), |
| p95=round(vals_sorted[min(n - 1, int(0.95 * n))], 4), |
| )) |
|
|
| survival_summary = [] |
| if survival_curves: |
| all_months = sorted({pt["month"] for curve in survival_curves for pt in curve}) |
| for m in all_months: |
| survival_summary.append(_summarize_curve(survival_curves, m)) |
|
|
| severity_p50 = sorted(severities)[len(severities) // 2] if severities else 0.0 |
|
|
| return SimulationSpec( |
| n_runs=len(runs), |
| intervention=intervention, |
| horizon_outcomes=horizon_summaries, |
| survival_summary=survival_summary, |
| median_severity=round(float(severity_p50), 4), |
| ) |
|
|