File size: 6,818 Bytes
74e3c01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """Push person classifier precision toward 99% while keeping recall high."""
import json, os, torch
import torch.nn.functional as F
from pycocotools.coco import COCO
COCO_ROOT = os.environ["ARENA_COCO_ROOT"]
VAL_CACHE = os.environ["ARENA_VAL_CACHE"]
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(SCRIPT_DIR, "..", "circuit", "evolved_extreme.json")) as f:
evolved = json.load(f)
dims_100 = sorted(list(set([r for r in evolved if r["K"] == 100][0]["genome"])))
val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
coco = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
PERSON_CAT = 1
def cofiber_decompose(f, n_scales):
cofibers = []; residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega); residual = omega
cofibers.append(residual); return cofibers
def has_person(img_id):
return len(coco.getAnnIds(imgIds=img_id, catIds=[PERSON_CAT], iscrowd=False)) > 0
# Pre-compute ALL image vectors at full 768 dims
print("Pre-computing image vectors...", flush=True)
all_vecs_768 = []
all_labels = []
for idx in range(len(val)):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float()
cofibers = cofiber_decompose(spatial, 3)
feats = []
for cof in cofibers:
B, C, Hc, Wc = cof.shape
f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
feats.append(f)
# Max-pool across all 2100 locations
all_f = torch.cat(feats) # (2100, 768)
all_vecs_768.append(all_f.max(dim=0).values)
all_labels.append(has_person(int(item["img_id"])))
if (idx + 1) % 1000 == 0:
print(f" {idx+1}/{len(val)}", flush=True)
all_vecs_768 = torch.stack(all_vecs_768)
all_labels = torch.tensor(all_labels, dtype=torch.bool)
print(f" {all_labels.sum()} person, {(~all_labels).sum()} non-person\n")
# Split: first 4000 train, last 1000 test
train_f = all_vecs_768[:4000]
train_y = all_labels[:4000].float().unsqueeze(1)
test_f = all_vecs_768[4000:]
test_y = all_labels[4000:]
def solve_and_eval(features_train, y_train, features_test, labels_test, lam=0.1):
n = features_train.shape[0]
d = features_train.shape[1]
fa = torch.cat([features_train, torch.ones(n, 1)], 1)
I = torch.eye(d + 1)
W = torch.linalg.solve(fa.T @ fa + lam * I * n, fa.T @ y_train)
scores = (features_test @ W[:d].squeeze() + W[d].squeeze().item()).sigmoid()
return scores, W
# ============================================================
# Experiment 1: Threshold sweep on 92 evolved dims
# ============================================================
print("=== Exp 1: Threshold sweep (92 evolved dims) ===")
scores_92, _ = solve_and_eval(train_f[:, dims_100], train_y, test_f[:, dims_100], test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_92 > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")
# ============================================================
# Experiment 2: Full 768 dims
# ============================================================
print(f"\n=== Exp 2: Full 768 dims ===")
scores_768, _ = solve_and_eval(train_f, train_y, test_f, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_768 > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")
# ============================================================
# Experiment 3: Full 768 dims + mean pooling instead of max
# ============================================================
print(f"\n=== Exp 3: Full 768 dims, mean-pool ===")
all_vecs_mean = []
for idx in range(len(val)):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float()
cofibers = cofiber_decompose(spatial, 3)
feats = []
for cof in cofibers:
B, C, Hc, Wc = cof.shape
f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
feats.append(f)
all_vecs_mean.append(torch.cat(feats).mean(dim=0))
all_vecs_mean = torch.stack(all_vecs_mean)
train_fm = all_vecs_mean[:4000]
test_fm = all_vecs_mean[4000:]
scores_mean, _ = solve_and_eval(train_fm, train_y, test_fm, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_mean > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")
# ============================================================
# Experiment 4: 768 max + 768 mean concatenated (1536 dims)
# ============================================================
print(f"\n=== Exp 4: 768 max + 768 mean (1536 dims) ===")
all_vecs_cat = torch.cat([all_vecs_768, all_vecs_mean], dim=1)
train_fc = all_vecs_cat[:4000]
test_fc = all_vecs_cat[4000:]
scores_cat, _ = solve_and_eval(train_fc, train_y, test_fc, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
pred = scores_cat > t
tp = (pred & test_y).sum().item()
fp = (pred & ~test_y).sum().item()
fn = (~pred & test_y).sum().item()
tn = (~pred & ~test_y).sum().item()
prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
f1 = 2*prec*rec / max(prec+rec, 1e-9)
marker = " <-- 99%+ prec" if prec >= 0.99 else ""
print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")
|