"""Linear probe + simple evaluators for frozen encoders. AF AUROC on PTB-XL (lead II ECG, resampled 500->250 Hz), HR R^2, retrieval, PTT regression (MLP). """ from __future__ import annotations from pathlib import Path import numpy as np import torch from sklearn.linear_model import LogisticRegression, Ridge from sklearn.metrics import mean_absolute_error, r2_score, roc_auc_score from sklearn.neural_network import MLPRegressor @torch.no_grad() def pooled_features(encoder: torch.nn.Module, x: torch.Tensor, device: torch.device, batch_size: int = 64) -> np.ndarray: encoder.train(False) feats = [] for i in range(0, len(x), batch_size): chunk = x[i : i + batch_size].to(device) z = encoder(chunk) # [B, N, d] feats.append(z.mean(dim=1).cpu().numpy()) return np.concatenate(feats, axis=0) def linear_probe_auroc( train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray, max_iter: int = 2000, C: float = 1.0, ) -> float: clf = LogisticRegression(max_iter=max_iter, C=C, solver="lbfgs") clf.fit(train_X, train_y) return float(roc_auc_score(test_y, clf.predict_proba(test_X)[:, 1])) def linear_probe_r2( train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray ) -> float: reg = Ridge(alpha=1.0) reg.fit(train_X, train_y) return float(r2_score(test_y, reg.predict(test_X))) def mlp_probe_mae( train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray, hidden: tuple[int, ...] = (128,), max_iter: int = 200, ) -> float: m = MLPRegressor(hidden_layer_sizes=hidden, max_iter=max_iter, random_state=0) m.fit(train_X, train_y) return float(mean_absolute_error(test_y, m.predict(test_X))) def retrieval_recall(z_query: np.ndarray, z_gallery: np.ndarray, k_list=(1, 5, 10)) -> dict: # normalize qn = z_query / (np.linalg.norm(z_query, axis=1, keepdims=True) + 1e-9) gn = z_gallery / (np.linalg.norm(z_gallery, axis=1, keepdims=True) + 1e-9) sim = qn @ gn.T # [Q, G] n = sim.shape[0] ranks = (-sim).argsort(axis=1) gt = np.arange(n) out = {} for k in k_list: top = ranks[:, :k] hits = (top == gt[:, None]).any(axis=1).mean() out[f"R@{k}"] = float(hits) return out