PhysioJEPA / src /physiojepa /probe.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Linear probe + simple evaluators for frozen encoders.
AF AUROC on PTB-XL (lead II ECG, resampled 500->250 Hz), HR R^2, retrieval,
PTT regression (MLP).
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import mean_absolute_error, r2_score, roc_auc_score
from sklearn.neural_network import MLPRegressor
@torch.no_grad()
def pooled_features(encoder: torch.nn.Module, x: torch.Tensor, device: torch.device,
batch_size: int = 64) -> np.ndarray:
encoder.train(False)
feats = []
for i in range(0, len(x), batch_size):
chunk = x[i : i + batch_size].to(device)
z = encoder(chunk) # [B, N, d]
feats.append(z.mean(dim=1).cpu().numpy())
return np.concatenate(feats, axis=0)
def linear_probe_auroc(
train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray,
max_iter: int = 2000, C: float = 1.0,
) -> float:
clf = LogisticRegression(max_iter=max_iter, C=C, solver="lbfgs")
clf.fit(train_X, train_y)
return float(roc_auc_score(test_y, clf.predict_proba(test_X)[:, 1]))
def linear_probe_r2(
train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray
) -> float:
reg = Ridge(alpha=1.0)
reg.fit(train_X, train_y)
return float(r2_score(test_y, reg.predict(test_X)))
def mlp_probe_mae(
train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray,
hidden: tuple[int, ...] = (128,), max_iter: int = 200,
) -> float:
m = MLPRegressor(hidden_layer_sizes=hidden, max_iter=max_iter, random_state=0)
m.fit(train_X, train_y)
return float(mean_absolute_error(test_y, m.predict(test_X)))
def retrieval_recall(z_query: np.ndarray, z_gallery: np.ndarray, k_list=(1, 5, 10)) -> dict:
# normalize
qn = z_query / (np.linalg.norm(z_query, axis=1, keepdims=True) + 1e-9)
gn = z_gallery / (np.linalg.norm(z_gallery, axis=1, keepdims=True) + 1e-9)
sim = qn @ gn.T # [Q, G]
n = sim.shape[0]
ranks = (-sim).argsort(axis=1)
gt = np.arange(n)
out = {}
for k in k_list:
top = ranks[:, :k]
hits = (top == gt[:, None]).any(axis=1).mean()
out[f"R@{k}"] = float(hits)
return out