File size: 8,416 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""EHRWorld-protocol evaluation metrics, ported for any EHR foundation model.

EHRWorld (arXiv 2602.03569, SJTU Jan 2026) reports three headline metrics on
their 579-episode MIMIC-IV test set:

  - S@25 (success at 25%): for numerical values (lab results, vitals),
    fraction predicted within +-25% relative error of ground truth.
  - SMAPE: symmetric mean absolute percentage error,
        SMAPE = (1/N) * sum( |y_hat - y| / ((|y| + |y_hat|)/2) )
    bounded in [0,2], lower is better.
  - Label-F1: precision/recall F1 over discrete diagnostic labels (set-based,
    multi-label).

We re-implement them here so any GEMEO-CWM checkpoint can be scored on its
own DATASUS test slice with EHRWorld-comparable numbers. We cannot run on
their actual test set without MIMIC-IV credentialing, but the methodological
alignment is legitimate for a "rare-disease + PT-BR" framing.

Reference: Section 4.2 + Appendix C of EHRWorld paper.
"""
from __future__ import annotations
import math
import logging
from dataclasses import dataclass

import torch

log = logging.getLogger("gemeo.cwm.eval_protocol")


@dataclass
class EHRWorldMetrics:
    s_at_25: float
    smape: float
    label_f1: float
    label_precision: float
    label_recall: float
    n_episodes: int
    n_numeric_eval: int
    n_label_eval: int


def s_at_k(y_pred: torch.Tensor, y_true: torch.Tensor, k: float = 0.25) -> float:
    """Fraction of numerical predictions within +-k relative error of truth.

    Both tensors should be 1-D floats (same length). Predictions where
    y_true == 0 are skipped (cannot define relative error).
    """
    if y_pred.numel() == 0:
        return float("nan")
    mask = y_true.abs() > 1e-9
    if mask.sum() == 0:
        return float("nan")
    yt = y_true[mask].float()
    yp = y_pred[mask].float()
    rel_err = (yp - yt).abs() / yt.abs()
    return (rel_err <= k).float().mean().item()


def smape(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    """Symmetric MAPE: mean of |yp - yt| / ((|yt| + |yp|)/2), in [0, 2]."""
    if y_pred.numel() == 0:
        return float("nan")
    num = (y_pred - y_true).abs()
    den = (y_pred.abs() + y_true.abs()) / 2.0
    den = den.clamp(min=1e-9)
    return (num / den).mean().item()


def label_f1(pred_sets: list[set], true_sets: list[set]) -> tuple[float, float, float]:
    """Multi-label macro-F1 over discrete labels.

    Each entry of pred_sets / true_sets is a set of label strings (or ids).
    Returns (precision, recall, f1) micro-averaged across all episodes.
    """
    assert len(pred_sets) == len(true_sets)
    tp = fp = fn = 0
    for p, t in zip(pred_sets, true_sets):
        tp += len(p & t)
        fp += len(p - t)
        fn += len(t - p)
    prec = tp / max(tp + fp, 1)
    rec = tp / max(tp + fn, 1)
    f1 = 2 * prec * rec / max(prec + rec, 1e-9)
    return prec, rec, f1


def long_horizon_retention(
    model_predictions: list[torch.Tensor],
    truth: list[torch.Tensor],
    horizons: list[int] = (1, 5, 10, 20),
) -> dict[int, float]:
    """Fraction of correct predictions at each horizon (next-event accuracy).

    EHRWorld reports 92.6% long-horizon retention (some definition of decay
    from horizon 1 to horizon N). We use Top-1 next-event accuracy at each
    horizon and compute the retention ratio acc(h) / acc(1).
    """
    out = {}
    for h in horizons:
        n_correct = 0
        n_total = 0
        for pred_seq, true_seq in zip(model_predictions, truth):
            if true_seq.numel() < h:
                continue
            t = true_seq[h - 1]
            p = pred_seq[h - 1] if pred_seq.numel() >= h else None
            if p is None:
                continue
            n_total += 1
            if p.item() == t.item():
                n_correct += 1
        out[h] = n_correct / max(n_total, 1)
    return out


@torch.no_grad()
def evaluate_ehrworld_protocol(
    model,
    test_dataset,
    numeric_token_ids: list[int] = None,
    label_token_ids: list[int] = None,
) -> EHRWorldMetrics:
    """Run EHRWorld-protocol evaluation on a CWMDataset slice.

    For GEMEO-CWM:
      - Numeric tokens = LOS buckets (los_short/week/month/long), age buckets
      - Label tokens = CID prefixes + outcome tokens (death/discharge)

    Strategy:
      - Hide last token of each sequence, predict it
      - For numeric tokens: convert bucket -> midpoint value for S@25 / SMAPE
      - For label tokens: predict top-K, compare set vs truth

    This is a SIMPLIFIED port — full EHRWorld eval requires their internal
    timestamp + observation schema. Use this as a within-DATASUS comparison
    metric, not as direct head-to-head.
    """
    device = next(model.parameters()).device
    if hasattr(test_dataset, 'to'):
        x_all, cond_all = test_dataset.to(device)
    else:
        x_all = test_dataset.to(device)
        cond_all = torch.zeros(x_all.size(0), dtype=torch.long, device=device)

    if numeric_token_ids is None:
        # Default: try to identify LOS / age bucket tokens by string prefix
        vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else []
        numeric_token_ids = [
            i for i, t in enumerate(vocab)
            if isinstance(t, str) and (t.startswith("los_") or t.startswith("age_"))
        ]
    if label_token_ids is None:
        vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else []
        label_token_ids = [
            i for i, t in enumerate(vocab)
            if isinstance(t, str) and (t.startswith("cid_") or "outcome" in t
                                       or t.startswith("EV_"))
        ]
    numeric_set = set(numeric_token_ids)
    label_set = set(label_token_ids)

    # Bucket value map (midpoints)
    bucket_values = {
        "los_short": 1.0, "los_week": 4.0, "los_month": 18.0, "los_long": 60.0,
        "age_0_1": 0.5, "age_1_2": 1.5, "age_2_5": 3.5, "age_5_12": 8.5,
        "age_12_18": 15.0, "age_18_30": 24.0, "age_30_50": 40.0,
        "age_50_70": 60.0, "age_70plus": 80.0,
    }

    numeric_preds, numeric_truths = [], []
    pred_label_sets, true_label_sets = [], []
    pad_id = test_dataset.tok2id.get("<PAD>", 0) if hasattr(test_dataset, 'tok2id') else 0

    for i in range(x_all.size(0)):
        seq = x_all[i]
        # find last non-pad position
        valid = (seq != pad_id).sum().item()
        if valid < 5:
            continue
        # Mask the last 5 positions, ask the model to fill
        truth_window = seq[max(0, valid - 5):valid].clone()
        masked_seq = seq.clone()
        mask_token = model.cfg.mask_token
        masked_seq[max(0, valid - 5):valid] = mask_token
        t_zero = torch.tensor([0.05], device=device)
        cond = cond_all[i:i+1]
        logits = model(masked_seq.unsqueeze(0), t_zero, cond)
        logits[:, :, mask_token] = -1e9
        preds = logits[0].argmax(dim=-1)
        pred_window = preds[max(0, valid - 5):valid]

        # Numeric: convert tokens to bucket values
        for j in range(pred_window.size(0)):
            pt, tt = pred_window[j].item(), truth_window[j].item()
            if pt in numeric_set or tt in numeric_set:
                vocab = test_dataset.vocab
                pv = bucket_values.get(vocab[pt], None) if pt < len(vocab) else None
                tv = bucket_values.get(vocab[tt], None) if tt < len(vocab) else None
                if pv is not None and tv is not None:
                    numeric_preds.append(pv)
                    numeric_truths.append(tv)

        # Labels: collect sets
        pred_lbl = set(int(x.item()) for x in pred_window if int(x.item()) in label_set)
        true_lbl = set(int(x.item()) for x in truth_window if int(x.item()) in label_set)
        if pred_lbl or true_lbl:
            pred_label_sets.append(pred_lbl)
            true_label_sets.append(true_lbl)

    yp = torch.tensor(numeric_preds, dtype=torch.float)
    yt = torch.tensor(numeric_truths, dtype=torch.float)
    s25 = s_at_k(yp, yt, k=0.25) if yp.numel() else float("nan")
    sm = smape(yp, yt) if yp.numel() else float("nan")
    prec, rec, f1 = (label_f1(pred_label_sets, true_label_sets)
                     if pred_label_sets else (0.0, 0.0, 0.0))

    return EHRWorldMetrics(
        s_at_25=s25, smape=sm, label_f1=f1,
        label_precision=prec, label_recall=rec,
        n_episodes=x_all.size(0), n_numeric_eval=len(numeric_preds),
        n_label_eval=len(pred_label_sets),
    )