| """ |
| utils.metrics — generic, reusable metric functions. |
| |
| Shared by every experiment in experiments/. Anything specific to a |
| particular dataset or training recipe stays out of this module. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| @torch.no_grad() |
| def accuracy(model, loader, device): |
| model.eval() |
| correct = total = 0 |
| for imgs, labels in loader: |
| imgs = imgs.to(device) |
| labels = labels.squeeze().long().to(device) |
| preds = model(imgs).argmax(1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| return correct / total |
|
|
|
|
| @torch.no_grad() |
| def weight_norm(model): |
| return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
|
|
|
|
| @torch.no_grad() |
| def feature_rank(model, loader, device, n=200, hook_module_attr="avgpool"): |
| """ |
| Effective rank of penultimate-layer features = exp(entropy of normalised |
| singular values). Fan et al. 2024 — the most reliable progress measure |
| for grokking. The rank collapses at the transition. |
| |
| `hook_module_attr` is the attribute name on the model whose forward output |
| we treat as the penultimate representation. Defaults to ResNet's avgpool. |
| """ |
| model.eval() |
| feats = [] |
| target = getattr(model, hook_module_attr) |
| hook = target.register_forward_hook( |
| lambda m, i, o: feats.append(o.view(o.size(0), -1).cpu()) |
| ) |
| count = 0 |
| for imgs, _ in loader: |
| model(imgs.to(device)) |
| count += imgs.size(0) |
| if count >= n: |
| break |
| hook.remove() |
| F_mat = torch.cat(feats)[:n] |
| try: |
| _, s, _ = torch.svd(F_mat) |
| s = s / (s.sum() + 1e-10) |
| return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item() |
| except Exception: |
| return float("nan") |
|
|
|
|
| def irm_penalty(model, envs, device): |
| """ |
| IRMv1 penalty (Arjovsky et al. 2019). For each environment, the squared |
| gradient of the loss w.r.t. a dummy scalar w=1 in the logits. |
| |
| LOW value = invariant predictor = causal features found. |
| HIGH value = environment-specific (spurious) features still in use. |
| |
| `envs` is a list of dicts {"x": tensor, "y": tensor} already on `device`. |
| Returns (mean, var) over environments. |
| """ |
| model.eval() |
| penalties = [] |
| for env in envs: |
| w = torch.tensor(1.0, requires_grad=True, device=device) |
| logits = model(env["x"]) * w |
| loss = F.cross_entropy(logits, env["y"]) |
| grad = torch.autograd.grad(loss, w, create_graph=False)[0] |
| penalties.append(grad.item() ** 2) |
| t = torch.tensor(penalties) |
| return t.mean().item(), t.var().item() |
|
|
|
|
| @torch.no_grad() |
| def shortcut_ratio(model, loader, device): |
| """ |
| Border-confidence / center-confidence proxy for artifact reliance. |
| Ratio > 1 means the model trusts the borders (where scanner markers, |
| laterality letters, and other artefacts live) more than the center |
| (where actual anatomy is). On CheXpert, replace with GradCAM pointed at |
| known artifact locations vs. anatomical regions. |
| |
| Returns (center_conf_mean, border_conf_mean). |
| """ |
| model.eval() |
| cc, bc = [], [] |
| for imgs, _ in loader: |
| imgs = imgs.to(device) |
| B, C, H, W = imgs.shape |
| hs, he = H // 4, 3 * H // 4 |
| ws, we = W // 4, 3 * W // 4 |
| center = F.interpolate(imgs[:, :, hs:he, ws:we], size=(H, W), |
| mode='bilinear', align_corners=False) |
| border = imgs.clone(); border[:, :, hs:he, ws:we] = 0. |
| cc.append(F.softmax(model(center), 1).max(1).values.mean().item()) |
| bc.append(F.softmax(model(border), 1).max(1).values.mean().item()) |
| return float(np.mean(cc)), float(np.mean(bc)) |
|
|