"""What-if / counterfactual engine. Bootstrap: heuristic — re-run trajectory + risk after applying the intervention to the snapshot. Wraps `trajectory_engine.what_if` if available, else mutates a copy of the snapshot and re-runs `gemeo.trajectory.predict` + `gemeo.risk.assess`. Phase 2: CF-GNNExplainer + do-calculus on the patient subgraph — perturb a node (gene knock-in, drug-on, lab-flipped) and recompute via the trained HGT. Supported interventions: - {"type": "treatment", "drug": "Cerezyme", "rxcui": "...", "start_in_days": 30} - {"type": "lab_flip", "test": "AFP", "from": "abnormal", "to": "normal"} - {"type": "phenotype_resolve", "hpo_id": "HP:0001250"} - {"type": "phenotype_add", "hpo_id": "HP:0002240"} - {"type": "gene_test", "symbol": "GBA", "result": "pathogenic"} """ from __future__ import annotations import copy import logging import os from typing import Optional from .types import WhatIfResult logger = logging.getLogger("gemeo.whatif") CF_CKPT = os.environ.get( "GEMEO_CF_CKPT", os.path.join(os.path.dirname(__file__), "artifacts", "cf_gnn.pt"), ) def _has_cf_model() -> bool: return os.path.exists(CF_CKPT) async def _apply_to_snapshot(space, intervention: dict): """Mutate a deep-copy of the space's current snapshot, return the copy.""" space_copy = copy.deepcopy(space) snap = space_copy.get_current_snapshot() if hasattr(space_copy, "get_current_snapshot") else None if snap is None: return space_copy itype = intervention.get("type", "") if itype == "treatment": snap.treatments.append({ "name": intervention.get("drug", "unknown"), "type": "what_if", "start": "in_simulation", "response": "pending", "protocol": intervention.get("protocol"), }) elif itype == "lab_flip": for lab in snap.labs: if (lab.get("test") or "").lower() == (intervention.get("test") or "").lower(): lab["abnormal"] = (intervention.get("to", "normal").lower() == "abnormal") lab["value"] = intervention.get("new_value", lab.get("value")) elif itype == "phenotype_resolve": target = intervention.get("hpo_id") snap.phenotypes = [p for p in snap.phenotypes if p.get("hpo_id") != target] elif itype == "phenotype_add": snap.phenotypes.append({ "hpo_id": intervention.get("hpo_id"), "name": intervention.get("name", intervention.get("hpo_id")), "severity": intervention.get("severity", "moderate"), "status": "what_if", }) elif itype == "gene_test": snap.genes.append({ "symbol": intervention.get("symbol"), "variant": intervention.get("variant", "unknown"), "pathogenicity": intervention.get("result", "uncertain"), "zygosity": intervention.get("zygosity"), }) return space_copy async def simulate( space, intervention: dict, *, baseline_risk=None, baseline_trajectory=None, ) -> WhatIfResult: """Run a counterfactual. If `baseline_risk` and `baseline_trajectory` are provided, computes deltas against them. Otherwise computes them on the fly. """ from . import trajectory as gtraj, risk as grisk, encoder as genc # 1) baseline (if not supplied) if baseline_risk is None: baseline_risk = await grisk.assess(space) if baseline_trajectory is None: baseline_trajectory = await gtraj.predict(space) # 2) try the trajectory_engine.what_if first (LLM-based counterfactual) new_traj = None new_risk = None rationale = "" confidence = 0.5 try: from trajectory_engine import what_if as te_what_if cf_result = await te_what_if(space, intervention) if cf_result is not None: if isinstance(cf_result, dict): rationale = cf_result.get("rationale") or cf_result.get("explanation") or "" confidence = float(cf_result.get("confidence", 0.6)) else: rationale = getattr(cf_result, "rationale", "") or getattr(cf_result, "explanation", "") confidence = float(getattr(cf_result, "confidence", 0.6) or 0.6) except Exception as e: logger.debug(f"trajectory_engine.what_if failed: {e}") # 3) re-run gemeo predictions on a mutated snapshot space_cf = await _apply_to_snapshot(space, intervention) try: new_traj = await gtraj.predict(space_cf) new_risk = await grisk.assess(space_cf) except Exception as e: logger.warning(f"counterfactual re-prediction failed: {e}") # 4) compute deltas delta_risk = 0.0 if new_risk and baseline_risk: delta_risk = float(new_risk.overall_severity) - float(baseline_risk.overall_severity) delta_trajectory = [] if new_traj and baseline_trajectory: baseline_by_h = {h.months: h for h in baseline_trajectory.horizons} for h in new_traj.horizons: base = baseline_by_h.get(h.months) if base: delta_trajectory.append({ "months": h.months, "delta_risk": round(h.risk_score - base.risk_score, 4), "from": base.state[:80], "to": h.state[:80], }) if not rationale: # build a short canned rationale sign = "decreases" if delta_risk < -0.02 else ("increases" if delta_risk > 0.02 else "barely changes") rationale = f"Intervention `{intervention.get('type')}` {sign} overall severity by {delta_risk:+.2%}." return WhatIfResult( intervention=intervention, delta_risk=round(delta_risk, 4), delta_trajectory=delta_trajectory, new_risk=new_risk, new_trajectory=new_traj, rationale=rationale, confidence=confidence, )