| """What-if / counterfactual engine. |
| |
| Bootstrap: heuristic — re-run trajectory + risk after applying the intervention |
| to the snapshot. Wraps `trajectory_engine.what_if` if available, else mutates |
| a copy of the snapshot and re-runs `gemeo.trajectory.predict` + `gemeo.risk.assess`. |
| |
| Phase 2: CF-GNNExplainer + do-calculus on the patient subgraph — perturb a |
| node (gene knock-in, drug-on, lab-flipped) and recompute via the trained HGT. |
| |
| Supported interventions: |
| - {"type": "treatment", "drug": "Cerezyme", "rxcui": "...", "start_in_days": 30} |
| - {"type": "lab_flip", "test": "AFP", "from": "abnormal", "to": "normal"} |
| - {"type": "phenotype_resolve", "hpo_id": "HP:0001250"} |
| - {"type": "phenotype_add", "hpo_id": "HP:0002240"} |
| - {"type": "gene_test", "symbol": "GBA", "result": "pathogenic"} |
| """ |
| from __future__ import annotations |
| import copy |
| import logging |
| import os |
| from typing import Optional |
|
|
| from .types import WhatIfResult |
|
|
| logger = logging.getLogger("gemeo.whatif") |
|
|
| CF_CKPT = os.environ.get( |
| "GEMEO_CF_CKPT", |
| os.path.join(os.path.dirname(__file__), "artifacts", "cf_gnn.pt"), |
| ) |
|
|
|
|
| def _has_cf_model() -> bool: |
| return os.path.exists(CF_CKPT) |
|
|
|
|
| async def _apply_to_snapshot(space, intervention: dict): |
| """Mutate a deep-copy of the space's current snapshot, return the copy.""" |
| space_copy = copy.deepcopy(space) |
| snap = space_copy.get_current_snapshot() if hasattr(space_copy, "get_current_snapshot") else None |
| if snap is None: |
| return space_copy |
|
|
| itype = intervention.get("type", "") |
|
|
| if itype == "treatment": |
| snap.treatments.append({ |
| "name": intervention.get("drug", "unknown"), |
| "type": "what_if", |
| "start": "in_simulation", |
| "response": "pending", |
| "protocol": intervention.get("protocol"), |
| }) |
| elif itype == "lab_flip": |
| for lab in snap.labs: |
| if (lab.get("test") or "").lower() == (intervention.get("test") or "").lower(): |
| lab["abnormal"] = (intervention.get("to", "normal").lower() == "abnormal") |
| lab["value"] = intervention.get("new_value", lab.get("value")) |
| elif itype == "phenotype_resolve": |
| target = intervention.get("hpo_id") |
| snap.phenotypes = [p for p in snap.phenotypes if p.get("hpo_id") != target] |
| elif itype == "phenotype_add": |
| snap.phenotypes.append({ |
| "hpo_id": intervention.get("hpo_id"), |
| "name": intervention.get("name", intervention.get("hpo_id")), |
| "severity": intervention.get("severity", "moderate"), |
| "status": "what_if", |
| }) |
| elif itype == "gene_test": |
| snap.genes.append({ |
| "symbol": intervention.get("symbol"), |
| "variant": intervention.get("variant", "unknown"), |
| "pathogenicity": intervention.get("result", "uncertain"), |
| "zygosity": intervention.get("zygosity"), |
| }) |
|
|
| return space_copy |
|
|
|
|
| async def simulate( |
| space, |
| intervention: dict, |
| *, |
| baseline_risk=None, |
| baseline_trajectory=None, |
| ) -> WhatIfResult: |
| """Run a counterfactual. |
| |
| If `baseline_risk` and `baseline_trajectory` are provided, computes deltas |
| against them. Otherwise computes them on the fly. |
| """ |
| from . import trajectory as gtraj, risk as grisk, encoder as genc |
|
|
| |
| if baseline_risk is None: |
| baseline_risk = await grisk.assess(space) |
| if baseline_trajectory is None: |
| baseline_trajectory = await gtraj.predict(space) |
|
|
| |
| new_traj = None |
| new_risk = None |
| rationale = "" |
| confidence = 0.5 |
|
|
| try: |
| from trajectory_engine import what_if as te_what_if |
| cf_result = await te_what_if(space, intervention) |
| if cf_result is not None: |
| if isinstance(cf_result, dict): |
| rationale = cf_result.get("rationale") or cf_result.get("explanation") or "" |
| confidence = float(cf_result.get("confidence", 0.6)) |
| else: |
| rationale = getattr(cf_result, "rationale", "") or getattr(cf_result, "explanation", "") |
| confidence = float(getattr(cf_result, "confidence", 0.6) or 0.6) |
| except Exception as e: |
| logger.debug(f"trajectory_engine.what_if failed: {e}") |
|
|
| |
| space_cf = await _apply_to_snapshot(space, intervention) |
| try: |
| new_traj = await gtraj.predict(space_cf) |
| new_risk = await grisk.assess(space_cf) |
| except Exception as e: |
| logger.warning(f"counterfactual re-prediction failed: {e}") |
|
|
| |
| delta_risk = 0.0 |
| if new_risk and baseline_risk: |
| delta_risk = float(new_risk.overall_severity) - float(baseline_risk.overall_severity) |
|
|
| delta_trajectory = [] |
| if new_traj and baseline_trajectory: |
| baseline_by_h = {h.months: h for h in baseline_trajectory.horizons} |
| for h in new_traj.horizons: |
| base = baseline_by_h.get(h.months) |
| if base: |
| delta_trajectory.append({ |
| "months": h.months, |
| "delta_risk": round(h.risk_score - base.risk_score, 4), |
| "from": base.state[:80], |
| "to": h.state[:80], |
| }) |
|
|
| if not rationale: |
| |
| sign = "decreases" if delta_risk < -0.02 else ("increases" if delta_risk > 0.02 else "barely changes") |
| rationale = f"Intervention `{intervention.get('type')}` {sign} overall severity by {delta_risk:+.2%}." |
|
|
| return WhatIfResult( |
| intervention=intervention, |
| delta_risk=round(delta_risk, 4), |
| delta_trajectory=delta_trajectory, |
| new_risk=new_risk, |
| new_trajectory=new_traj, |
| rationale=rationale, |
| confidence=confidence, |
| ) |
|
|