"""Evaluate a trained checkpoint on PTB-XL AF + downstream probes. Loads the model from `--ckpt`, fetches PTB-XL via HF, extracts pooled latents from the ECG encoder, runs a logistic-regression linear probe, and writes results JSON. Used at epoch 25 (K-gate eval) and epoch 100 (final eval). """ from __future__ import annotations import argparse import json import os from pathlib import Path import numpy as np import torch from dotenv import load_dotenv load_dotenv() os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) import sys sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) from physiojepa.models import MODEL_REGISTRY, ModelConfig from physiojepa.probe import linear_probe_auroc, pooled_features def get_ecg_encoder(model_letter: str, model: torch.nn.Module) -> torch.nn.Module: if model_letter == "A": return model.ecg if model_letter == "C": return model.ecg return model.bb.ecg def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--ckpt", required=True) ap.add_argument("--model", required=True, choices=["A", "B", "C", "F"]) ap.add_argument("--ptbxl_npz", default="/workspace/cache/ptbxl_af.npz") ap.add_argument("--out", required=True) args = ap.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") sd = torch.load(args.ckpt, map_location=device, weights_only=False) saved_cfg = sd.get("cfg", {}) # Respect ablation knobs saved in the TrainConfig cfg = ModelConfig( pred_depth=saved_cfg.get("pred_depth", 4), query_mode=saved_cfg.get("query_mode", "learned"), mask_ratio=saved_cfg.get("mask_ratio", 0.50), ) print(f"[eval] model cfg: pred_depth={cfg.pred_depth} query_mode={cfg.query_mode} mask_ratio={cfg.mask_ratio}") model = MODEL_REGISTRY[args.model](cfg) model.load_state_dict(sd["model"]) model.to(device) model.train(False) enc = get_ecg_encoder(args.model, model) print(f"[eval] loading PTB-XL cache from {args.ptbxl_npz}") arr = np.load(args.ptbxl_npz) X, y = arr["X"], arr["y"] print(f"[eval] X={X.shape} y_pos={int(y.sum())} y_neg={int((1 - y).sum())}") X_t = torch.from_numpy(X) feats = pooled_features(enc, X_t, device=device, batch_size=64) rng = np.random.default_rng(0) idx = rng.permutation(len(y)) cut = int(len(idx) * 0.8) train_idx, test_idx = idx[:cut], idx[cut:] auroc = linear_probe_auroc(feats[train_idx], y[train_idx], feats[test_idx], y[test_idx]) print(f"[eval] AF AUROC = {auroc:.4f}") Path(args.out).parent.mkdir(parents=True, exist_ok=True) Path(args.out).write_text(json.dumps({ "ckpt": args.ckpt, "model": args.model, "auroc": auroc, "n_train": int(cut), "n_test": int(len(idx) - cut), "n_pos": int(y.sum()), "n_neg": int((1 - y).sum()), }, indent=2)) if __name__ == "__main__": main()