gemeo-twin-stack / src /gemeo /cwm /eval_protocol.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""EHRWorld-protocol evaluation metrics, ported for any EHR foundation model.
EHRWorld (arXiv 2602.03569, SJTU Jan 2026) reports three headline metrics on
their 579-episode MIMIC-IV test set:
- S@25 (success at 25%): for numerical values (lab results, vitals),
fraction predicted within +-25% relative error of ground truth.
- SMAPE: symmetric mean absolute percentage error,
SMAPE = (1/N) * sum( |y_hat - y| / ((|y| + |y_hat|)/2) )
bounded in [0,2], lower is better.
- Label-F1: precision/recall F1 over discrete diagnostic labels (set-based,
multi-label).
We re-implement them here so any GEMEO-CWM checkpoint can be scored on its
own DATASUS test slice with EHRWorld-comparable numbers. We cannot run on
their actual test set without MIMIC-IV credentialing, but the methodological
alignment is legitimate for a "rare-disease + PT-BR" framing.
Reference: Section 4.2 + Appendix C of EHRWorld paper.
"""
from __future__ import annotations
import math
import logging
from dataclasses import dataclass
import torch
log = logging.getLogger("gemeo.cwm.eval_protocol")
@dataclass
class EHRWorldMetrics:
s_at_25: float
smape: float
label_f1: float
label_precision: float
label_recall: float
n_episodes: int
n_numeric_eval: int
n_label_eval: int
def s_at_k(y_pred: torch.Tensor, y_true: torch.Tensor, k: float = 0.25) -> float:
"""Fraction of numerical predictions within +-k relative error of truth.
Both tensors should be 1-D floats (same length). Predictions where
y_true == 0 are skipped (cannot define relative error).
"""
if y_pred.numel() == 0:
return float("nan")
mask = y_true.abs() > 1e-9
if mask.sum() == 0:
return float("nan")
yt = y_true[mask].float()
yp = y_pred[mask].float()
rel_err = (yp - yt).abs() / yt.abs()
return (rel_err <= k).float().mean().item()
def smape(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
"""Symmetric MAPE: mean of |yp - yt| / ((|yt| + |yp|)/2), in [0, 2]."""
if y_pred.numel() == 0:
return float("nan")
num = (y_pred - y_true).abs()
den = (y_pred.abs() + y_true.abs()) / 2.0
den = den.clamp(min=1e-9)
return (num / den).mean().item()
def label_f1(pred_sets: list[set], true_sets: list[set]) -> tuple[float, float, float]:
"""Multi-label macro-F1 over discrete labels.
Each entry of pred_sets / true_sets is a set of label strings (or ids).
Returns (precision, recall, f1) micro-averaged across all episodes.
"""
assert len(pred_sets) == len(true_sets)
tp = fp = fn = 0
for p, t in zip(pred_sets, true_sets):
tp += len(p & t)
fp += len(p - t)
fn += len(t - p)
prec = tp / max(tp + fp, 1)
rec = tp / max(tp + fn, 1)
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
return prec, rec, f1
def long_horizon_retention(
model_predictions: list[torch.Tensor],
truth: list[torch.Tensor],
horizons: list[int] = (1, 5, 10, 20),
) -> dict[int, float]:
"""Fraction of correct predictions at each horizon (next-event accuracy).
EHRWorld reports 92.6% long-horizon retention (some definition of decay
from horizon 1 to horizon N). We use Top-1 next-event accuracy at each
horizon and compute the retention ratio acc(h) / acc(1).
"""
out = {}
for h in horizons:
n_correct = 0
n_total = 0
for pred_seq, true_seq in zip(model_predictions, truth):
if true_seq.numel() < h:
continue
t = true_seq[h - 1]
p = pred_seq[h - 1] if pred_seq.numel() >= h else None
if p is None:
continue
n_total += 1
if p.item() == t.item():
n_correct += 1
out[h] = n_correct / max(n_total, 1)
return out
@torch.no_grad()
def evaluate_ehrworld_protocol(
model,
test_dataset,
numeric_token_ids: list[int] = None,
label_token_ids: list[int] = None,
) -> EHRWorldMetrics:
"""Run EHRWorld-protocol evaluation on a CWMDataset slice.
For GEMEO-CWM:
- Numeric tokens = LOS buckets (los_short/week/month/long), age buckets
- Label tokens = CID prefixes + outcome tokens (death/discharge)
Strategy:
- Hide last token of each sequence, predict it
- For numeric tokens: convert bucket -> midpoint value for S@25 / SMAPE
- For label tokens: predict top-K, compare set vs truth
This is a SIMPLIFIED port — full EHRWorld eval requires their internal
timestamp + observation schema. Use this as a within-DATASUS comparison
metric, not as direct head-to-head.
"""
device = next(model.parameters()).device
if hasattr(test_dataset, 'to'):
x_all, cond_all = test_dataset.to(device)
else:
x_all = test_dataset.to(device)
cond_all = torch.zeros(x_all.size(0), dtype=torch.long, device=device)
if numeric_token_ids is None:
# Default: try to identify LOS / age bucket tokens by string prefix
vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else []
numeric_token_ids = [
i for i, t in enumerate(vocab)
if isinstance(t, str) and (t.startswith("los_") or t.startswith("age_"))
]
if label_token_ids is None:
vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else []
label_token_ids = [
i for i, t in enumerate(vocab)
if isinstance(t, str) and (t.startswith("cid_") or "outcome" in t
or t.startswith("EV_"))
]
numeric_set = set(numeric_token_ids)
label_set = set(label_token_ids)
# Bucket value map (midpoints)
bucket_values = {
"los_short": 1.0, "los_week": 4.0, "los_month": 18.0, "los_long": 60.0,
"age_0_1": 0.5, "age_1_2": 1.5, "age_2_5": 3.5, "age_5_12": 8.5,
"age_12_18": 15.0, "age_18_30": 24.0, "age_30_50": 40.0,
"age_50_70": 60.0, "age_70plus": 80.0,
}
numeric_preds, numeric_truths = [], []
pred_label_sets, true_label_sets = [], []
pad_id = test_dataset.tok2id.get("<PAD>", 0) if hasattr(test_dataset, 'tok2id') else 0
for i in range(x_all.size(0)):
seq = x_all[i]
# find last non-pad position
valid = (seq != pad_id).sum().item()
if valid < 5:
continue
# Mask the last 5 positions, ask the model to fill
truth_window = seq[max(0, valid - 5):valid].clone()
masked_seq = seq.clone()
mask_token = model.cfg.mask_token
masked_seq[max(0, valid - 5):valid] = mask_token
t_zero = torch.tensor([0.05], device=device)
cond = cond_all[i:i+1]
logits = model(masked_seq.unsqueeze(0), t_zero, cond)
logits[:, :, mask_token] = -1e9
preds = logits[0].argmax(dim=-1)
pred_window = preds[max(0, valid - 5):valid]
# Numeric: convert tokens to bucket values
for j in range(pred_window.size(0)):
pt, tt = pred_window[j].item(), truth_window[j].item()
if pt in numeric_set or tt in numeric_set:
vocab = test_dataset.vocab
pv = bucket_values.get(vocab[pt], None) if pt < len(vocab) else None
tv = bucket_values.get(vocab[tt], None) if tt < len(vocab) else None
if pv is not None and tv is not None:
numeric_preds.append(pv)
numeric_truths.append(tv)
# Labels: collect sets
pred_lbl = set(int(x.item()) for x in pred_window if int(x.item()) in label_set)
true_lbl = set(int(x.item()) for x in truth_window if int(x.item()) in label_set)
if pred_lbl or true_lbl:
pred_label_sets.append(pred_lbl)
true_label_sets.append(true_lbl)
yp = torch.tensor(numeric_preds, dtype=torch.float)
yt = torch.tensor(numeric_truths, dtype=torch.float)
s25 = s_at_k(yp, yt, k=0.25) if yp.numel() else float("nan")
sm = smape(yp, yt) if yp.numel() else float("nan")
prec, rec, f1 = (label_f1(pred_label_sets, true_label_sets)
if pred_label_sets else (0.0, 0.0, 0.0))
return EHRWorldMetrics(
s_at_25=s25, smape=sm, label_f1=f1,
label_precision=prec, label_recall=rec,
n_episodes=x_all.size(0), n_numeric_eval=len(numeric_preds),
n_label_eval=len(pred_label_sets),
)