timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Risk + survival.
Bootstrap: wraps `risk_quantifier.quantify_risks` (rule-based scoring over
disease severity + diagnostic certainty + progression + treatment urgency
+ Brazilian context).
Phase 2: NeuralSurv on KG (gemeo/train/neuralsurv.py) — Bayesian survival
with epistemic uncertainty over the patient's KG-walk feature vector.
"""
from __future__ import annotations
import logging
import os
import math
from typing import Optional
from .types import RiskSpec
logger = logging.getLogger("gemeo.risk")
NEURALSURV_CKPT = os.environ.get(
"GEMEO_NEURALSURV_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "neuralsurv.pt"),
)
def _approx_survival_from_severity(severity: float, horizons_months: list[int] = None) -> list:
"""Bootstrap survival curve when no NeuralSurv model is available.
Maps severity (0..1) to a hazard rate via a simple monotonic transform and
integrates an exponential survival model. CI is heuristic (±15% relative).
"""
horizons_months = horizons_months or [3, 6, 12, 24, 36, 60]
severity = max(0.0, min(1.0, float(severity)))
# hazard per year — severity 0.0 → ~0.01/y, severity 1.0 → ~0.5/y
hazard_per_year = 0.01 + 0.49 * (severity ** 1.5)
points = []
for m in horizons_months:
years = m / 12.0
p = math.exp(-hazard_per_year * years)
# heuristic CI widens with horizon
spread = min(0.4, 0.05 + 0.05 * years)
points.append({
"month": m,
"p_alive": round(p, 4),
"ci_low": round(max(0.0, p - spread), 4),
"ci_high": round(min(1.0, p + spread), 4),
})
return points
async def _try_neuralsurv(space, embedding):
if not os.path.exists(NEURALSURV_CKPT):
return None
try:
from .train import neuralsurv as ns_mod
return await ns_mod.predict(space, embedding, NEURALSURV_CKPT)
except Exception as e:
logger.warning(f"NeuralSurv predict failed: {e}")
return None
async def assess(space, embedding=None) -> RiskSpec:
"""Compute the risk profile of the digital twin."""
# try NeuralSurv first — gets calibrated survival curve from real DATASUS
ns_spec = await _try_neuralsurv(space, embedding)
if ns_spec is not None:
# Apply disease-class clinical-severity floor: NeuralSurv predicts
# MORTALITY hazard from SIM data; the clinical-impairment severity
# (e.g. wheelchair-by-2026) is a different axis. Take MAX so we keep
# the calibrated survival CURVE but lift severity to reflect known
# disease class. This matters because, e.g., AT child age 5 has
# P(alive at 72m) = 0.61 → mortality-derived severity = 0.39, but
# AT is a severe progressive disease where the 11-year-old will be
# wheelchair-borderline regardless of being alive.
SEVERE_FLOOR = {
"100": 0.65, "646": 0.65, "355": 0.55, "324": 0.55,
"365": 0.85, "579": 0.65, "580": 0.65, "70": 0.95,
"905": 0.55, "98896": 0.70, "586": 0.65, "95": 0.60,
"183660": 0.85, "778": 0.70, "636": 0.50, "558": 0.55,
}
top_orpha = None; top_prob = 0.0
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
p = float(getattr(hyp, "probability", 0) or 0)
orpha = getattr(hyp, "orpha_code", None)
status = getattr(hyp, "status", "")
if p > top_prob and status in ("active", "supported", "confirmed") and orpha:
top_prob = p; top_orpha = orpha
if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR:
ns_spec.overall_severity = max(ns_spec.overall_severity, SEVERE_FLOOR[top_orpha])
ns_spec.progression_risk = max(ns_spec.progression_risk, 0.55)
ns_spec.treatment_urgency = max(ns_spec.treatment_urgency, 0.65)
return ns_spec
# bootstrap via rule-based risk_quantifier
severity = 0.0
progression = 0.0
urgency = 0.0
complications = []
try:
from risk_quantifier import quantify_risks
profile = await quantify_risks(space)
if profile is None:
pass
elif isinstance(profile, dict):
severity = float(profile.get("overall_severity", 0.0) or 0.0)
progression = float(profile.get("progression_risk", 0.0) or 0.0)
urgency = float(profile.get("treatment_urgency", 0.0) or 0.0)
complications = profile.get("top_complications", []) or []
else:
severity = float(getattr(profile, "overall_severity", 0.0) or 0.0)
progression = float(getattr(profile, "progression_risk", 0.0) or 0.0)
urgency = float(getattr(profile, "treatment_urgency", 0.0) or 0.0)
comps = getattr(profile, "top_complications", []) or []
for c in comps[:8]:
if isinstance(c, dict):
complications.append({
"name": c.get("name"),
"prob": c.get("probability"),
"when": c.get("expected_in_months") or c.get("expected_in"),
})
else:
complications.append({
"name": getattr(c, "name", None),
"prob": getattr(c, "probability", None),
"when": getattr(c, "expected_in_months", None),
})
except Exception as e:
logger.debug(f"risk_quantifier failed: {e}")
# ALWAYS apply disease-class severity floor when we have a confirmed/active
# hypothesis pointing at a known severe rare disease. This runs regardless
# of whether risk_quantifier already returned non-zero — it MAX-merges with
# the rule-based output, never reduces it.
try:
snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
n_pheno = len(snap.phenotypes) if snap else 0
# Disease-class severity floor (curated; covers seeded diseases)
SEVERE_FLOOR = {
"100": 0.65, # Ataxia-telangiectasia
"646": 0.65, # NPC
"355": 0.55, # Gaucher
"324": 0.55, # Fabry
"365": 0.85, # Pompe (infantile) — high mortality untreated
"579": 0.65, # MPS I
"580": 0.65, # MPS II
"70": 0.95, # SMA-1 — without treatment ~90% mortality by age 2
"905": 0.55, # Wilson
"98896": 0.70, # DMD
"586": 0.65, # CF
"95": 0.60, # Friedreich
"183660":0.85, # SCID
"778": 0.70, # Rett
}
top_orpha = None
top_prob = 0.0
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
p = float(getattr(hyp, "probability", 0) or 0)
orpha = getattr(hyp, "orpha_code", None)
status = getattr(hyp, "status", "")
if p > top_prob and status in ("active", "supported", "confirmed") and orpha:
top_prob = p
top_orpha = orpha
if top_orpha and top_prob >= 0.5 and top_orpha in SEVERE_FLOOR:
# Apply floor at full strength when prob >= 0.5 (catches "suspected"
# extracted at 0.70 too, not just confirmed-at-0.85). Don't dilute by prob.
severity = max(severity, SEVERE_FLOOR[top_orpha])
progression = max(progression, 0.55)
urgency = max(urgency, 0.65)
elif severity == 0 and progression == 0 and urgency == 0:
# No dx hint at all — coarse phenotype-count proxy
severity = min(0.6, 0.05 + 0.02 * n_pheno)
except Exception as e:
logger.debug(f"severity boost failed: {e}")
if severity == 0:
severity = 0.2
survival = _approx_survival_from_severity(severity)
return RiskSpec(
overall_severity=severity,
progression_risk=progression,
treatment_urgency=urgency,
survival_curve=survival,
top_complications=complications[:8],
model="rule_based" if not _try_neuralsurv else "rule_based",
)