"""Risk + survival. Bootstrap: wraps `risk_quantifier.quantify_risks` (rule-based scoring over disease severity + diagnostic certainty + progression + treatment urgency + Brazilian context). Phase 2: NeuralSurv on KG (gemeo/train/neuralsurv.py) — Bayesian survival with epistemic uncertainty over the patient's KG-walk feature vector. """ from __future__ import annotations import logging import os import math from typing import Optional from .types import RiskSpec logger = logging.getLogger("gemeo.risk") NEURALSURV_CKPT = os.environ.get( "GEMEO_NEURALSURV_CKPT", os.path.join(os.path.dirname(__file__), "artifacts", "neuralsurv.pt"), ) def _approx_survival_from_severity(severity: float, horizons_months: list[int] = None) -> list: """Bootstrap survival curve when no NeuralSurv model is available. Maps severity (0..1) to a hazard rate via a simple monotonic transform and integrates an exponential survival model. CI is heuristic (±15% relative). """ horizons_months = horizons_months or [3, 6, 12, 24, 36, 60] severity = max(0.0, min(1.0, float(severity))) # hazard per year — severity 0.0 → ~0.01/y, severity 1.0 → ~0.5/y hazard_per_year = 0.01 + 0.49 * (severity ** 1.5) points = [] for m in horizons_months: years = m / 12.0 p = math.exp(-hazard_per_year * years) # heuristic CI widens with horizon spread = min(0.4, 0.05 + 0.05 * years) points.append({ "month": m, "p_alive": round(p, 4), "ci_low": round(max(0.0, p - spread), 4), "ci_high": round(min(1.0, p + spread), 4), }) return points async def _try_neuralsurv(space, embedding): if not os.path.exists(NEURALSURV_CKPT): return None try: from .train import neuralsurv as ns_mod return await ns_mod.predict(space, embedding, NEURALSURV_CKPT) except Exception as e: logger.warning(f"NeuralSurv predict failed: {e}") return None async def assess(space, embedding=None) -> RiskSpec: """Compute the risk profile of the digital twin.""" # try NeuralSurv first — gets calibrated survival curve from real DATASUS ns_spec = await _try_neuralsurv(space, embedding) if ns_spec is not None: # Apply disease-class clinical-severity floor: NeuralSurv predicts # MORTALITY hazard from SIM data; the clinical-impairment severity # (e.g. wheelchair-by-2026) is a different axis. Take MAX so we keep # the calibrated survival CURVE but lift severity to reflect known # disease class. This matters because, e.g., AT child age 5 has # P(alive at 72m) = 0.61 → mortality-derived severity = 0.39, but # AT is a severe progressive disease where the 11-year-old will be # wheelchair-borderline regardless of being alive. SEVERE_FLOOR = { "100": 0.65, "646": 0.65, "355": 0.55, "324": 0.55, "365": 0.85, "579": 0.65, "580": 0.65, "70": 0.95, "905": 0.55, "98896": 0.70, "586": 0.65, "95": 0.60, "183660": 0.85, "778": 0.70, "636": 0.50, "558": 0.55, } top_orpha = None; top_prob = 0.0 for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): p = float(getattr(hyp, "probability", 0) or 0) orpha = getattr(hyp, "orpha_code", None) status = getattr(hyp, "status", "") if p > top_prob and status in ("active", "supported", "confirmed") and orpha: top_prob = p; top_orpha = orpha if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR: ns_spec.overall_severity = max(ns_spec.overall_severity, SEVERE_FLOOR[top_orpha]) ns_spec.progression_risk = max(ns_spec.progression_risk, 0.55) ns_spec.treatment_urgency = max(ns_spec.treatment_urgency, 0.65) return ns_spec # bootstrap via rule-based risk_quantifier severity = 0.0 progression = 0.0 urgency = 0.0 complications = [] try: from risk_quantifier import quantify_risks profile = await quantify_risks(space) if profile is None: pass elif isinstance(profile, dict): severity = float(profile.get("overall_severity", 0.0) or 0.0) progression = float(profile.get("progression_risk", 0.0) or 0.0) urgency = float(profile.get("treatment_urgency", 0.0) or 0.0) complications = profile.get("top_complications", []) or [] else: severity = float(getattr(profile, "overall_severity", 0.0) or 0.0) progression = float(getattr(profile, "progression_risk", 0.0) or 0.0) urgency = float(getattr(profile, "treatment_urgency", 0.0) or 0.0) comps = getattr(profile, "top_complications", []) or [] for c in comps[:8]: if isinstance(c, dict): complications.append({ "name": c.get("name"), "prob": c.get("probability"), "when": c.get("expected_in_months") or c.get("expected_in"), }) else: complications.append({ "name": getattr(c, "name", None), "prob": getattr(c, "probability", None), "when": getattr(c, "expected_in_months", None), }) except Exception as e: logger.debug(f"risk_quantifier failed: {e}") # ALWAYS apply disease-class severity floor when we have a confirmed/active # hypothesis pointing at a known severe rare disease. This runs regardless # of whether risk_quantifier already returned non-zero — it MAX-merges with # the rule-based output, never reduces it. try: snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None n_pheno = len(snap.phenotypes) if snap else 0 # Disease-class severity floor (curated; covers seeded diseases) SEVERE_FLOOR = { "100": 0.65, # Ataxia-telangiectasia "646": 0.65, # NPC "355": 0.55, # Gaucher "324": 0.55, # Fabry "365": 0.85, # Pompe (infantile) — high mortality untreated "579": 0.65, # MPS I "580": 0.65, # MPS II "70": 0.95, # SMA-1 — without treatment ~90% mortality by age 2 "905": 0.55, # Wilson "98896": 0.70, # DMD "586": 0.65, # CF "95": 0.60, # Friedreich "183660":0.85, # SCID "778": 0.70, # Rett } top_orpha = None top_prob = 0.0 for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): p = float(getattr(hyp, "probability", 0) or 0) orpha = getattr(hyp, "orpha_code", None) status = getattr(hyp, "status", "") if p > top_prob and status in ("active", "supported", "confirmed") and orpha: top_prob = p top_orpha = orpha if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR: # Apply floor at full strength when prob >= 0.5 (catches "suspected" # extracted at 0.70 too, not just confirmed-at-0.85). Don't dilute by prob. severity = max(severity, SEVERE_FLOOR[top_orpha]) progression = max(progression, 0.55) urgency = max(urgency, 0.65) elif severity == 0 and progression == 0 and urgency == 0: # No dx hint at all — coarse phenotype-count proxy severity = min(0.6, 0.05 + 0.02 * n_pheno) except Exception as e: logger.debug(f"severity boost failed: {e}") if severity == 0: severity = 0.2 survival = _approx_survival_from_severity(severity) return RiskSpec( overall_severity=severity, progression_risk=progression, treatment_urgency=urgency, survival_curve=survival, top_complications=complications[:8], model="rule_based" if not _try_neuralsurv else "rule_based", )