File size: 8,186 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""Risk + survival.

Bootstrap: wraps `risk_quantifier.quantify_risks` (rule-based scoring over
disease severity + diagnostic certainty + progression + treatment urgency
+ Brazilian context).

Phase 2: NeuralSurv on KG (gemeo/train/neuralsurv.py) β€” Bayesian survival
with epistemic uncertainty over the patient's KG-walk feature vector.
"""
from __future__ import annotations
import logging
import os
import math
from typing import Optional

from .types import RiskSpec

logger = logging.getLogger("gemeo.risk")

NEURALSURV_CKPT = os.environ.get(
    "GEMEO_NEURALSURV_CKPT",
    os.path.join(os.path.dirname(__file__), "artifacts", "neuralsurv.pt"),
)


def _approx_survival_from_severity(severity: float, horizons_months: list[int] = None) -> list:
    """Bootstrap survival curve when no NeuralSurv model is available.

    Maps severity (0..1) to a hazard rate via a simple monotonic transform and
    integrates an exponential survival model. CI is heuristic (Β±15% relative).
    """
    horizons_months = horizons_months or [3, 6, 12, 24, 36, 60]
    severity = max(0.0, min(1.0, float(severity)))
    # hazard per year β€” severity 0.0 β†’ ~0.01/y, severity 1.0 β†’ ~0.5/y
    hazard_per_year = 0.01 + 0.49 * (severity ** 1.5)
    points = []
    for m in horizons_months:
        years = m / 12.0
        p = math.exp(-hazard_per_year * years)
        # heuristic CI widens with horizon
        spread = min(0.4, 0.05 + 0.05 * years)
        points.append({
            "month": m,
            "p_alive": round(p, 4),
            "ci_low": round(max(0.0, p - spread), 4),
            "ci_high": round(min(1.0, p + spread), 4),
        })
    return points


async def _try_neuralsurv(space, embedding):
    if not os.path.exists(NEURALSURV_CKPT):
        return None
    try:
        from .train import neuralsurv as ns_mod
        return await ns_mod.predict(space, embedding, NEURALSURV_CKPT)
    except Exception as e:
        logger.warning(f"NeuralSurv predict failed: {e}")
        return None


async def assess(space, embedding=None) -> RiskSpec:
    """Compute the risk profile of the digital twin."""

    # try NeuralSurv first β€” gets calibrated survival curve from real DATASUS
    ns_spec = await _try_neuralsurv(space, embedding)
    if ns_spec is not None:
        # Apply disease-class clinical-severity floor: NeuralSurv predicts
        # MORTALITY hazard from SIM data; the clinical-impairment severity
        # (e.g. wheelchair-by-2026) is a different axis. Take MAX so we keep
        # the calibrated survival CURVE but lift severity to reflect known
        # disease class. This matters because, e.g., AT child age 5 has
        # P(alive at 72m) = 0.61 β†’ mortality-derived severity = 0.39, but
        # AT is a severe progressive disease where the 11-year-old will be
        # wheelchair-borderline regardless of being alive.
        SEVERE_FLOOR = {
            "100": 0.65, "646": 0.65, "355": 0.55, "324": 0.55,
            "365": 0.85, "579": 0.65, "580": 0.65, "70": 0.95,
            "905": 0.55, "98896": 0.70, "586": 0.65, "95": 0.60,
            "183660": 0.85, "778": 0.70, "636": 0.50, "558": 0.55,
        }
        top_orpha = None; top_prob = 0.0
        for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
            p = float(getattr(hyp, "probability", 0) or 0)
            orpha = getattr(hyp, "orpha_code", None)
            status = getattr(hyp, "status", "")
            if p > top_prob and status in ("active", "supported", "confirmed") and orpha:
                top_prob = p; top_orpha = orpha
        if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR:
            ns_spec.overall_severity = max(ns_spec.overall_severity, SEVERE_FLOOR[top_orpha])
            ns_spec.progression_risk = max(ns_spec.progression_risk, 0.55)
            ns_spec.treatment_urgency = max(ns_spec.treatment_urgency, 0.65)
        return ns_spec

    # bootstrap via rule-based risk_quantifier
    severity = 0.0
    progression = 0.0
    urgency = 0.0
    complications = []
    try:
        from risk_quantifier import quantify_risks
        profile = await quantify_risks(space)
        if profile is None:
            pass
        elif isinstance(profile, dict):
            severity = float(profile.get("overall_severity", 0.0) or 0.0)
            progression = float(profile.get("progression_risk", 0.0) or 0.0)
            urgency = float(profile.get("treatment_urgency", 0.0) or 0.0)
            complications = profile.get("top_complications", []) or []
        else:
            severity = float(getattr(profile, "overall_severity", 0.0) or 0.0)
            progression = float(getattr(profile, "progression_risk", 0.0) or 0.0)
            urgency = float(getattr(profile, "treatment_urgency", 0.0) or 0.0)
            comps = getattr(profile, "top_complications", []) or []
            for c in comps[:8]:
                if isinstance(c, dict):
                    complications.append({
                        "name": c.get("name"),
                        "prob": c.get("probability"),
                        "when": c.get("expected_in_months") or c.get("expected_in"),
                    })
                else:
                    complications.append({
                        "name": getattr(c, "name", None),
                        "prob": getattr(c, "probability", None),
                        "when": getattr(c, "expected_in_months", None),
                    })
    except Exception as e:
        logger.debug(f"risk_quantifier failed: {e}")

    # ALWAYS apply disease-class severity floor when we have a confirmed/active
    # hypothesis pointing at a known severe rare disease. This runs regardless
    # of whether risk_quantifier already returned non-zero β€” it MAX-merges with
    # the rule-based output, never reduces it.
    try:
        snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
        n_pheno = len(snap.phenotypes) if snap else 0
        # Disease-class severity floor (curated; covers seeded diseases)
        SEVERE_FLOOR = {
            "100":   0.65,  # Ataxia-telangiectasia
            "646":   0.65,  # NPC
            "355":   0.55,  # Gaucher
            "324":   0.55,  # Fabry
            "365":   0.85,  # Pompe (infantile) β€” high mortality untreated
            "579":   0.65,  # MPS I
            "580":   0.65,  # MPS II
            "70":    0.95,  # SMA-1 β€” without treatment ~90% mortality by age 2
            "905":   0.55,  # Wilson
            "98896": 0.70,  # DMD
            "586":   0.65,  # CF
            "95":    0.60,  # Friedreich
            "183660":0.85,  # SCID
            "778":   0.70,  # Rett
        }
        top_orpha = None
        top_prob = 0.0
        for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
            p = float(getattr(hyp, "probability", 0) or 0)
            orpha = getattr(hyp, "orpha_code", None)
            status = getattr(hyp, "status", "")
            if p > top_prob and status in ("active", "supported", "confirmed") and orpha:
                top_prob = p
                top_orpha = orpha
        if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR:
            # Apply floor at full strength when prob >= 0.5 (catches "suspected"
            # extracted at 0.70 too, not just confirmed-at-0.85). Don't dilute by prob.
            severity = max(severity, SEVERE_FLOOR[top_orpha])
            progression = max(progression, 0.55)
            urgency = max(urgency, 0.65)
        elif severity == 0 and progression == 0 and urgency == 0:
            # No dx hint at all β€” coarse phenotype-count proxy
            severity = min(0.6, 0.05 + 0.02 * n_pheno)
    except Exception as e:
        logger.debug(f"severity boost failed: {e}")
        if severity == 0:
            severity = 0.2

    survival = _approx_survival_from_severity(severity)

    return RiskSpec(
        overall_severity=severity,
        progression_risk=progression,
        treatment_urgency=urgency,
        survival_curve=survival,
        top_complications=complications[:8],
        model="rule_based" if not _try_neuralsurv else "rule_based",
    )