CausalGrok / code /utils /metrics.py
nileshsarkar-ai's picture
Upload code/utils
9d2fc01 verified
"""
utils.metrics — generic, reusable metric functions.
Shared by every experiment in experiments/. Anything specific to a
particular dataset or training recipe stays out of this module.
"""
from __future__ import annotations
import numpy as np
import torch
import torch.nn.functional as F
@torch.no_grad()
def accuracy(model, loader, device):
model.eval()
correct = total = 0
for imgs, labels in loader:
imgs = imgs.to(device)
labels = labels.squeeze().long().to(device)
preds = model(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
@torch.no_grad()
def weight_norm(model):
return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
@torch.no_grad()
def feature_rank(model, loader, device, n=200, hook_module_attr="avgpool"):
"""
Effective rank of penultimate-layer features = exp(entropy of normalised
singular values). Fan et al. 2024 — the most reliable progress measure
for grokking. The rank collapses at the transition.
`hook_module_attr` is the attribute name on the model whose forward output
we treat as the penultimate representation. Defaults to ResNet's avgpool.
"""
model.eval()
feats = []
target = getattr(model, hook_module_attr)
hook = target.register_forward_hook(
lambda m, i, o: feats.append(o.view(o.size(0), -1).cpu())
)
count = 0
for imgs, _ in loader:
model(imgs.to(device))
count += imgs.size(0)
if count >= n:
break
hook.remove()
F_mat = torch.cat(feats)[:n]
try:
_, s, _ = torch.svd(F_mat)
s = s / (s.sum() + 1e-10)
return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item()
except Exception:
return float("nan")
def irm_penalty(model, envs, device):
"""
IRMv1 penalty (Arjovsky et al. 2019). For each environment, the squared
gradient of the loss w.r.t. a dummy scalar w=1 in the logits.
LOW value = invariant predictor = causal features found.
HIGH value = environment-specific (spurious) features still in use.
`envs` is a list of dicts {"x": tensor, "y": tensor} already on `device`.
Returns (mean, var) over environments.
"""
model.eval()
penalties = []
for env in envs:
w = torch.tensor(1.0, requires_grad=True, device=device)
logits = model(env["x"]) * w
loss = F.cross_entropy(logits, env["y"])
grad = torch.autograd.grad(loss, w, create_graph=False)[0]
penalties.append(grad.item() ** 2)
t = torch.tensor(penalties)
return t.mean().item(), t.var().item()
@torch.no_grad()
def shortcut_ratio(model, loader, device):
"""
Border-confidence / center-confidence proxy for artifact reliance.
Ratio > 1 means the model trusts the borders (where scanner markers,
laterality letters, and other artefacts live) more than the center
(where actual anatomy is). On CheXpert, replace with GradCAM pointed at
known artifact locations vs. anatomical regions.
Returns (center_conf_mean, border_conf_mean).
"""
model.eval()
cc, bc = [], []
for imgs, _ in loader:
imgs = imgs.to(device)
B, C, H, W = imgs.shape
hs, he = H // 4, 3 * H // 4
ws, we = W // 4, 3 * W // 4
center = F.interpolate(imgs[:, :, hs:he, ws:we], size=(H, W),
mode='bilinear', align_corners=False)
border = imgs.clone(); border[:, :, hs:he, ws:we] = 0.
cc.append(F.softmax(model(center), 1).max(1).values.mean().item())
bc.append(F.softmax(model(border), 1).max(1).values.mean().item())
return float(np.mean(cc)), float(np.mean(bc))