File size: 5,915 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """What-if / counterfactual engine.
Bootstrap: heuristic — re-run trajectory + risk after applying the intervention
to the snapshot. Wraps `trajectory_engine.what_if` if available, else mutates
a copy of the snapshot and re-runs `gemeo.trajectory.predict` + `gemeo.risk.assess`.
Phase 2: CF-GNNExplainer + do-calculus on the patient subgraph — perturb a
node (gene knock-in, drug-on, lab-flipped) and recompute via the trained HGT.
Supported interventions:
- {"type": "treatment", "drug": "Cerezyme", "rxcui": "...", "start_in_days": 30}
- {"type": "lab_flip", "test": "AFP", "from": "abnormal", "to": "normal"}
- {"type": "phenotype_resolve", "hpo_id": "HP:0001250"}
- {"type": "phenotype_add", "hpo_id": "HP:0002240"}
- {"type": "gene_test", "symbol": "GBA", "result": "pathogenic"}
"""
from __future__ import annotations
import copy
import logging
import os
from typing import Optional
from .types import WhatIfResult
logger = logging.getLogger("gemeo.whatif")
CF_CKPT = os.environ.get(
"GEMEO_CF_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "cf_gnn.pt"),
)
def _has_cf_model() -> bool:
return os.path.exists(CF_CKPT)
async def _apply_to_snapshot(space, intervention: dict):
"""Mutate a deep-copy of the space's current snapshot, return the copy."""
space_copy = copy.deepcopy(space)
snap = space_copy.get_current_snapshot() if hasattr(space_copy, "get_current_snapshot") else None
if snap is None:
return space_copy
itype = intervention.get("type", "")
if itype == "treatment":
snap.treatments.append({
"name": intervention.get("drug", "unknown"),
"type": "what_if",
"start": "in_simulation",
"response": "pending",
"protocol": intervention.get("protocol"),
})
elif itype == "lab_flip":
for lab in snap.labs:
if (lab.get("test") or "").lower() == (intervention.get("test") or "").lower():
lab["abnormal"] = (intervention.get("to", "normal").lower() == "abnormal")
lab["value"] = intervention.get("new_value", lab.get("value"))
elif itype == "phenotype_resolve":
target = intervention.get("hpo_id")
snap.phenotypes = [p for p in snap.phenotypes if p.get("hpo_id") != target]
elif itype == "phenotype_add":
snap.phenotypes.append({
"hpo_id": intervention.get("hpo_id"),
"name": intervention.get("name", intervention.get("hpo_id")),
"severity": intervention.get("severity", "moderate"),
"status": "what_if",
})
elif itype == "gene_test":
snap.genes.append({
"symbol": intervention.get("symbol"),
"variant": intervention.get("variant", "unknown"),
"pathogenicity": intervention.get("result", "uncertain"),
"zygosity": intervention.get("zygosity"),
})
return space_copy
async def simulate(
space,
intervention: dict,
*,
baseline_risk=None,
baseline_trajectory=None,
) -> WhatIfResult:
"""Run a counterfactual.
If `baseline_risk` and `baseline_trajectory` are provided, computes deltas
against them. Otherwise computes them on the fly.
"""
from . import trajectory as gtraj, risk as grisk, encoder as genc
# 1) baseline (if not supplied)
if baseline_risk is None:
baseline_risk = await grisk.assess(space)
if baseline_trajectory is None:
baseline_trajectory = await gtraj.predict(space)
# 2) try the trajectory_engine.what_if first (LLM-based counterfactual)
new_traj = None
new_risk = None
rationale = ""
confidence = 0.5
try:
from trajectory_engine import what_if as te_what_if
cf_result = await te_what_if(space, intervention)
if cf_result is not None:
if isinstance(cf_result, dict):
rationale = cf_result.get("rationale") or cf_result.get("explanation") or ""
confidence = float(cf_result.get("confidence", 0.6))
else:
rationale = getattr(cf_result, "rationale", "") or getattr(cf_result, "explanation", "")
confidence = float(getattr(cf_result, "confidence", 0.6) or 0.6)
except Exception as e:
logger.debug(f"trajectory_engine.what_if failed: {e}")
# 3) re-run gemeo predictions on a mutated snapshot
space_cf = await _apply_to_snapshot(space, intervention)
try:
new_traj = await gtraj.predict(space_cf)
new_risk = await grisk.assess(space_cf)
except Exception as e:
logger.warning(f"counterfactual re-prediction failed: {e}")
# 4) compute deltas
delta_risk = 0.0
if new_risk and baseline_risk:
delta_risk = float(new_risk.overall_severity) - float(baseline_risk.overall_severity)
delta_trajectory = []
if new_traj and baseline_trajectory:
baseline_by_h = {h.months: h for h in baseline_trajectory.horizons}
for h in new_traj.horizons:
base = baseline_by_h.get(h.months)
if base:
delta_trajectory.append({
"months": h.months,
"delta_risk": round(h.risk_score - base.risk_score, 4),
"from": base.state[:80],
"to": h.state[:80],
})
if not rationale:
# build a short canned rationale
sign = "decreases" if delta_risk < -0.02 else ("increases" if delta_risk > 0.02 else "barely changes")
rationale = f"Intervention `{intervention.get('type')}` {sign} overall severity by {delta_risk:+.2%}."
return WhatIfResult(
intervention=intervention,
delta_risk=round(delta_risk, 4),
delta_trajectory=delta_trajectory,
new_risk=new_risk,
new_trajectory=new_traj,
rationale=rationale,
confidence=confidence,
)
|