| """Linear probe + simple evaluators for frozen encoders. |
| |
| AF AUROC on PTB-XL (lead II ECG, resampled 500->250 Hz), HR R^2, retrieval, |
| PTT regression (MLP). |
| """ |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from sklearn.linear_model import LogisticRegression, Ridge |
| from sklearn.metrics import mean_absolute_error, r2_score, roc_auc_score |
| from sklearn.neural_network import MLPRegressor |
|
|
|
|
| @torch.no_grad() |
| def pooled_features(encoder: torch.nn.Module, x: torch.Tensor, device: torch.device, |
| batch_size: int = 64) -> np.ndarray: |
| encoder.train(False) |
| feats = [] |
| for i in range(0, len(x), batch_size): |
| chunk = x[i : i + batch_size].to(device) |
| z = encoder(chunk) |
| feats.append(z.mean(dim=1).cpu().numpy()) |
| return np.concatenate(feats, axis=0) |
|
|
|
|
| def linear_probe_auroc( |
| train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray, |
| max_iter: int = 2000, C: float = 1.0, |
| ) -> float: |
| clf = LogisticRegression(max_iter=max_iter, C=C, solver="lbfgs") |
| clf.fit(train_X, train_y) |
| return float(roc_auc_score(test_y, clf.predict_proba(test_X)[:, 1])) |
|
|
|
|
| def linear_probe_r2( |
| train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray |
| ) -> float: |
| reg = Ridge(alpha=1.0) |
| reg.fit(train_X, train_y) |
| return float(r2_score(test_y, reg.predict(test_X))) |
|
|
|
|
| def mlp_probe_mae( |
| train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray, |
| hidden: tuple[int, ...] = (128,), max_iter: int = 200, |
| ) -> float: |
| m = MLPRegressor(hidden_layer_sizes=hidden, max_iter=max_iter, random_state=0) |
| m.fit(train_X, train_y) |
| return float(mean_absolute_error(test_y, m.predict(test_X))) |
|
|
|
|
| def retrieval_recall(z_query: np.ndarray, z_gallery: np.ndarray, k_list=(1, 5, 10)) -> dict: |
| |
| qn = z_query / (np.linalg.norm(z_query, axis=1, keepdims=True) + 1e-9) |
| gn = z_gallery / (np.linalg.norm(z_gallery, axis=1, keepdims=True) + 1e-9) |
| sim = qn @ gn.T |
| n = sim.shape[0] |
| ranks = (-sim).argsort(axis=1) |
| gt = np.arange(n) |
| out = {} |
| for k in k_list: |
| top = ranks[:, :k] |
| hits = (top == gt[:, None]).any(axis=1).mean() |
| out[f"R@{k}"] = float(hits) |
| return out |
|
|