File size: 2,323 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Linear probe + simple evaluators for frozen encoders.

AF AUROC on PTB-XL (lead II ECG, resampled 500->250 Hz), HR R^2, retrieval,
PTT regression (MLP).
"""
from __future__ import annotations

from pathlib import Path

import numpy as np
import torch
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import mean_absolute_error, r2_score, roc_auc_score
from sklearn.neural_network import MLPRegressor


@torch.no_grad()
def pooled_features(encoder: torch.nn.Module, x: torch.Tensor, device: torch.device,
                    batch_size: int = 64) -> np.ndarray:
    encoder.train(False)
    feats = []
    for i in range(0, len(x), batch_size):
        chunk = x[i : i + batch_size].to(device)
        z = encoder(chunk)  # [B, N, d]
        feats.append(z.mean(dim=1).cpu().numpy())
    return np.concatenate(feats, axis=0)


def linear_probe_auroc(
    train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray,
    max_iter: int = 2000, C: float = 1.0,
) -> float:
    clf = LogisticRegression(max_iter=max_iter, C=C, solver="lbfgs")
    clf.fit(train_X, train_y)
    return float(roc_auc_score(test_y, clf.predict_proba(test_X)[:, 1]))


def linear_probe_r2(
    train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray
) -> float:
    reg = Ridge(alpha=1.0)
    reg.fit(train_X, train_y)
    return float(r2_score(test_y, reg.predict(test_X)))


def mlp_probe_mae(
    train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray,
    hidden: tuple[int, ...] = (128,), max_iter: int = 200,
) -> float:
    m = MLPRegressor(hidden_layer_sizes=hidden, max_iter=max_iter, random_state=0)
    m.fit(train_X, train_y)
    return float(mean_absolute_error(test_y, m.predict(test_X)))


def retrieval_recall(z_query: np.ndarray, z_gallery: np.ndarray, k_list=(1, 5, 10)) -> dict:
    # normalize
    qn = z_query / (np.linalg.norm(z_query, axis=1, keepdims=True) + 1e-9)
    gn = z_gallery / (np.linalg.norm(z_gallery, axis=1, keepdims=True) + 1e-9)
    sim = qn @ gn.T  # [Q, G]
    n = sim.shape[0]
    ranks = (-sim).argsort(axis=1)
    gt = np.arange(n)
    out = {}
    for k in k_list:
        top = ranks[:, :k]
        hits = (top == gt[:, None]).any(axis=1).mean()
        out[f"R@{k}"] = float(hits)
    return out