gemeo-sus / src /eval_sota.py
timmers's picture
GEMEO/SUS v6 recurrence-aware (RAVEN) — new-onset Top-1 60.1% vs baseline 38.2%, defeats autocorrelation trap. GEMEO Arch v2.0 Principle 7 proven.
908ea05 verified
"""SOTA evaluation suite for CDFv13 — audit-proof.
Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized)
with the canonical EHR foundation model metric stack:
Classification (next-event, downstream tasks):
- AUROC + AUPRC + Brier
- Calibration: ICI (Austin & Steyerberg 2019)
- Decision-curve analysis (Vickers)
- Bootstrap 95% CI (≥2000 resamples) — required for rare disease
Survival (DATASUS SIM mortality):
- Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring
- Integrated Brier Score (1/3/5y)
- Time-dependent AUC
Counterfactual / causal:
- ATE with bootstrap CI
- E-value (VanderWeele)
- Negative-control outcome + exposure
- Tipping-point analysis
Generation fidelity (CoMET / SynthEHRella):
- Dim-wise probability match
- MMD (Maximum Mean Discrepancy) with RBF kernel
- TSTR (Train-on-Synthetic-Test-on-Real)
Subgroup fairness (npj DM requirement):
- Stratified metrics: sex, age band, UF region
Split strategy (DATASUS rare disease):
- Temporal: train ≤2022, val 2023, test 2024-2025
- Geographic: train SE+S, test N+NE (UF cross-region = "external")
- Patient-level 5-fold CV (variance estimation)
"""
from __future__ import annotations
import math
import numpy as np
import torch
from typing import Callable
# ---------- Classification ----------
def auroc(y: np.ndarray, p: np.ndarray) -> float:
from sklearn.metrics import roc_auc_score
if len(np.unique(y)) < 2: return float("nan")
return roc_auc_score(y, p)
def auprc(y: np.ndarray, p: np.ndarray) -> float:
from sklearn.metrics import average_precision_score
if len(np.unique(y)) < 2: return float("nan")
return average_precision_score(y, p)
def brier(y: np.ndarray, p: np.ndarray) -> float:
from sklearn.metrics import brier_score_loss
return brier_score_loss(y, p)
def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float:
"""Integrated Calibration Index (Austin & Steyerberg 2019).
Lowess-smoothed deviation from perfect calibration.
"""
from statsmodels.nonparametric.smoothers_lowess import lowess
sm = lowess(y, p, frac=frac, return_sorted=True)
return float(np.mean(np.abs(sm[:, 1] - sm[:, 0])))
def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float:
"""Net benefit at a given decision threshold (Vickers DCA)."""
tp = ((p >= threshold) & (y == 1)).sum()
fp = ((p >= threshold) & (y == 0)).sum()
n = len(y)
if threshold >= 1.0: return 0.0
return tp / n - (fp / n) * (threshold / (1 - threshold))
def decision_curve(y: np.ndarray, p: np.ndarray,
thresholds: list[float] = None) -> dict:
"""Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none."""
if thresholds is None:
thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
model_nb = [net_benefit(y, p, t) for t in thresholds]
treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0
for t in thresholds]
treat_none_nb = [0.0] * len(thresholds)
return {
"thresholds": thresholds,
"model": model_nb,
"treat_all": treat_all_nb,
"treat_none": treat_none_nb,
}
def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable,
n_boot: int = 2000, seed: int = 0,
ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]:
"""Bootstrap 95% CI for any (y, p) -> scalar metric."""
rng = np.random.default_rng(seed)
n = len(y)
stats = []
for _ in range(n_boot):
idx = rng.integers(0, n, n)
if len(np.unique(y[idx])) < 2: continue
try:
stats.append(metric_fn(y[idx], p[idx]))
except Exception:
continue
if not stats: return (float("nan"),) * 3
return (
float(np.percentile(stats, ci[0])),
float(np.median(stats)),
float(np.percentile(stats, ci[1])),
)
# ---------- Survival ----------
def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time,
risk_score, tau: float = None) -> float:
"""Uno's C-index (IPCW concordance), preferred at high censoring.
Requires scikit-survival.
"""
try:
from sksurv.metrics import concordance_index_ipcw
except ImportError:
return float("nan")
# Build structured arrays
y_train = np.array(
list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
dtype=[("event", "?"), ("time", "<f8")],
)
y_test = np.array(
list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
dtype=[("event", "?"), ("time", "<f8")],
)
if tau is None:
tau = float(y_test_time.max()) * 0.95
c, *_ = concordance_index_ipcw(y_train, y_test, risk_score, tau=tau)
return float(c)
def integrated_brier_score(y_train_event, y_train_time, y_test_event, y_test_time,
surv_pred: np.ndarray, times: np.ndarray) -> float:
"""Integrated Brier Score (lower is better)."""
try:
from sksurv.metrics import integrated_brier_score as ibs_fn
except ImportError:
return float("nan")
y_train = np.array(
list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
dtype=[("event", "?"), ("time", "<f8")],
)
y_test = np.array(
list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
dtype=[("event", "?"), ("time", "<f8")],
)
return float(ibs_fn(y_train, y_test, surv_pred, times))
# ---------- Causal / Counterfactual ----------
def e_value(rr: float) -> float:
"""E-value (VanderWeele & Ding 2017): min strength of unmeasured
confounder needed to explain away an observed RR.
"""
rr = max(rr, 1e-9)
if rr >= 1.0:
return rr + math.sqrt(rr * (rr - 1))
rr_inv = 1.0 / rr
return rr_inv + math.sqrt(rr_inv * (rr_inv - 1))
def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool:
"""Negative-control outcome: ATE on a control outcome should be ~0."""
return abs(nc_ate) < threshold
def tipping_point(observed_effect: float, ci_half_width: float) -> float:
"""How much would unmeasured confounding need to shift effect to nullify?"""
if abs(observed_effect) <= ci_half_width:
return 0.0
return float(abs(observed_effect) - ci_half_width)
# ---------- Generation fidelity (SynthEHRella triad) ----------
def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor,
vocab_size: int) -> float:
"""Compare per-token Bernoulli rates between real and synthetic batches.
Returns mean abs difference (lower = closer match).
"""
real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1))
synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1))
return float((real_one_hot - synth_one_hot).abs().mean())
def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float:
"""Maximum Mean Discrepancy with RBF kernel.
x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer).
"""
def rbf(a, b):
d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1)
return torch.exp(-d / (2 * sigma ** 2))
return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean())
# ---------- Subgroup fairness ----------
def stratified_metrics(y: np.ndarray, p: np.ndarray,
groups: np.ndarray,
metric_fn: Callable = auroc) -> dict[str, float]:
"""Compute metric per subgroup (sex, age band, UF region)."""
out = {}
for g in np.unique(groups):
mask = groups == g
if mask.sum() > 10:
try:
out[str(g)] = metric_fn(y[mask], p[mask])
except Exception:
out[str(g)] = float("nan")
return out
# ---------- DATASUS split strategies ----------
def temporal_split(events: list[dict], train_until: int = 2022,
val_year: int = 2023):
"""Temporal split for DATASUS: train ≤2022, val 2023, test 2024+."""
train, val, test = [], [], []
for e in events:
y = e.get("year") or 2020
if y <= train_until: train.append(e)
elif y == val_year: val.append(e)
else: test.append(e)
return train, val, test
def geographic_split(patients: list[dict], external_ufs: set = None):
"""Geographic split: train on SE+S, test on N+NE.
For DATASUS this is the closest analog to "external validation."
"""
if external_ufs is None:
external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA",
"PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"}
train, test = [], []
for p in patients:
uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")),
None)
(test if uf in external_ufs else train).append(p)
return train, test
# ---------- Combined eval report ----------
def full_eval_report(y: np.ndarray, p: np.ndarray,
groups_sex: np.ndarray = None,
groups_age: np.ndarray = None,
groups_uf: np.ndarray = None,
n_boot: int = 2000) -> dict:
"""Generate a full audit-proof report for a binary classification task.
Returns a dict with point estimates + bootstrap CIs + DCA + fairness.
"""
import torch.nn.functional as F # local import to keep top clean
auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot)
auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot)
brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot)
report = {
"n_eval": len(y),
"prevalence": float(y.mean()),
"auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med},
"auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med},
"brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med},
"ici": ici(y, p),
"decision_curve": decision_curve(y, p),
}
if groups_sex is not None:
report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc)
if groups_age is not None:
report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc)
if groups_uf is not None:
report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc)
return report