File size: 8,416 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | """EHRWorld-protocol evaluation metrics, ported for any EHR foundation model.
EHRWorld (arXiv 2602.03569, SJTU Jan 2026) reports three headline metrics on
their 579-episode MIMIC-IV test set:
- S@25 (success at 25%): for numerical values (lab results, vitals),
fraction predicted within +-25% relative error of ground truth.
- SMAPE: symmetric mean absolute percentage error,
SMAPE = (1/N) * sum( |y_hat - y| / ((|y| + |y_hat|)/2) )
bounded in [0,2], lower is better.
- Label-F1: precision/recall F1 over discrete diagnostic labels (set-based,
multi-label).
We re-implement them here so any GEMEO-CWM checkpoint can be scored on its
own DATASUS test slice with EHRWorld-comparable numbers. We cannot run on
their actual test set without MIMIC-IV credentialing, but the methodological
alignment is legitimate for a "rare-disease + PT-BR" framing.
Reference: Section 4.2 + Appendix C of EHRWorld paper.
"""
from __future__ import annotations
import math
import logging
from dataclasses import dataclass
import torch
log = logging.getLogger("gemeo.cwm.eval_protocol")
@dataclass
class EHRWorldMetrics:
s_at_25: float
smape: float
label_f1: float
label_precision: float
label_recall: float
n_episodes: int
n_numeric_eval: int
n_label_eval: int
def s_at_k(y_pred: torch.Tensor, y_true: torch.Tensor, k: float = 0.25) -> float:
"""Fraction of numerical predictions within +-k relative error of truth.
Both tensors should be 1-D floats (same length). Predictions where
y_true == 0 are skipped (cannot define relative error).
"""
if y_pred.numel() == 0:
return float("nan")
mask = y_true.abs() > 1e-9
if mask.sum() == 0:
return float("nan")
yt = y_true[mask].float()
yp = y_pred[mask].float()
rel_err = (yp - yt).abs() / yt.abs()
return (rel_err <= k).float().mean().item()
def smape(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
"""Symmetric MAPE: mean of |yp - yt| / ((|yt| + |yp|)/2), in [0, 2]."""
if y_pred.numel() == 0:
return float("nan")
num = (y_pred - y_true).abs()
den = (y_pred.abs() + y_true.abs()) / 2.0
den = den.clamp(min=1e-9)
return (num / den).mean().item()
def label_f1(pred_sets: list[set], true_sets: list[set]) -> tuple[float, float, float]:
"""Multi-label macro-F1 over discrete labels.
Each entry of pred_sets / true_sets is a set of label strings (or ids).
Returns (precision, recall, f1) micro-averaged across all episodes.
"""
assert len(pred_sets) == len(true_sets)
tp = fp = fn = 0
for p, t in zip(pred_sets, true_sets):
tp += len(p & t)
fp += len(p - t)
fn += len(t - p)
prec = tp / max(tp + fp, 1)
rec = tp / max(tp + fn, 1)
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
return prec, rec, f1
def long_horizon_retention(
model_predictions: list[torch.Tensor],
truth: list[torch.Tensor],
horizons: list[int] = (1, 5, 10, 20),
) -> dict[int, float]:
"""Fraction of correct predictions at each horizon (next-event accuracy).
EHRWorld reports 92.6% long-horizon retention (some definition of decay
from horizon 1 to horizon N). We use Top-1 next-event accuracy at each
horizon and compute the retention ratio acc(h) / acc(1).
"""
out = {}
for h in horizons:
n_correct = 0
n_total = 0
for pred_seq, true_seq in zip(model_predictions, truth):
if true_seq.numel() < h:
continue
t = true_seq[h - 1]
p = pred_seq[h - 1] if pred_seq.numel() >= h else None
if p is None:
continue
n_total += 1
if p.item() == t.item():
n_correct += 1
out[h] = n_correct / max(n_total, 1)
return out
@torch.no_grad()
def evaluate_ehrworld_protocol(
model,
test_dataset,
numeric_token_ids: list[int] = None,
label_token_ids: list[int] = None,
) -> EHRWorldMetrics:
"""Run EHRWorld-protocol evaluation on a CWMDataset slice.
For GEMEO-CWM:
- Numeric tokens = LOS buckets (los_short/week/month/long), age buckets
- Label tokens = CID prefixes + outcome tokens (death/discharge)
Strategy:
- Hide last token of each sequence, predict it
- For numeric tokens: convert bucket -> midpoint value for S@25 / SMAPE
- For label tokens: predict top-K, compare set vs truth
This is a SIMPLIFIED port — full EHRWorld eval requires their internal
timestamp + observation schema. Use this as a within-DATASUS comparison
metric, not as direct head-to-head.
"""
device = next(model.parameters()).device
if hasattr(test_dataset, 'to'):
x_all, cond_all = test_dataset.to(device)
else:
x_all = test_dataset.to(device)
cond_all = torch.zeros(x_all.size(0), dtype=torch.long, device=device)
if numeric_token_ids is None:
# Default: try to identify LOS / age bucket tokens by string prefix
vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else []
numeric_token_ids = [
i for i, t in enumerate(vocab)
if isinstance(t, str) and (t.startswith("los_") or t.startswith("age_"))
]
if label_token_ids is None:
vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else []
label_token_ids = [
i for i, t in enumerate(vocab)
if isinstance(t, str) and (t.startswith("cid_") or "outcome" in t
or t.startswith("EV_"))
]
numeric_set = set(numeric_token_ids)
label_set = set(label_token_ids)
# Bucket value map (midpoints)
bucket_values = {
"los_short": 1.0, "los_week": 4.0, "los_month": 18.0, "los_long": 60.0,
"age_0_1": 0.5, "age_1_2": 1.5, "age_2_5": 3.5, "age_5_12": 8.5,
"age_12_18": 15.0, "age_18_30": 24.0, "age_30_50": 40.0,
"age_50_70": 60.0, "age_70plus": 80.0,
}
numeric_preds, numeric_truths = [], []
pred_label_sets, true_label_sets = [], []
pad_id = test_dataset.tok2id.get("<PAD>", 0) if hasattr(test_dataset, 'tok2id') else 0
for i in range(x_all.size(0)):
seq = x_all[i]
# find last non-pad position
valid = (seq != pad_id).sum().item()
if valid < 5:
continue
# Mask the last 5 positions, ask the model to fill
truth_window = seq[max(0, valid - 5):valid].clone()
masked_seq = seq.clone()
mask_token = model.cfg.mask_token
masked_seq[max(0, valid - 5):valid] = mask_token
t_zero = torch.tensor([0.05], device=device)
cond = cond_all[i:i+1]
logits = model(masked_seq.unsqueeze(0), t_zero, cond)
logits[:, :, mask_token] = -1e9
preds = logits[0].argmax(dim=-1)
pred_window = preds[max(0, valid - 5):valid]
# Numeric: convert tokens to bucket values
for j in range(pred_window.size(0)):
pt, tt = pred_window[j].item(), truth_window[j].item()
if pt in numeric_set or tt in numeric_set:
vocab = test_dataset.vocab
pv = bucket_values.get(vocab[pt], None) if pt < len(vocab) else None
tv = bucket_values.get(vocab[tt], None) if tt < len(vocab) else None
if pv is not None and tv is not None:
numeric_preds.append(pv)
numeric_truths.append(tv)
# Labels: collect sets
pred_lbl = set(int(x.item()) for x in pred_window if int(x.item()) in label_set)
true_lbl = set(int(x.item()) for x in truth_window if int(x.item()) in label_set)
if pred_lbl or true_lbl:
pred_label_sets.append(pred_lbl)
true_label_sets.append(true_lbl)
yp = torch.tensor(numeric_preds, dtype=torch.float)
yt = torch.tensor(numeric_truths, dtype=torch.float)
s25 = s_at_k(yp, yt, k=0.25) if yp.numel() else float("nan")
sm = smape(yp, yt) if yp.numel() else float("nan")
prec, rec, f1 = (label_f1(pred_label_sets, true_label_sets)
if pred_label_sets else (0.0, 0.0, 0.0))
return EHRWorldMetrics(
s_at_25=s25, smape=sm, label_f1=f1,
label_precision=prec, label_recall=rec,
n_episodes=x_all.size(0), n_numeric_eval=len(numeric_preds),
n_label_eval=len(pred_label_sets),
)
|