detection-heads / scripts /_person_precision.py
phanerozoic's picture
update repository
74e3c01
"""Push person classifier precision toward 99% while keeping recall high."""
import json, os, torch
import torch.nn.functional as F
from pycocotools.coco import COCO
COCO_ROOT = os.environ["ARENA_COCO_ROOT"]
VAL_CACHE = os.environ["ARENA_VAL_CACHE"]
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(SCRIPT_DIR, "..", "circuit", "evolved_extreme.json")) as f:
evolved = json.load(f)
dims_100 = sorted(list(set([r for r in evolved if r["K"] == 100][0]["genome"])))
val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
coco = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
PERSON_CAT = 1
def cofiber_decompose(f, n_scales):
cofibers = []; residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega); residual = omega
cofibers.append(residual); return cofibers
def has_person(img_id):
return len(coco.getAnnIds(imgIds=img_id, catIds=[PERSON_CAT], iscrowd=False)) > 0
# Pre-compute ALL image vectors at full 768 dims
print("Pre-computing image vectors...", flush=True)
all_vecs_768 = []
all_labels = []
for idx in range(len(val)):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float()
cofibers = cofiber_decompose(spatial, 3)
feats = []
for cof in cofibers:
B, C, Hc, Wc = cof.shape
f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
feats.append(f)
# Max-pool across all 2100 locations
all_f = torch.cat(feats) # (2100, 768)
all_vecs_768.append(all_f.max(dim=0).values)
all_labels.append(has_person(int(item["img_id"])))
if (idx + 1) % 1000 == 0:
print(f" {idx+1}/{len(val)}", flush=True)
all_vecs_768 = torch.stack(all_vecs_768)
all_labels = torch.tensor(all_labels, dtype=torch.bool)
print(f" {all_labels.sum()} person, {(~all_labels).sum()} non-person\n")
# Split: first 4000 train, last 1000 test
train_f = all_vecs_768[:4000]
train_y = all_labels[:4000].float().unsqueeze(1)
test_f = all_vecs_768[4000:]
test_y = all_labels[4000:]
def solve_and_eval(features_train, y_train, features_test, labels_test, lam=0.1):
n = features_train.shape[0]
d = features_train.shape[1]
fa = torch.cat([features_train, torch.ones(n, 1)], 1)
I = torch.eye(d + 1)
W = torch.linalg.solve(fa.T @ fa + lam * I * n, fa.T @ y_train)
scores = (features_test @ W[:d].squeeze() + W[d].squeeze().item()).sigmoid()
return scores, W
# ============================================================
# Experiment 1: Threshold sweep on 92 evolved dims
# ============================================================
print("=== Exp 1: Threshold sweep (92 evolved dims) ===")
scores_92, _ = solve_and_eval(train_f[:, dims_100], train_y, test_f[:, dims_100], test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_92 > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")
# ============================================================
# Experiment 2: Full 768 dims
# ============================================================
print(f"\n=== Exp 2: Full 768 dims ===")
scores_768, _ = solve_and_eval(train_f, train_y, test_f, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_768 > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")
# ============================================================
# Experiment 3: Full 768 dims + mean pooling instead of max
# ============================================================
print(f"\n=== Exp 3: Full 768 dims, mean-pool ===")
all_vecs_mean = []
for idx in range(len(val)):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float()
cofibers = cofiber_decompose(spatial, 3)
feats = []
for cof in cofibers:
B, C, Hc, Wc = cof.shape
f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
feats.append(f)
all_vecs_mean.append(torch.cat(feats).mean(dim=0))
all_vecs_mean = torch.stack(all_vecs_mean)
train_fm = all_vecs_mean[:4000]
test_fm = all_vecs_mean[4000:]
scores_mean, _ = solve_and_eval(train_fm, train_y, test_fm, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_mean > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")
# ============================================================
# Experiment 4: 768 max + 768 mean concatenated (1536 dims)
# ============================================================
print(f"\n=== Exp 4: 768 max + 768 mean (1536 dims) ===")
all_vecs_cat = torch.cat([all_vecs_768, all_vecs_mean], dim=1)
train_fc = all_vecs_cat[:4000]
test_fc = all_vecs_cat[4000:]
scores_cat, _ = solve_and_eval(train_fc, train_y, test_fc, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_cat > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")