gemeo-twin-stack / src /gemeo /whatif.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""What-if / counterfactual engine.
Bootstrap: heuristic — re-run trajectory + risk after applying the intervention
to the snapshot. Wraps `trajectory_engine.what_if` if available, else mutates
a copy of the snapshot and re-runs `gemeo.trajectory.predict` + `gemeo.risk.assess`.
Phase 2: CF-GNNExplainer + do-calculus on the patient subgraph — perturb a
node (gene knock-in, drug-on, lab-flipped) and recompute via the trained HGT.
Supported interventions:
- {"type": "treatment", "drug": "Cerezyme", "rxcui": "...", "start_in_days": 30}
- {"type": "lab_flip", "test": "AFP", "from": "abnormal", "to": "normal"}
- {"type": "phenotype_resolve", "hpo_id": "HP:0001250"}
- {"type": "phenotype_add", "hpo_id": "HP:0002240"}
- {"type": "gene_test", "symbol": "GBA", "result": "pathogenic"}
"""
from __future__ import annotations
import copy
import logging
import os
from typing import Optional
from .types import WhatIfResult
logger = logging.getLogger("gemeo.whatif")
CF_CKPT = os.environ.get(
"GEMEO_CF_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "cf_gnn.pt"),
)
def _has_cf_model() -> bool:
return os.path.exists(CF_CKPT)
async def _apply_to_snapshot(space, intervention: dict):
"""Mutate a deep-copy of the space's current snapshot, return the copy."""
space_copy = copy.deepcopy(space)
snap = space_copy.get_current_snapshot() if hasattr(space_copy, "get_current_snapshot") else None
if snap is None:
return space_copy
itype = intervention.get("type", "")
if itype == "treatment":
snap.treatments.append({
"name": intervention.get("drug", "unknown"),
"type": "what_if",
"start": "in_simulation",
"response": "pending",
"protocol": intervention.get("protocol"),
})
elif itype == "lab_flip":
for lab in snap.labs:
if (lab.get("test") or "").lower() == (intervention.get("test") or "").lower():
lab["abnormal"] = (intervention.get("to", "normal").lower() == "abnormal")
lab["value"] = intervention.get("new_value", lab.get("value"))
elif itype == "phenotype_resolve":
target = intervention.get("hpo_id")
snap.phenotypes = [p for p in snap.phenotypes if p.get("hpo_id") != target]
elif itype == "phenotype_add":
snap.phenotypes.append({
"hpo_id": intervention.get("hpo_id"),
"name": intervention.get("name", intervention.get("hpo_id")),
"severity": intervention.get("severity", "moderate"),
"status": "what_if",
})
elif itype == "gene_test":
snap.genes.append({
"symbol": intervention.get("symbol"),
"variant": intervention.get("variant", "unknown"),
"pathogenicity": intervention.get("result", "uncertain"),
"zygosity": intervention.get("zygosity"),
})
return space_copy
async def simulate(
space,
intervention: dict,
*,
baseline_risk=None,
baseline_trajectory=None,
) -> WhatIfResult:
"""Run a counterfactual.
If `baseline_risk` and `baseline_trajectory` are provided, computes deltas
against them. Otherwise computes them on the fly.
"""
from . import trajectory as gtraj, risk as grisk, encoder as genc
# 1) baseline (if not supplied)
if baseline_risk is None:
baseline_risk = await grisk.assess(space)
if baseline_trajectory is None:
baseline_trajectory = await gtraj.predict(space)
# 2) try the trajectory_engine.what_if first (LLM-based counterfactual)
new_traj = None
new_risk = None
rationale = ""
confidence = 0.5
try:
from trajectory_engine import what_if as te_what_if
cf_result = await te_what_if(space, intervention)
if cf_result is not None:
if isinstance(cf_result, dict):
rationale = cf_result.get("rationale") or cf_result.get("explanation") or ""
confidence = float(cf_result.get("confidence", 0.6))
else:
rationale = getattr(cf_result, "rationale", "") or getattr(cf_result, "explanation", "")
confidence = float(getattr(cf_result, "confidence", 0.6) or 0.6)
except Exception as e:
logger.debug(f"trajectory_engine.what_if failed: {e}")
# 3) re-run gemeo predictions on a mutated snapshot
space_cf = await _apply_to_snapshot(space, intervention)
try:
new_traj = await gtraj.predict(space_cf)
new_risk = await grisk.assess(space_cf)
except Exception as e:
logger.warning(f"counterfactual re-prediction failed: {e}")
# 4) compute deltas
delta_risk = 0.0
if new_risk and baseline_risk:
delta_risk = float(new_risk.overall_severity) - float(baseline_risk.overall_severity)
delta_trajectory = []
if new_traj and baseline_trajectory:
baseline_by_h = {h.months: h for h in baseline_trajectory.horizons}
for h in new_traj.horizons:
base = baseline_by_h.get(h.months)
if base:
delta_trajectory.append({
"months": h.months,
"delta_risk": round(h.risk_score - base.risk_score, 4),
"from": base.state[:80],
"to": h.state[:80],
})
if not rationale:
# build a short canned rationale
sign = "decreases" if delta_risk < -0.02 else ("increases" if delta_risk > 0.02 else "barely changes")
rationale = f"Intervention `{intervention.get('type')}` {sign} overall severity by {delta_risk:+.2%}."
return WhatIfResult(
intervention=intervention,
delta_risk=round(delta_risk, 4),
delta_trajectory=delta_trajectory,
new_risk=new_risk,
new_trajectory=new_traj,
rationale=rationale,
confidence=confidence,
)