| """SOTA evaluation suite for CDFv13 — audit-proof. |
| |
| Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized) |
| with the canonical EHR foundation model metric stack: |
| |
| Classification (next-event, downstream tasks): |
| - AUROC + AUPRC + Brier |
| - Calibration: ICI (Austin & Steyerberg 2019) |
| - Decision-curve analysis (Vickers) |
| - Bootstrap 95% CI (≥2000 resamples) — required for rare disease |
| |
| Survival (DATASUS SIM mortality): |
| - Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring |
| - Integrated Brier Score (1/3/5y) |
| - Time-dependent AUC |
| |
| Counterfactual / causal: |
| - ATE with bootstrap CI |
| - E-value (VanderWeele) |
| - Negative-control outcome + exposure |
| - Tipping-point analysis |
| |
| Generation fidelity (CoMET / SynthEHRella): |
| - Dim-wise probability match |
| - MMD (Maximum Mean Discrepancy) with RBF kernel |
| - TSTR (Train-on-Synthetic-Test-on-Real) |
| |
| Subgroup fairness (npj DM requirement): |
| - Stratified metrics: sex, age band, UF region |
| |
| Split strategy (DATASUS rare disease): |
| - Temporal: train ≤2022, val 2023, test 2024-2025 |
| - Geographic: train SE+S, test N+NE (UF cross-region = "external") |
| - Patient-level 5-fold CV (variance estimation) |
| """ |
| from __future__ import annotations |
| import math |
| import numpy as np |
| import torch |
| from typing import Callable |
|
|
|
|
| |
|
|
| def auroc(y: np.ndarray, p: np.ndarray) -> float: |
| from sklearn.metrics import roc_auc_score |
| if len(np.unique(y)) < 2: return float("nan") |
| return roc_auc_score(y, p) |
|
|
|
|
| def auprc(y: np.ndarray, p: np.ndarray) -> float: |
| from sklearn.metrics import average_precision_score |
| if len(np.unique(y)) < 2: return float("nan") |
| return average_precision_score(y, p) |
|
|
|
|
| def brier(y: np.ndarray, p: np.ndarray) -> float: |
| from sklearn.metrics import brier_score_loss |
| return brier_score_loss(y, p) |
|
|
|
|
| def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float: |
| """Integrated Calibration Index (Austin & Steyerberg 2019). |
| Lowess-smoothed deviation from perfect calibration. |
| """ |
| from statsmodels.nonparametric.smoothers_lowess import lowess |
| sm = lowess(y, p, frac=frac, return_sorted=True) |
| return float(np.mean(np.abs(sm[:, 1] - sm[:, 0]))) |
|
|
|
|
| def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float: |
| """Net benefit at a given decision threshold (Vickers DCA).""" |
| tp = ((p >= threshold) & (y == 1)).sum() |
| fp = ((p >= threshold) & (y == 0)).sum() |
| n = len(y) |
| if threshold >= 1.0: return 0.0 |
| return tp / n - (fp / n) * (threshold / (1 - threshold)) |
|
|
|
|
| def decision_curve(y: np.ndarray, p: np.ndarray, |
| thresholds: list[float] = None) -> dict: |
| """Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none.""" |
| if thresholds is None: |
| thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5] |
| model_nb = [net_benefit(y, p, t) for t in thresholds] |
| treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0 |
| for t in thresholds] |
| treat_none_nb = [0.0] * len(thresholds) |
| return { |
| "thresholds": thresholds, |
| "model": model_nb, |
| "treat_all": treat_all_nb, |
| "treat_none": treat_none_nb, |
| } |
|
|
|
|
| def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable, |
| n_boot: int = 2000, seed: int = 0, |
| ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]: |
| """Bootstrap 95% CI for any (y, p) -> scalar metric.""" |
| rng = np.random.default_rng(seed) |
| n = len(y) |
| stats = [] |
| for _ in range(n_boot): |
| idx = rng.integers(0, n, n) |
| if len(np.unique(y[idx])) < 2: continue |
| try: |
| stats.append(metric_fn(y[idx], p[idx])) |
| except Exception: |
| continue |
| if not stats: return (float("nan"),) * 3 |
| return ( |
| float(np.percentile(stats, ci[0])), |
| float(np.median(stats)), |
| float(np.percentile(stats, ci[1])), |
| ) |
|
|
|
|
| |
|
|
| def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time, |
| risk_score, tau: float = None) -> float: |
| """Uno's C-index (IPCW concordance), preferred at high censoring. |
| Requires scikit-survival. |
| """ |
| try: |
| from sksurv.metrics import concordance_index_ipcw |
| except ImportError: |
| return float("nan") |
| |
| y_train = np.array( |
| list(zip(y_train_event.astype(bool), y_train_time.astype(float))), |
| dtype=[("event", "?"), ("time", "<f8")], |
| ) |
| y_test = np.array( |
| list(zip(y_test_event.astype(bool), y_test_time.astype(float))), |
| dtype=[("event", "?"), ("time", "<f8")], |
| ) |
| if tau is None: |
| tau = float(y_test_time.max()) * 0.95 |
| c, *_ = concordance_index_ipcw(y_train, y_test, risk_score, tau=tau) |
| return float(c) |
|
|
|
|
| def integrated_brier_score(y_train_event, y_train_time, y_test_event, y_test_time, |
| surv_pred: np.ndarray, times: np.ndarray) -> float: |
| """Integrated Brier Score (lower is better).""" |
| try: |
| from sksurv.metrics import integrated_brier_score as ibs_fn |
| except ImportError: |
| return float("nan") |
| y_train = np.array( |
| list(zip(y_train_event.astype(bool), y_train_time.astype(float))), |
| dtype=[("event", "?"), ("time", "<f8")], |
| ) |
| y_test = np.array( |
| list(zip(y_test_event.astype(bool), y_test_time.astype(float))), |
| dtype=[("event", "?"), ("time", "<f8")], |
| ) |
| return float(ibs_fn(y_train, y_test, surv_pred, times)) |
|
|
|
|
| |
|
|
| def e_value(rr: float) -> float: |
| """E-value (VanderWeele & Ding 2017): min strength of unmeasured |
| confounder needed to explain away an observed RR. |
| """ |
| rr = max(rr, 1e-9) |
| if rr >= 1.0: |
| return rr + math.sqrt(rr * (rr - 1)) |
| rr_inv = 1.0 / rr |
| return rr_inv + math.sqrt(rr_inv * (rr_inv - 1)) |
|
|
|
|
| def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool: |
| """Negative-control outcome: ATE on a control outcome should be ~0.""" |
| return abs(nc_ate) < threshold |
|
|
|
|
| def tipping_point(observed_effect: float, ci_half_width: float) -> float: |
| """How much would unmeasured confounding need to shift effect to nullify?""" |
| if abs(observed_effect) <= ci_half_width: |
| return 0.0 |
| return float(abs(observed_effect) - ci_half_width) |
|
|
|
|
| |
|
|
| def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor, |
| vocab_size: int) -> float: |
| """Compare per-token Bernoulli rates between real and synthetic batches. |
| |
| Returns mean abs difference (lower = closer match). |
| """ |
| real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1)) |
| synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1)) |
| return float((real_one_hot - synth_one_hot).abs().mean()) |
|
|
|
|
| def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float: |
| """Maximum Mean Discrepancy with RBF kernel. |
| |
| x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer). |
| """ |
| def rbf(a, b): |
| d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1) |
| return torch.exp(-d / (2 * sigma ** 2)) |
| return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean()) |
|
|
|
|
| |
|
|
| def stratified_metrics(y: np.ndarray, p: np.ndarray, |
| groups: np.ndarray, |
| metric_fn: Callable = auroc) -> dict[str, float]: |
| """Compute metric per subgroup (sex, age band, UF region).""" |
| out = {} |
| for g in np.unique(groups): |
| mask = groups == g |
| if mask.sum() > 10: |
| try: |
| out[str(g)] = metric_fn(y[mask], p[mask]) |
| except Exception: |
| out[str(g)] = float("nan") |
| return out |
|
|
|
|
| |
|
|
| def temporal_split(events: list[dict], train_until: int = 2022, |
| val_year: int = 2023): |
| """Temporal split for DATASUS: train ≤2022, val 2023, test 2024+.""" |
| train, val, test = [], [], [] |
| for e in events: |
| y = e.get("year") or 2020 |
| if y <= train_until: train.append(e) |
| elif y == val_year: val.append(e) |
| else: test.append(e) |
| return train, val, test |
|
|
|
|
| def geographic_split(patients: list[dict], external_ufs: set = None): |
| """Geographic split: train on SE+S, test on N+NE. |
| For DATASUS this is the closest analog to "external validation." |
| """ |
| if external_ufs is None: |
| external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA", |
| "PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"} |
| train, test = [], [] |
| for p in patients: |
| uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")), |
| None) |
| (test if uf in external_ufs else train).append(p) |
| return train, test |
|
|
|
|
| |
|
|
| def full_eval_report(y: np.ndarray, p: np.ndarray, |
| groups_sex: np.ndarray = None, |
| groups_age: np.ndarray = None, |
| groups_uf: np.ndarray = None, |
| n_boot: int = 2000) -> dict: |
| """Generate a full audit-proof report for a binary classification task. |
| |
| Returns a dict with point estimates + bootstrap CIs + DCA + fairness. |
| """ |
| import torch.nn.functional as F |
|
|
| auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot) |
| auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot) |
| brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot) |
|
|
| report = { |
| "n_eval": len(y), |
| "prevalence": float(y.mean()), |
| "auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med}, |
| "auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med}, |
| "brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med}, |
| "ici": ici(y, p), |
| "decision_curve": decision_curve(y, p), |
| } |
| if groups_sex is not None: |
| report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc) |
| if groups_age is not None: |
| report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc) |
| if groups_uf is not None: |
| report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc) |
| return report |
|
|