File size: 8,186 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """Risk + survival.
Bootstrap: wraps `risk_quantifier.quantify_risks` (rule-based scoring over
disease severity + diagnostic certainty + progression + treatment urgency
+ Brazilian context).
Phase 2: NeuralSurv on KG (gemeo/train/neuralsurv.py) β Bayesian survival
with epistemic uncertainty over the patient's KG-walk feature vector.
"""
from __future__ import annotations
import logging
import os
import math
from typing import Optional
from .types import RiskSpec
logger = logging.getLogger("gemeo.risk")
NEURALSURV_CKPT = os.environ.get(
"GEMEO_NEURALSURV_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "neuralsurv.pt"),
)
def _approx_survival_from_severity(severity: float, horizons_months: list[int] = None) -> list:
"""Bootstrap survival curve when no NeuralSurv model is available.
Maps severity (0..1) to a hazard rate via a simple monotonic transform and
integrates an exponential survival model. CI is heuristic (Β±15% relative).
"""
horizons_months = horizons_months or [3, 6, 12, 24, 36, 60]
severity = max(0.0, min(1.0, float(severity)))
# hazard per year β severity 0.0 β ~0.01/y, severity 1.0 β ~0.5/y
hazard_per_year = 0.01 + 0.49 * (severity ** 1.5)
points = []
for m in horizons_months:
years = m / 12.0
p = math.exp(-hazard_per_year * years)
# heuristic CI widens with horizon
spread = min(0.4, 0.05 + 0.05 * years)
points.append({
"month": m,
"p_alive": round(p, 4),
"ci_low": round(max(0.0, p - spread), 4),
"ci_high": round(min(1.0, p + spread), 4),
})
return points
async def _try_neuralsurv(space, embedding):
if not os.path.exists(NEURALSURV_CKPT):
return None
try:
from .train import neuralsurv as ns_mod
return await ns_mod.predict(space, embedding, NEURALSURV_CKPT)
except Exception as e:
logger.warning(f"NeuralSurv predict failed: {e}")
return None
async def assess(space, embedding=None) -> RiskSpec:
"""Compute the risk profile of the digital twin."""
# try NeuralSurv first β gets calibrated survival curve from real DATASUS
ns_spec = await _try_neuralsurv(space, embedding)
if ns_spec is not None:
# Apply disease-class clinical-severity floor: NeuralSurv predicts
# MORTALITY hazard from SIM data; the clinical-impairment severity
# (e.g. wheelchair-by-2026) is a different axis. Take MAX so we keep
# the calibrated survival CURVE but lift severity to reflect known
# disease class. This matters because, e.g., AT child age 5 has
# P(alive at 72m) = 0.61 β mortality-derived severity = 0.39, but
# AT is a severe progressive disease where the 11-year-old will be
# wheelchair-borderline regardless of being alive.
SEVERE_FLOOR = {
"100": 0.65, "646": 0.65, "355": 0.55, "324": 0.55,
"365": 0.85, "579": 0.65, "580": 0.65, "70": 0.95,
"905": 0.55, "98896": 0.70, "586": 0.65, "95": 0.60,
"183660": 0.85, "778": 0.70, "636": 0.50, "558": 0.55,
}
top_orpha = None; top_prob = 0.0
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
p = float(getattr(hyp, "probability", 0) or 0)
orpha = getattr(hyp, "orpha_code", None)
status = getattr(hyp, "status", "")
if p > top_prob and status in ("active", "supported", "confirmed") and orpha:
top_prob = p; top_orpha = orpha
if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR:
ns_spec.overall_severity = max(ns_spec.overall_severity, SEVERE_FLOOR[top_orpha])
ns_spec.progression_risk = max(ns_spec.progression_risk, 0.55)
ns_spec.treatment_urgency = max(ns_spec.treatment_urgency, 0.65)
return ns_spec
# bootstrap via rule-based risk_quantifier
severity = 0.0
progression = 0.0
urgency = 0.0
complications = []
try:
from risk_quantifier import quantify_risks
profile = await quantify_risks(space)
if profile is None:
pass
elif isinstance(profile, dict):
severity = float(profile.get("overall_severity", 0.0) or 0.0)
progression = float(profile.get("progression_risk", 0.0) or 0.0)
urgency = float(profile.get("treatment_urgency", 0.0) or 0.0)
complications = profile.get("top_complications", []) or []
else:
severity = float(getattr(profile, "overall_severity", 0.0) or 0.0)
progression = float(getattr(profile, "progression_risk", 0.0) or 0.0)
urgency = float(getattr(profile, "treatment_urgency", 0.0) or 0.0)
comps = getattr(profile, "top_complications", []) or []
for c in comps[:8]:
if isinstance(c, dict):
complications.append({
"name": c.get("name"),
"prob": c.get("probability"),
"when": c.get("expected_in_months") or c.get("expected_in"),
})
else:
complications.append({
"name": getattr(c, "name", None),
"prob": getattr(c, "probability", None),
"when": getattr(c, "expected_in_months", None),
})
except Exception as e:
logger.debug(f"risk_quantifier failed: {e}")
# ALWAYS apply disease-class severity floor when we have a confirmed/active
# hypothesis pointing at a known severe rare disease. This runs regardless
# of whether risk_quantifier already returned non-zero β it MAX-merges with
# the rule-based output, never reduces it.
try:
snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
n_pheno = len(snap.phenotypes) if snap else 0
# Disease-class severity floor (curated; covers seeded diseases)
SEVERE_FLOOR = {
"100": 0.65, # Ataxia-telangiectasia
"646": 0.65, # NPC
"355": 0.55, # Gaucher
"324": 0.55, # Fabry
"365": 0.85, # Pompe (infantile) β high mortality untreated
"579": 0.65, # MPS I
"580": 0.65, # MPS II
"70": 0.95, # SMA-1 β without treatment ~90% mortality by age 2
"905": 0.55, # Wilson
"98896": 0.70, # DMD
"586": 0.65, # CF
"95": 0.60, # Friedreich
"183660":0.85, # SCID
"778": 0.70, # Rett
}
top_orpha = None
top_prob = 0.0
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
p = float(getattr(hyp, "probability", 0) or 0)
orpha = getattr(hyp, "orpha_code", None)
status = getattr(hyp, "status", "")
if p > top_prob and status in ("active", "supported", "confirmed") and orpha:
top_prob = p
top_orpha = orpha
if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR:
# Apply floor at full strength when prob >= 0.5 (catches "suspected"
# extracted at 0.70 too, not just confirmed-at-0.85). Don't dilute by prob.
severity = max(severity, SEVERE_FLOOR[top_orpha])
progression = max(progression, 0.55)
urgency = max(urgency, 0.65)
elif severity == 0 and progression == 0 and urgency == 0:
# No dx hint at all β coarse phenotype-count proxy
severity = min(0.6, 0.05 + 0.02 * n_pheno)
except Exception as e:
logger.debug(f"severity boost failed: {e}")
if severity == 0:
severity = 0.2
survival = _approx_survival_from_severity(severity)
return RiskSpec(
overall_severity=severity,
progression_risk=progression,
treatment_urgency=urgency,
survival_curve=survival,
top_complications=complications[:8],
model="rule_based" if not _try_neuralsurv else "rule_based",
)
|