"""SOTA evaluation suite for CDFv13 — audit-proof. Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized) with the canonical EHR foundation model metric stack: Classification (next-event, downstream tasks): - AUROC + AUPRC + Brier - Calibration: ICI (Austin & Steyerberg 2019) - Decision-curve analysis (Vickers) - Bootstrap 95% CI (≥2000 resamples) — required for rare disease Survival (DATASUS SIM mortality): - Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring - Integrated Brier Score (1/3/5y) - Time-dependent AUC Counterfactual / causal: - ATE with bootstrap CI - E-value (VanderWeele) - Negative-control outcome + exposure - Tipping-point analysis Generation fidelity (CoMET / SynthEHRella): - Dim-wise probability match - MMD (Maximum Mean Discrepancy) with RBF kernel - TSTR (Train-on-Synthetic-Test-on-Real) Subgroup fairness (npj DM requirement): - Stratified metrics: sex, age band, UF region Split strategy (DATASUS rare disease): - Temporal: train ≤2022, val 2023, test 2024-2025 - Geographic: train SE+S, test N+NE (UF cross-region = "external") - Patient-level 5-fold CV (variance estimation) """ from __future__ import annotations import math import numpy as np import torch from typing import Callable # ---------- Classification ---------- def auroc(y: np.ndarray, p: np.ndarray) -> float: from sklearn.metrics import roc_auc_score if len(np.unique(y)) < 2: return float("nan") return roc_auc_score(y, p) def auprc(y: np.ndarray, p: np.ndarray) -> float: from sklearn.metrics import average_precision_score if len(np.unique(y)) < 2: return float("nan") return average_precision_score(y, p) def brier(y: np.ndarray, p: np.ndarray) -> float: from sklearn.metrics import brier_score_loss return brier_score_loss(y, p) def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float: """Integrated Calibration Index (Austin & Steyerberg 2019). Lowess-smoothed deviation from perfect calibration. """ from statsmodels.nonparametric.smoothers_lowess import lowess sm = lowess(y, p, frac=frac, return_sorted=True) return float(np.mean(np.abs(sm[:, 1] - sm[:, 0]))) def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float: """Net benefit at a given decision threshold (Vickers DCA).""" tp = ((p >= threshold) & (y == 1)).sum() fp = ((p >= threshold) & (y == 0)).sum() n = len(y) if threshold >= 1.0: return 0.0 return tp / n - (fp / n) * (threshold / (1 - threshold)) def decision_curve(y: np.ndarray, p: np.ndarray, thresholds: list[float] = None) -> dict: """Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none.""" if thresholds is None: thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5] model_nb = [net_benefit(y, p, t) for t in thresholds] treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0 for t in thresholds] treat_none_nb = [0.0] * len(thresholds) return { "thresholds": thresholds, "model": model_nb, "treat_all": treat_all_nb, "treat_none": treat_none_nb, } def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable, n_boot: int = 2000, seed: int = 0, ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]: """Bootstrap 95% CI for any (y, p) -> scalar metric.""" rng = np.random.default_rng(seed) n = len(y) stats = [] for _ in range(n_boot): idx = rng.integers(0, n, n) if len(np.unique(y[idx])) < 2: continue try: stats.append(metric_fn(y[idx], p[idx])) except Exception: continue if not stats: return (float("nan"),) * 3 return ( float(np.percentile(stats, ci[0])), float(np.median(stats)), float(np.percentile(stats, ci[1])), ) # ---------- Survival ---------- def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time, risk_score, tau: float = None) -> float: """Uno's C-index (IPCW concordance), preferred at high censoring. Requires scikit-survival. """ try: from sksurv.metrics import concordance_index_ipcw except ImportError: return float("nan") # Build structured arrays y_train = np.array( list(zip(y_train_event.astype(bool), y_train_time.astype(float))), dtype=[("event", "?"), ("time", " float: """Integrated Brier Score (lower is better).""" try: from sksurv.metrics import integrated_brier_score as ibs_fn except ImportError: return float("nan") y_train = np.array( list(zip(y_train_event.astype(bool), y_train_time.astype(float))), dtype=[("event", "?"), ("time", " float: """E-value (VanderWeele & Ding 2017): min strength of unmeasured confounder needed to explain away an observed RR. """ rr = max(rr, 1e-9) if rr >= 1.0: return rr + math.sqrt(rr * (rr - 1)) rr_inv = 1.0 / rr return rr_inv + math.sqrt(rr_inv * (rr_inv - 1)) def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool: """Negative-control outcome: ATE on a control outcome should be ~0.""" return abs(nc_ate) < threshold def tipping_point(observed_effect: float, ci_half_width: float) -> float: """How much would unmeasured confounding need to shift effect to nullify?""" if abs(observed_effect) <= ci_half_width: return 0.0 return float(abs(observed_effect) - ci_half_width) # ---------- Generation fidelity (SynthEHRella triad) ---------- def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor, vocab_size: int) -> float: """Compare per-token Bernoulli rates between real and synthetic batches. Returns mean abs difference (lower = closer match). """ real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1)) synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1)) return float((real_one_hot - synth_one_hot).abs().mean()) def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float: """Maximum Mean Discrepancy with RBF kernel. x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer). """ def rbf(a, b): d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1) return torch.exp(-d / (2 * sigma ** 2)) return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean()) # ---------- Subgroup fairness ---------- def stratified_metrics(y: np.ndarray, p: np.ndarray, groups: np.ndarray, metric_fn: Callable = auroc) -> dict[str, float]: """Compute metric per subgroup (sex, age band, UF region).""" out = {} for g in np.unique(groups): mask = groups == g if mask.sum() > 10: try: out[str(g)] = metric_fn(y[mask], p[mask]) except Exception: out[str(g)] = float("nan") return out # ---------- DATASUS split strategies ---------- def temporal_split(events: list[dict], train_until: int = 2022, val_year: int = 2023): """Temporal split for DATASUS: train ≤2022, val 2023, test 2024+.""" train, val, test = [], [], [] for e in events: y = e.get("year") or 2020 if y <= train_until: train.append(e) elif y == val_year: val.append(e) else: test.append(e) return train, val, test def geographic_split(patients: list[dict], external_ufs: set = None): """Geographic split: train on SE+S, test on N+NE. For DATASUS this is the closest analog to "external validation." """ if external_ufs is None: external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA", "PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"} train, test = [], [] for p in patients: uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")), None) (test if uf in external_ufs else train).append(p) return train, test # ---------- Combined eval report ---------- def full_eval_report(y: np.ndarray, p: np.ndarray, groups_sex: np.ndarray = None, groups_age: np.ndarray = None, groups_uf: np.ndarray = None, n_boot: int = 2000) -> dict: """Generate a full audit-proof report for a binary classification task. Returns a dict with point estimates + bootstrap CIs + DCA + fairness. """ import torch.nn.functional as F # local import to keep top clean auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot) auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot) brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot) report = { "n_eval": len(y), "prevalence": float(y.mean()), "auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med}, "auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med}, "brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med}, "ici": ici(y, p), "decision_curve": decision_curve(y, p), } if groups_sex is not None: report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc) if groups_age is not None: report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc) if groups_uf is not None: report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc) return report