"""EHRWorld-protocol evaluation metrics, ported for any EHR foundation model. EHRWorld (arXiv 2602.03569, SJTU Jan 2026) reports three headline metrics on their 579-episode MIMIC-IV test set: - S@25 (success at 25%): for numerical values (lab results, vitals), fraction predicted within +-25% relative error of ground truth. - SMAPE: symmetric mean absolute percentage error, SMAPE = (1/N) * sum( |y_hat - y| / ((|y| + |y_hat|)/2) ) bounded in [0,2], lower is better. - Label-F1: precision/recall F1 over discrete diagnostic labels (set-based, multi-label). We re-implement them here so any GEMEO-CWM checkpoint can be scored on its own DATASUS test slice with EHRWorld-comparable numbers. We cannot run on their actual test set without MIMIC-IV credentialing, but the methodological alignment is legitimate for a "rare-disease + PT-BR" framing. Reference: Section 4.2 + Appendix C of EHRWorld paper. """ from __future__ import annotations import math import logging from dataclasses import dataclass import torch log = logging.getLogger("gemeo.cwm.eval_protocol") @dataclass class EHRWorldMetrics: s_at_25: float smape: float label_f1: float label_precision: float label_recall: float n_episodes: int n_numeric_eval: int n_label_eval: int def s_at_k(y_pred: torch.Tensor, y_true: torch.Tensor, k: float = 0.25) -> float: """Fraction of numerical predictions within +-k relative error of truth. Both tensors should be 1-D floats (same length). Predictions where y_true == 0 are skipped (cannot define relative error). """ if y_pred.numel() == 0: return float("nan") mask = y_true.abs() > 1e-9 if mask.sum() == 0: return float("nan") yt = y_true[mask].float() yp = y_pred[mask].float() rel_err = (yp - yt).abs() / yt.abs() return (rel_err <= k).float().mean().item() def smape(y_pred: torch.Tensor, y_true: torch.Tensor) -> float: """Symmetric MAPE: mean of |yp - yt| / ((|yt| + |yp|)/2), in [0, 2].""" if y_pred.numel() == 0: return float("nan") num = (y_pred - y_true).abs() den = (y_pred.abs() + y_true.abs()) / 2.0 den = den.clamp(min=1e-9) return (num / den).mean().item() def label_f1(pred_sets: list[set], true_sets: list[set]) -> tuple[float, float, float]: """Multi-label macro-F1 over discrete labels. Each entry of pred_sets / true_sets is a set of label strings (or ids). Returns (precision, recall, f1) micro-averaged across all episodes. """ assert len(pred_sets) == len(true_sets) tp = fp = fn = 0 for p, t in zip(pred_sets, true_sets): tp += len(p & t) fp += len(p - t) fn += len(t - p) prec = tp / max(tp + fp, 1) rec = tp / max(tp + fn, 1) f1 = 2 * prec * rec / max(prec + rec, 1e-9) return prec, rec, f1 def long_horizon_retention( model_predictions: list[torch.Tensor], truth: list[torch.Tensor], horizons: list[int] = (1, 5, 10, 20), ) -> dict[int, float]: """Fraction of correct predictions at each horizon (next-event accuracy). EHRWorld reports 92.6% long-horizon retention (some definition of decay from horizon 1 to horizon N). We use Top-1 next-event accuracy at each horizon and compute the retention ratio acc(h) / acc(1). """ out = {} for h in horizons: n_correct = 0 n_total = 0 for pred_seq, true_seq in zip(model_predictions, truth): if true_seq.numel() < h: continue t = true_seq[h - 1] p = pred_seq[h - 1] if pred_seq.numel() >= h else None if p is None: continue n_total += 1 if p.item() == t.item(): n_correct += 1 out[h] = n_correct / max(n_total, 1) return out @torch.no_grad() def evaluate_ehrworld_protocol( model, test_dataset, numeric_token_ids: list[int] = None, label_token_ids: list[int] = None, ) -> EHRWorldMetrics: """Run EHRWorld-protocol evaluation on a CWMDataset slice. For GEMEO-CWM: - Numeric tokens = LOS buckets (los_short/week/month/long), age buckets - Label tokens = CID prefixes + outcome tokens (death/discharge) Strategy: - Hide last token of each sequence, predict it - For numeric tokens: convert bucket -> midpoint value for S@25 / SMAPE - For label tokens: predict top-K, compare set vs truth This is a SIMPLIFIED port — full EHRWorld eval requires their internal timestamp + observation schema. Use this as a within-DATASUS comparison metric, not as direct head-to-head. """ device = next(model.parameters()).device if hasattr(test_dataset, 'to'): x_all, cond_all = test_dataset.to(device) else: x_all = test_dataset.to(device) cond_all = torch.zeros(x_all.size(0), dtype=torch.long, device=device) if numeric_token_ids is None: # Default: try to identify LOS / age bucket tokens by string prefix vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else [] numeric_token_ids = [ i for i, t in enumerate(vocab) if isinstance(t, str) and (t.startswith("los_") or t.startswith("age_")) ] if label_token_ids is None: vocab = test_dataset.vocab if hasattr(test_dataset, 'vocab') else [] label_token_ids = [ i for i, t in enumerate(vocab) if isinstance(t, str) and (t.startswith("cid_") or "outcome" in t or t.startswith("EV_")) ] numeric_set = set(numeric_token_ids) label_set = set(label_token_ids) # Bucket value map (midpoints) bucket_values = { "los_short": 1.0, "los_week": 4.0, "los_month": 18.0, "los_long": 60.0, "age_0_1": 0.5, "age_1_2": 1.5, "age_2_5": 3.5, "age_5_12": 8.5, "age_12_18": 15.0, "age_18_30": 24.0, "age_30_50": 40.0, "age_50_70": 60.0, "age_70plus": 80.0, } numeric_preds, numeric_truths = [], [] pred_label_sets, true_label_sets = [], [] pad_id = test_dataset.tok2id.get("", 0) if hasattr(test_dataset, 'tok2id') else 0 for i in range(x_all.size(0)): seq = x_all[i] # find last non-pad position valid = (seq != pad_id).sum().item() if valid < 5: continue # Mask the last 5 positions, ask the model to fill truth_window = seq[max(0, valid - 5):valid].clone() masked_seq = seq.clone() mask_token = model.cfg.mask_token masked_seq[max(0, valid - 5):valid] = mask_token t_zero = torch.tensor([0.05], device=device) cond = cond_all[i:i+1] logits = model(masked_seq.unsqueeze(0), t_zero, cond) logits[:, :, mask_token] = -1e9 preds = logits[0].argmax(dim=-1) pred_window = preds[max(0, valid - 5):valid] # Numeric: convert tokens to bucket values for j in range(pred_window.size(0)): pt, tt = pred_window[j].item(), truth_window[j].item() if pt in numeric_set or tt in numeric_set: vocab = test_dataset.vocab pv = bucket_values.get(vocab[pt], None) if pt < len(vocab) else None tv = bucket_values.get(vocab[tt], None) if tt < len(vocab) else None if pv is not None and tv is not None: numeric_preds.append(pv) numeric_truths.append(tv) # Labels: collect sets pred_lbl = set(int(x.item()) for x in pred_window if int(x.item()) in label_set) true_lbl = set(int(x.item()) for x in truth_window if int(x.item()) in label_set) if pred_lbl or true_lbl: pred_label_sets.append(pred_lbl) true_label_sets.append(true_lbl) yp = torch.tensor(numeric_preds, dtype=torch.float) yt = torch.tensor(numeric_truths, dtype=torch.float) s25 = s_at_k(yp, yt, k=0.25) if yp.numel() else float("nan") sm = smape(yp, yt) if yp.numel() else float("nan") prec, rec, f1 = (label_f1(pred_label_sets, true_label_sets) if pred_label_sets else (0.0, 0.0, 0.0)) return EHRWorldMetrics( s_at_25=s25, smape=sm, label_f1=f1, label_precision=prec, label_recall=rec, n_episodes=x_all.size(0), n_numeric_eval=len(numeric_preds), n_label_eval=len(pred_label_sets), )