File size: 10,593 Bytes
a0fa886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""SOTA evaluation suite for CDFv13 — audit-proof.

Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized)
with the canonical EHR foundation model metric stack:

  Classification (next-event, downstream tasks):
    - AUROC + AUPRC + Brier
    - Calibration: ICI (Austin & Steyerberg 2019)
    - Decision-curve analysis (Vickers)
    - Bootstrap 95% CI (≥2000 resamples) — required for rare disease

  Survival (DATASUS SIM mortality):
    - Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring
    - Integrated Brier Score (1/3/5y)
    - Time-dependent AUC

  Counterfactual / causal:
    - ATE with bootstrap CI
    - E-value (VanderWeele)
    - Negative-control outcome + exposure
    - Tipping-point analysis

  Generation fidelity (CoMET / SynthEHRella):
    - Dim-wise probability match
    - MMD (Maximum Mean Discrepancy) with RBF kernel
    - TSTR (Train-on-Synthetic-Test-on-Real)

  Subgroup fairness (npj DM requirement):
    - Stratified metrics: sex, age band, UF region

  Split strategy (DATASUS rare disease):
    - Temporal: train ≤2022, val 2023, test 2024-2025
    - Geographic: train SE+S, test N+NE (UF cross-region = "external")
    - Patient-level 5-fold CV (variance estimation)
"""
from __future__ import annotations
import math
import numpy as np
import torch
from typing import Callable


# ---------- Classification ----------

def auroc(y: np.ndarray, p: np.ndarray) -> float:
    from sklearn.metrics import roc_auc_score
    if len(np.unique(y)) < 2: return float("nan")
    return roc_auc_score(y, p)


def auprc(y: np.ndarray, p: np.ndarray) -> float:
    from sklearn.metrics import average_precision_score
    if len(np.unique(y)) < 2: return float("nan")
    return average_precision_score(y, p)


def brier(y: np.ndarray, p: np.ndarray) -> float:
    from sklearn.metrics import brier_score_loss
    return brier_score_loss(y, p)


def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float:
    """Integrated Calibration Index (Austin & Steyerberg 2019).
    Lowess-smoothed deviation from perfect calibration.
    """
    from statsmodels.nonparametric.smoothers_lowess import lowess
    sm = lowess(y, p, frac=frac, return_sorted=True)
    return float(np.mean(np.abs(sm[:, 1] - sm[:, 0])))


def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float:
    """Net benefit at a given decision threshold (Vickers DCA)."""
    tp = ((p >= threshold) & (y == 1)).sum()
    fp = ((p >= threshold) & (y == 0)).sum()
    n = len(y)
    if threshold >= 1.0: return 0.0
    return tp / n - (fp / n) * (threshold / (1 - threshold))


def decision_curve(y: np.ndarray, p: np.ndarray,
                   thresholds: list[float] = None) -> dict:
    """Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none."""
    if thresholds is None:
        thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
    model_nb = [net_benefit(y, p, t) for t in thresholds]
    treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0
                    for t in thresholds]
    treat_none_nb = [0.0] * len(thresholds)
    return {
        "thresholds": thresholds,
        "model": model_nb,
        "treat_all": treat_all_nb,
        "treat_none": treat_none_nb,
    }


def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable,
                 n_boot: int = 2000, seed: int = 0,
                 ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]:
    """Bootstrap 95% CI for any (y, p) -> scalar metric."""
    rng = np.random.default_rng(seed)
    n = len(y)
    stats = []
    for _ in range(n_boot):
        idx = rng.integers(0, n, n)
        if len(np.unique(y[idx])) < 2: continue
        try:
            stats.append(metric_fn(y[idx], p[idx]))
        except Exception:
            continue
    if not stats: return (float("nan"),) * 3
    return (
        float(np.percentile(stats, ci[0])),
        float(np.median(stats)),
        float(np.percentile(stats, ci[1])),
    )


# ---------- Survival ----------

def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time,
                risk_score, tau: float = None) -> float:
    """Uno's C-index (IPCW concordance), preferred at high censoring.
    Requires scikit-survival.
    """
    try:
        from sksurv.metrics import concordance_index_ipcw
    except ImportError:
        return float("nan")
    # Build structured arrays
    y_train = np.array(
        list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
        dtype=[("event", "?"), ("time", "<f8")],
    )
    y_test = np.array(
        list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
        dtype=[("event", "?"), ("time", "<f8")],
    )
    if tau is None:
        tau = float(y_test_time.max()) * 0.95
    c, *_ = concordance_index_ipcw(y_train, y_test, risk_score, tau=tau)
    return float(c)


def integrated_brier_score(y_train_event, y_train_time, y_test_event, y_test_time,
                            surv_pred: np.ndarray, times: np.ndarray) -> float:
    """Integrated Brier Score (lower is better)."""
    try:
        from sksurv.metrics import integrated_brier_score as ibs_fn
    except ImportError:
        return float("nan")
    y_train = np.array(
        list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
        dtype=[("event", "?"), ("time", "<f8")],
    )
    y_test = np.array(
        list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
        dtype=[("event", "?"), ("time", "<f8")],
    )
    return float(ibs_fn(y_train, y_test, surv_pred, times))


# ---------- Causal / Counterfactual ----------

def e_value(rr: float) -> float:
    """E-value (VanderWeele & Ding 2017): min strength of unmeasured
    confounder needed to explain away an observed RR.
    """
    rr = max(rr, 1e-9)
    if rr >= 1.0:
        return rr + math.sqrt(rr * (rr - 1))
    rr_inv = 1.0 / rr
    return rr_inv + math.sqrt(rr_inv * (rr_inv - 1))


def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool:
    """Negative-control outcome: ATE on a control outcome should be ~0."""
    return abs(nc_ate) < threshold


def tipping_point(observed_effect: float, ci_half_width: float) -> float:
    """How much would unmeasured confounding need to shift effect to nullify?"""
    if abs(observed_effect) <= ci_half_width:
        return 0.0
    return float(abs(observed_effect) - ci_half_width)


# ---------- Generation fidelity (SynthEHRella triad) ----------

def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor,
                          vocab_size: int) -> float:
    """Compare per-token Bernoulli rates between real and synthetic batches.

    Returns mean abs difference (lower = closer match).
    """
    real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1))
    synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1))
    return float((real_one_hot - synth_one_hot).abs().mean())


def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float:
    """Maximum Mean Discrepancy with RBF kernel.

    x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer).
    """
    def rbf(a, b):
        d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1)
        return torch.exp(-d / (2 * sigma ** 2))
    return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean())


# ---------- Subgroup fairness ----------

def stratified_metrics(y: np.ndarray, p: np.ndarray,
                       groups: np.ndarray,
                       metric_fn: Callable = auroc) -> dict[str, float]:
    """Compute metric per subgroup (sex, age band, UF region)."""
    out = {}
    for g in np.unique(groups):
        mask = groups == g
        if mask.sum() > 10:
            try:
                out[str(g)] = metric_fn(y[mask], p[mask])
            except Exception:
                out[str(g)] = float("nan")
    return out


# ---------- DATASUS split strategies ----------

def temporal_split(events: list[dict], train_until: int = 2022,
                   val_year: int = 2023):
    """Temporal split for DATASUS: train ≤2022, val 2023, test 2024+."""
    train, val, test = [], [], []
    for e in events:
        y = e.get("year") or 2020
        if y <= train_until: train.append(e)
        elif y == val_year: val.append(e)
        else: test.append(e)
    return train, val, test


def geographic_split(patients: list[dict], external_ufs: set = None):
    """Geographic split: train on SE+S, test on N+NE.
    For DATASUS this is the closest analog to "external validation."
    """
    if external_ufs is None:
        external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA",
                       "PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"}
    train, test = [], []
    for p in patients:
        uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")),
                  None)
        (test if uf in external_ufs else train).append(p)
    return train, test


# ---------- Combined eval report ----------

def full_eval_report(y: np.ndarray, p: np.ndarray,
                     groups_sex: np.ndarray = None,
                     groups_age: np.ndarray = None,
                     groups_uf: np.ndarray = None,
                     n_boot: int = 2000) -> dict:
    """Generate a full audit-proof report for a binary classification task.

    Returns a dict with point estimates + bootstrap CIs + DCA + fairness.
    """
    import torch.nn.functional as F  # local import to keep top clean

    auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot)
    auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot)
    brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot)

    report = {
        "n_eval": len(y),
        "prevalence": float(y.mean()),
        "auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med},
        "auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med},
        "brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med},
        "ici": ici(y, p),
        "decision_curve": decision_curve(y, p),
    }
    if groups_sex is not None:
        report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc)
    if groups_age is not None:
        report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc)
    if groups_uf is not None:
        report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc)
    return report