detection-heads / scripts /_person_mlp.py
phanerozoic's picture
update repository
74e3c01
"""Train a tiny MLP on 92 evolved dims for image-level person classification."""
import json, os, torch, torch.nn as nn
import torch.nn.functional as F
from pycocotools.coco import COCO
COCO_ROOT = os.environ["ARENA_COCO_ROOT"]
VAL_CACHE = os.environ["ARENA_VAL_CACHE"]
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(SCRIPT_DIR, "..", "circuit", "evolved_extreme.json")) as f:
evolved = json.load(f)
dims = sorted(list(set([r for r in evolved if r["K"] == 100][0]["genome"])))
N = len(dims)
val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
coco = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
PERSON_CAT = 1
def cofiber_decompose(f, n_scales):
cofibers = []; residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega); residual = omega
cofibers.append(residual); return cofibers
print("Pre-computing image vectors...", flush=True)
all_vecs = []
all_labels = []
for idx in range(len(val)):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float()
cofibers = cofiber_decompose(spatial, 3)
feats = []
for cof in cofibers:
B, C, Hc, Wc = cof.shape
f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
feats.append(f)
all_vecs.append(torch.cat(feats)[:, dims].max(dim=0).values)
hp = len(coco.getAnnIds(imgIds=int(item["img_id"]), catIds=[PERSON_CAT], iscrowd=False)) > 0
all_labels.append(1.0 if hp else 0.0)
if (idx + 1) % 1000 == 0:
print(f" {idx+1}/{len(val)}", flush=True)
X = torch.stack(all_vecs).cuda()
Y = torch.tensor(all_labels).cuda()
# 5-fold CV with MLP
print(f"\n5-fold CV with MLPs on {N} evolved dims\n", flush=True)
for hidden, layers_desc in [(32, "92->32->1"), (64, "92->64->1"),
(128, "92->64->64->1"), (256, "92->128->64->1")]:
fold_size = 1000
all_tp = all_fp = all_fn = all_tn = 0
for fold in range(5):
test_mask = torch.zeros(len(val), dtype=torch.bool, device="cuda")
test_mask[fold * fold_size:(fold + 1) * fold_size] = True
train_mask = ~test_mask
train_x = X[train_mask]
train_y = Y[train_mask]
test_x = X[test_mask]
test_y = Y[test_mask]
# Build MLP
if layers_desc == "92->32->1":
model = nn.Sequential(nn.Linear(N, 32), nn.GELU(), nn.Linear(32, 1)).cuda()
elif layers_desc == "92->64->1":
model = nn.Sequential(nn.Linear(N, 64), nn.GELU(), nn.Linear(64, 1)).cuda()
elif layers_desc == "92->64->64->1":
model = nn.Sequential(nn.Linear(N, 64), nn.GELU(), nn.Linear(64, 64), nn.GELU(), nn.Linear(64, 1)).cuda()
else:
model = nn.Sequential(nn.Linear(N, 128), nn.GELU(), nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 1)).cuda()
n_params = sum(p.numel() for p in model.parameters())
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
# Train
model.train()
for epoch in range(200):
idx = torch.randperm(train_x.shape[0], device="cuda")
for start in range(0, len(idx), 256):
batch = idx[start:start + 256]
logits = model(train_x[batch]).squeeze()
loss = F.binary_cross_entropy_with_logits(logits, train_y[batch])
opt.zero_grad(); loss.backward(); opt.step()
# Eval at multiple thresholds
model.eval()
with torch.no_grad():
scores = model(test_x).squeeze().sigmoid()
# Find best threshold for 99% precision
best_t = 0.5
best_rec = 0.0
for t_int in range(50, 100):
t = t_int / 100.0
pred = scores > t
tp = (pred & test_y.bool()).sum().item()
fp = (pred & ~test_y.bool()).sum().item()
fn = (~pred & test_y.bool()).sum().item()
prec = tp / max(tp + fp, 1)
rec = tp / max(tp + fn, 1)
if prec >= 0.99 and rec > best_rec:
best_rec = rec
best_t = t
pred = scores > best_t
tp = (pred & test_y.bool()).sum().item()
fp = (pred & ~test_y.bool()).sum().item()
fn = (~pred & test_y.bool()).sum().item()
tn = (~pred & ~test_y.bool()).sum().item()
all_tp += tp; all_fp += fp; all_fn += fn; all_tn += tn
prec = all_tp / max(all_tp + all_fp, 1)
rec = all_tp / max(all_tp + all_fn, 1)
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
acc = (all_tp + all_tn) / 5000
print(f" {layers_desc:20s} ({n_params:5d} params): P={prec:.3f} R={rec:.3f} F1={f1:.3f} acc={acc:.3f} "
f"(TP={all_tp} FP={all_fp} FN={all_fn} TN={all_tn})")
# Also test at threshold 0.5 for best F1
print(f"\nSame models at threshold=0.5 (best F1):\n")
for hidden, layers_desc in [(32, "92->32->1"), (64, "92->64->1"),
(128, "92->64->64->1"), (256, "92->128->64->1")]:
fold_size = 1000
all_tp = all_fp = all_fn = all_tn = 0
for fold in range(5):
test_mask = torch.zeros(len(val), dtype=torch.bool, device="cuda")
test_mask[fold * fold_size:(fold + 1) * fold_size] = True
train_mask = ~test_mask
train_x = X[train_mask]; train_y = Y[train_mask]
test_x = X[test_mask]; test_y = Y[test_mask]
if layers_desc == "92->32->1":
model = nn.Sequential(nn.Linear(N, 32), nn.GELU(), nn.Linear(32, 1)).cuda()
elif layers_desc == "92->64->1":
model = nn.Sequential(nn.Linear(N, 64), nn.GELU(), nn.Linear(64, 1)).cuda()
elif layers_desc == "92->64->64->1":
model = nn.Sequential(nn.Linear(N, 64), nn.GELU(), nn.Linear(64, 64), nn.GELU(), nn.Linear(64, 1)).cuda()
else:
model = nn.Sequential(nn.Linear(N, 128), nn.GELU(), nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 1)).cuda()
n_params = sum(p.numel() for p in model.parameters())
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(200):
idx = torch.randperm(train_x.shape[0], device="cuda")
for start in range(0, len(idx), 256):
batch = idx[start:start + 256]
logits = model(train_x[batch]).squeeze()
loss = F.binary_cross_entropy_with_logits(logits, train_y[batch])
opt.zero_grad(); loss.backward(); opt.step()
model.eval()
with torch.no_grad():
pred = model(test_x).squeeze().sigmoid() > 0.5
tp = (pred & test_y.bool()).sum().item()
fp = (pred & ~test_y.bool()).sum().item()
fn = (~pred & test_y.bool()).sum().item()
tn = (~pred & ~test_y.bool()).sum().item()
all_tp += tp; all_fp += fp; all_fn += fn; all_tn += tn
prec = all_tp / max(all_tp + all_fp, 1)
rec = all_tp / max(all_tp + all_fn, 1)
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
print(f" {layers_desc:20s} ({n_params:5d} params): P={prec:.3f} R={rec:.3f} F1={f1:.3f}")