| """EHRWorld-protocol evaluation metrics, ported for any EHR foundation model. |
| |
| EHRWorld (arXiv 2602.03569, SJTU Jan 2026) reports three headline metrics on |
| their 579-episode MIMIC-IV test set: |
| |
| - S@25 (success at 25%): for numerical values (lab results, vitals), |
| fraction predicted within +-25% relative error of ground truth. |
| - SMAPE: symmetric mean absolute percentage error, |
| SMAPE = (1/N) * sum( |y_hat - y| / ((|y| + |y_hat|)/2) ) |
| bounded in [0,2], lower is better. |
| - Label-F1: precision/recall F1 over discrete diagnostic labels (set-based, |
| multi-label). |
| |
| We re-implement them here so any GEMEO-CWM checkpoint can be scored on its |
| own DATASUS test slice with EHRWorld-comparable numbers. We cannot run on |
| their actual test set without MIMIC-IV credentialing, but the methodological |
| alignment is legitimate for a "rare-disease + PT-BR" framing. |
| |
| Reference: Section 4.2 + Appendix C of EHRWorld paper. |
| """ |
| from __future__ import annotations |
| import math |
| import logging |
| from dataclasses import dataclass |
|
|
| import torch |
|
|
| log = logging.getLogger("gemeo.cwm.eval_protocol") |
|
|
|
|
| @dataclass |
| class EHRWorldMetrics: |
| s_at_25: float |
| smape: float |
| label_f1: float |
| label_precision: float |
| label_recall: float |
| n_episodes: int |
| n_numeric_eval: int |
| n_label_eval: int |
|
|
|
|
| def s_at_k(y_pred: torch.Tensor, y_true: torch.Tensor, k: float = 0.25) -> float: |
| """Fraction of numerical predictions within +-k relative error of truth. |
| |
| Both tensors should be 1-D floats (same length). Predictions where |
| y_true == 0 are skipped (cannot define relative error). |
| """ |
| if y_pred.numel() == 0: |
| return float("nan") |
| mask = y_true.abs() > 1e-9 |
| if mask.sum() == 0: |
| return float("nan") |
| yt = y_true[mask].float() |
| yp = y_pred[mask].float() |
| rel_err = (yp - yt).abs() / yt.abs() |
| return (rel_err <= k).float().mean().item() |
|
|
|
|
| def smape(y_pred: torch.Tensor, y_true: torch.Tensor) -> float: |
| """Symmetric MAPE: mean of |yp - yt| / ((|yt| + |yp|)/2), in [0, 2].""" |
| if y_pred.numel() == 0: |
| return float("nan") |
| num = (y_pred - y_true).abs() |
| den = (y_pred.abs() + y_true.abs()) / 2.0 |
| den = den.clamp(min=1e-9) |
| return (num / den).mean().item() |
|
|
|
|
| def label_f1(pred_sets: list[set], true_sets: list[set]) -> tuple[float, float, float]: |
| """Multi-label macro-F1 over discrete labels. |
| |
| Each entry of pred_sets / true_sets is a set of label strings (or ids). |
| Returns (precision, recall, f1) micro-averaged across all episodes. |
| """ |
| assert len(pred_sets) == len(true_sets) |
| tp = fp = fn = 0 |
| for p, t in zip(pred_sets, true_sets): |
| tp += len(p & t) |
| fp += len(p - t) |
| fn += len(t - p) |
| prec = tp / max(tp + fp, 1) |
| rec = tp / max(tp + fn, 1) |
| f1 = 2 * prec * rec / max(prec + rec, 1e-9) |
| return prec, rec, f1 |
|
|
|
|
| def long_horizon_retention( |
| model_predictions: list[torch.Tensor], |
| truth: list[torch.Tensor], |
| horizons: list[int] = (1, 5, 10, 20), |
| ) -> dict[int, float]: |
| """Fraction of correct predictions at each horizon (next-event accuracy). |
| |
| EHRWorld reports 92.6% long-horizon retention (some definition of decay |
| from horizon 1 to horizon N). We use Top-1 next-event accuracy at each |
| horizon and compute the retention ratio acc(h) / acc(1). |
| """ |
| out = {} |
| for h in horizons: |
| n_correct = 0 |
| n_total = 0 |
| for pred_seq, true_seq in zip(model_predictions, truth): |
| if true_seq.numel() < h: |
| continue |
| t = true_seq[h - 1] |
| p = pred_seq[h - 1] if pred_seq.numel() >= h else None |
| if p is None: |
| continue |
| n_total += 1 |
| if p.item() == t.item(): |
| n_correct += 1 |
| out[h] = n_correct / max(n_total, 1) |
| return out |
|
|
|
|
| @torch.no_grad() |
| def evaluate_ehrworld_protocol( |
| model, |
| test_dataset, |
| numeric_token_ids: list[int] = None, |
| label_token_ids: list[int] = None, |
| ) -> EHRWorldMetrics: |
| """Run EHRWorld-protocol evaluation on a CWMDataset slice. |
| |
| For GEMEO-CWM: |
| - Numeric tokens = LOS buckets (los_short/week/month/long), age buckets |
| - Label tokens = CID prefixes + outcome tokens (death/discharge) |
| |
| Strategy: |
| - Hide last token of each sequence, predict it |
| - For numeric tokens: convert bucket -> midpoint value for S@25 / SMAPE |
| - For label tokens: predict top-K, compare set vs truth |
| |
| This is a SIMPLIFIED port — full EHRWorld eval requires their internal |
| timestamp + observation schema. Use this as a within-DATASUS comparison |
| metric, not as direct head-to-head. |
| """ |
| device = next(model.parameters()).device |
| if hasattr(test_dataset, 'to'): |
| x_all, cond_all = test_dataset.to(device) |
| else: |
| x_all = test_dataset.to(device) |
| cond_all = torch.zeros(x_all.size(0), dtype=torch.long, device=device) |
|
|
| if numeric_token_ids is None: |
| |
| vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else [] |
| numeric_token_ids = [ |
| i for i, t in enumerate(vocab) |
| if isinstance(t, str) and (t.startswith("los_") or t.startswith("age_")) |
| ] |
| if label_token_ids is None: |
| vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else [] |
| label_token_ids = [ |
| i for i, t in enumerate(vocab) |
| if isinstance(t, str) and (t.startswith("cid_") or "outcome" in t |
| or t.startswith("EV_")) |
| ] |
| numeric_set = set(numeric_token_ids) |
| label_set = set(label_token_ids) |
|
|
| |
| bucket_values = { |
| "los_short": 1.0, "los_week": 4.0, "los_month": 18.0, "los_long": 60.0, |
| "age_0_1": 0.5, "age_1_2": 1.5, "age_2_5": 3.5, "age_5_12": 8.5, |
| "age_12_18": 15.0, "age_18_30": 24.0, "age_30_50": 40.0, |
| "age_50_70": 60.0, "age_70plus": 80.0, |
| } |
|
|
| numeric_preds, numeric_truths = [], [] |
| pred_label_sets, true_label_sets = [], [] |
| pad_id = test_dataset.tok2id.get("<PAD>", 0) if hasattr(test_dataset, 'tok2id') else 0 |
|
|
| for i in range(x_all.size(0)): |
| seq = x_all[i] |
| |
| valid = (seq != pad_id).sum().item() |
| if valid < 5: |
| continue |
| |
| truth_window = seq[max(0, valid - 5):valid].clone() |
| masked_seq = seq.clone() |
| mask_token = model.cfg.mask_token |
| masked_seq[max(0, valid - 5):valid] = mask_token |
| t_zero = torch.tensor([0.05], device=device) |
| cond = cond_all[i:i+1] |
| logits = model(masked_seq.unsqueeze(0), t_zero, cond) |
| logits[:, :, mask_token] = -1e9 |
| preds = logits[0].argmax(dim=-1) |
| pred_window = preds[max(0, valid - 5):valid] |
|
|
| |
| for j in range(pred_window.size(0)): |
| pt, tt = pred_window[j].item(), truth_window[j].item() |
| if pt in numeric_set or tt in numeric_set: |
| vocab = test_dataset.vocab |
| pv = bucket_values.get(vocab[pt], None) if pt < len(vocab) else None |
| tv = bucket_values.get(vocab[tt], None) if tt < len(vocab) else None |
| if pv is not None and tv is not None: |
| numeric_preds.append(pv) |
| numeric_truths.append(tv) |
|
|
| |
| pred_lbl = set(int(x.item()) for x in pred_window if int(x.item()) in label_set) |
| true_lbl = set(int(x.item()) for x in truth_window if int(x.item()) in label_set) |
| if pred_lbl or true_lbl: |
| pred_label_sets.append(pred_lbl) |
| true_label_sets.append(true_lbl) |
|
|
| yp = torch.tensor(numeric_preds, dtype=torch.float) |
| yt = torch.tensor(numeric_truths, dtype=torch.float) |
| s25 = s_at_k(yp, yt, k=0.25) if yp.numel() else float("nan") |
| sm = smape(yp, yt) if yp.numel() else float("nan") |
| prec, rec, f1 = (label_f1(pred_label_sets, true_label_sets) |
| if pred_label_sets else (0.0, 0.0, 0.0)) |
|
|
| return EHRWorldMetrics( |
| s_at_25=s25, smape=sm, label_f1=f1, |
| label_precision=prec, label_recall=rec, |
| n_episodes=x_all.size(0), n_numeric_eval=len(numeric_preds), |
| n_label_eval=len(pred_label_sets), |
| ) |
|
|