File size: 5,915 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""What-if / counterfactual engine.

Bootstrap: heuristic — re-run trajectory + risk after applying the intervention
to the snapshot. Wraps `trajectory_engine.what_if` if available, else mutates
a copy of the snapshot and re-runs `gemeo.trajectory.predict` + `gemeo.risk.assess`.

Phase 2: CF-GNNExplainer + do-calculus on the patient subgraph — perturb a
node (gene knock-in, drug-on, lab-flipped) and recompute via the trained HGT.

Supported interventions:
  - {"type": "treatment", "drug": "Cerezyme", "rxcui": "...", "start_in_days": 30}
  - {"type": "lab_flip", "test": "AFP", "from": "abnormal", "to": "normal"}
  - {"type": "phenotype_resolve", "hpo_id": "HP:0001250"}
  - {"type": "phenotype_add", "hpo_id": "HP:0002240"}
  - {"type": "gene_test", "symbol": "GBA", "result": "pathogenic"}
"""
from __future__ import annotations
import copy
import logging
import os
from typing import Optional

from .types import WhatIfResult

logger = logging.getLogger("gemeo.whatif")

CF_CKPT = os.environ.get(
    "GEMEO_CF_CKPT",
    os.path.join(os.path.dirname(__file__), "artifacts", "cf_gnn.pt"),
)


def _has_cf_model() -> bool:
    return os.path.exists(CF_CKPT)


async def _apply_to_snapshot(space, intervention: dict):
    """Mutate a deep-copy of the space's current snapshot, return the copy."""
    space_copy = copy.deepcopy(space)
    snap = space_copy.get_current_snapshot() if hasattr(space_copy, "get_current_snapshot") else None
    if snap is None:
        return space_copy

    itype = intervention.get("type", "")

    if itype == "treatment":
        snap.treatments.append({
            "name": intervention.get("drug", "unknown"),
            "type": "what_if",
            "start": "in_simulation",
            "response": "pending",
            "protocol": intervention.get("protocol"),
        })
    elif itype == "lab_flip":
        for lab in snap.labs:
            if (lab.get("test") or "").lower() == (intervention.get("test") or "").lower():
                lab["abnormal"] = (intervention.get("to", "normal").lower() == "abnormal")
                lab["value"] = intervention.get("new_value", lab.get("value"))
    elif itype == "phenotype_resolve":
        target = intervention.get("hpo_id")
        snap.phenotypes = [p for p in snap.phenotypes if p.get("hpo_id") != target]
    elif itype == "phenotype_add":
        snap.phenotypes.append({
            "hpo_id": intervention.get("hpo_id"),
            "name": intervention.get("name", intervention.get("hpo_id")),
            "severity": intervention.get("severity", "moderate"),
            "status": "what_if",
        })
    elif itype == "gene_test":
        snap.genes.append({
            "symbol": intervention.get("symbol"),
            "variant": intervention.get("variant", "unknown"),
            "pathogenicity": intervention.get("result", "uncertain"),
            "zygosity": intervention.get("zygosity"),
        })

    return space_copy


async def simulate(
    space,
    intervention: dict,
    *,
    baseline_risk=None,
    baseline_trajectory=None,
) -> WhatIfResult:
    """Run a counterfactual.

    If `baseline_risk` and `baseline_trajectory` are provided, computes deltas
    against them. Otherwise computes them on the fly.
    """
    from . import trajectory as gtraj, risk as grisk, encoder as genc

    # 1) baseline (if not supplied)
    if baseline_risk is None:
        baseline_risk = await grisk.assess(space)
    if baseline_trajectory is None:
        baseline_trajectory = await gtraj.predict(space)

    # 2) try the trajectory_engine.what_if first (LLM-based counterfactual)
    new_traj = None
    new_risk = None
    rationale = ""
    confidence = 0.5

    try:
        from trajectory_engine import what_if as te_what_if
        cf_result = await te_what_if(space, intervention)
        if cf_result is not None:
            if isinstance(cf_result, dict):
                rationale = cf_result.get("rationale") or cf_result.get("explanation") or ""
                confidence = float(cf_result.get("confidence", 0.6))
            else:
                rationale = getattr(cf_result, "rationale", "") or getattr(cf_result, "explanation", "")
                confidence = float(getattr(cf_result, "confidence", 0.6) or 0.6)
    except Exception as e:
        logger.debug(f"trajectory_engine.what_if failed: {e}")

    # 3) re-run gemeo predictions on a mutated snapshot
    space_cf = await _apply_to_snapshot(space, intervention)
    try:
        new_traj = await gtraj.predict(space_cf)
        new_risk = await grisk.assess(space_cf)
    except Exception as e:
        logger.warning(f"counterfactual re-prediction failed: {e}")

    # 4) compute deltas
    delta_risk = 0.0
    if new_risk and baseline_risk:
        delta_risk = float(new_risk.overall_severity) - float(baseline_risk.overall_severity)

    delta_trajectory = []
    if new_traj and baseline_trajectory:
        baseline_by_h = {h.months: h for h in baseline_trajectory.horizons}
        for h in new_traj.horizons:
            base = baseline_by_h.get(h.months)
            if base:
                delta_trajectory.append({
                    "months": h.months,
                    "delta_risk": round(h.risk_score - base.risk_score, 4),
                    "from": base.state[:80],
                    "to": h.state[:80],
                })

    if not rationale:
        # build a short canned rationale
        sign = "decreases" if delta_risk < -0.02 else ("increases" if delta_risk > 0.02 else "barely changes")
        rationale = f"Intervention `{intervention.get('type')}` {sign} overall severity by {delta_risk:+.2%}."

    return WhatIfResult(
        intervention=intervention,
        delta_risk=round(delta_risk, 4),
        delta_trajectory=delta_trajectory,
        new_risk=new_risk,
        new_trajectory=new_traj,
        rationale=rationale,
        confidence=confidence,
    )