"""In-silico Monte-Carlo simulation of trajectory. Runs the trajectory module N times under stochastic perturbations (intervention timing, treatment adherence, symptomatic event arrivals) and aggregates outcome distributions. Gives the clinician a sense of the *spread* of possible futures, not a single point estimate. This is a thin Monte-Carlo wrapper over `trajectory.predict` and `whatif.simulate` — no new model. The value is in the aggregation: empirical CDFs over time-to-event, the modal vs tail outcomes, and the sensitivity of the trajectory to small intervention timing shifts. """ from __future__ import annotations import asyncio import logging import random from typing import Optional from .types import SimulationSpec, SimulationOutcome logger = logging.getLogger("gemeo.simulate") def _stochastic_intervention(base: dict, jitter: float = 0.2) -> dict: """Inject randomness into the intervention (timing, dose, adherence).""" out = dict(base) if "start_in_days" in out: try: d = float(out["start_in_days"]) out["start_in_days"] = max(0, d + random.gauss(0, d * jitter)) except Exception: pass out["adherence"] = max(0.4, min(1.0, random.gauss(0.85, 0.1))) return out def _summarize_curve(curves: list, month: int) -> dict: """Empirical CI over P(alive at `month`) across simulation runs.""" vals = [] for curve in curves: for pt in curve: if pt.get("month") == month: vals.append(float(pt.get("p_alive", 0.0))) break if not vals: return {"month": month, "n": 0} vals.sort() n = len(vals) return { "month": month, "n": n, "mean": round(sum(vals) / n, 4), "p05": round(vals[int(0.05 * n)], 4), "p50": round(vals[int(0.5 * n)], 4), "p95": round(vals[min(n - 1, int(0.95 * n))], 4), } async def run( space, *, n_runs: int = 30, intervention: dict = None, horizons_months: list[int] = None, ) -> SimulationSpec: """Run a Monte-Carlo simulation of the patient's trajectory. Each run: 1. Optionally applies a perturbed copy of `intervention`. 2. Calls `gemeo.trajectory.predict` (which itself routes to TGNN or LLM/rule-based bootstrap). 3. Calls `gemeo.risk.assess`. """ horizons_months = horizons_months or [6, 12, 24] n_runs = max(5, min(200, int(n_runs))) from . import trajectory as gtraj, risk as grisk import random import math # 1) Compute the BASE prediction ONCE (one trajectory + one risk LLM call). # Then perturb stochastically per-run — no per-run LLM (would blow up cost # + time and hit rate limits). This is the right way to do MC over an # LLM forecaster: cheap perturbation around an expensive base. try: base_traj = await gtraj.predict(space, horizons_months) except Exception as e: logger.warning(f"simulate: base trajectory failed: {e}") base_traj = None try: base_risk = await grisk.assess(space) except Exception as e: logger.warning(f"simulate: base risk failed: {e}") base_risk = None if base_traj is None and base_risk is None: return SimulationSpec(n_runs=0, intervention=intervention, horizon_outcomes=[], survival_summary=[], median_severity=0.0) # 2) Apply intervention shift heuristically to the base intervention_shift = 0.0 if intervention: adherence = max(0.4, min(1.0, random.gauss(0.85, 0.10))) # Treatment intent reduces risk; null intent leaves shift = 0 if intervention.get("type") == "treatment": intervention_shift = -0.10 * adherence # ~10% absolute reduction at full adherence elif intervention.get("type") == "phenotype_resolve": intervention_shift = -0.05 elif intervention.get("type") == "phenotype_add": intervention_shift = +0.08 # 3) N stochastic samples around the base runs_traj = [] runs_risk = [] for _ in range(n_runs): # Per-run noise increasing with horizon (longer = more uncertain) if base_traj: new_h = [] for h in base_traj.horizons: noise = random.gauss(0, 0.05 + 0.02 * (h.months / 12)) new_score = max(0.0, min(1.0, h.risk_score + intervention_shift + noise)) new_h.append(type(h)( months=h.months, state=h.state, risk_score=new_score, confidence_low=h.confidence_low, confidence_high=h.confidence_high, expected_phenotypes=h.expected_phenotypes, expected_complications=h.expected_complications, )) runs_traj.append(new_h) if base_risk: sev_noise = random.gauss(0, 0.04) new_sev = max(0.0, min(1.0, base_risk.overall_severity + intervention_shift + sev_noise)) # Recompute survival under the perturbed severity from .risk import _approx_survival_from_severity new_curve = _approx_survival_from_severity(new_sev, [p["month"] for p in (base_risk.survival_curve or [])]) runs_risk.append({"severity": new_sev, "curve": new_curve}) runs = list(zip(runs_traj or [None] * n_runs, runs_risk or [None] * n_runs)) # aggregate by_horizon: dict = {h: [] for h in horizons_months} survival_curves = [] severities = [] for traj, risk in runs: if traj: for h in traj: if h.months in by_horizon: by_horizon[h.months].append(h.risk_score) if risk: severities.append(risk["severity"]) if risk["curve"]: survival_curves.append(risk["curve"]) horizon_summaries = [] for h, vals in by_horizon.items(): if not vals: horizon_summaries.append(SimulationOutcome( metric="risk_score", horizon_months=h, n=0, mean=0.0, p05=0.0, p50=0.0, p95=0.0, )) continue vals_sorted = sorted(vals) n = len(vals_sorted) horizon_summaries.append(SimulationOutcome( metric="risk_score", horizon_months=h, n=n, mean=round(sum(vals_sorted) / n, 4), p05=round(vals_sorted[int(0.05 * n)], 4), p50=round(vals_sorted[int(0.5 * n)], 4), p95=round(vals_sorted[min(n - 1, int(0.95 * n))], 4), )) survival_summary = [] if survival_curves: all_months = sorted({pt["month"] for curve in survival_curves for pt in curve}) for m in all_months: survival_summary.append(_summarize_curve(survival_curves, m)) severity_p50 = sorted(severities)[len(severities) // 2] if severities else 0.0 return SimulationSpec( n_runs=len(runs), intervention=intervention, horizon_outcomes=horizon_summaries, survival_summary=survival_summary, median_severity=round(float(severity_p50), 4), )