"""Push person classifier precision toward 99% while keeping recall high.""" import json, os, torch import torch.nn.functional as F from pycocotools.coco import COCO COCO_ROOT = os.environ["ARENA_COCO_ROOT"] VAL_CACHE = os.environ["ARENA_VAL_CACHE"] SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(SCRIPT_DIR, "..", "circuit", "evolved_extreme.json")) as f: evolved = json.load(f) dims_100 = sorted(list(set([r for r in evolved if r["K"] == 100][0]["genome"]))) val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False) coco = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json")) PERSON_CAT = 1 def cofiber_decompose(f, n_scales): cofibers = []; residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega); residual = omega cofibers.append(residual); return cofibers def has_person(img_id): return len(coco.getAnnIds(imgIds=img_id, catIds=[PERSON_CAT], iscrowd=False)) > 0 # Pre-compute ALL image vectors at full 768 dims print("Pre-computing image vectors...", flush=True) all_vecs_768 = [] all_labels = [] for idx in range(len(val)): item = val[idx] spatial = item["spatial"].unsqueeze(0).float() cofibers = cofiber_decompose(spatial, 3) feats = [] for cof in cofibers: B, C, Hc, Wc = cof.shape f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) feats.append(f) # Max-pool across all 2100 locations all_f = torch.cat(feats) # (2100, 768) all_vecs_768.append(all_f.max(dim=0).values) all_labels.append(has_person(int(item["img_id"]))) if (idx + 1) % 1000 == 0: print(f" {idx+1}/{len(val)}", flush=True) all_vecs_768 = torch.stack(all_vecs_768) all_labels = torch.tensor(all_labels, dtype=torch.bool) print(f" {all_labels.sum()} person, {(~all_labels).sum()} non-person\n") # Split: first 4000 train, last 1000 test train_f = all_vecs_768[:4000] train_y = all_labels[:4000].float().unsqueeze(1) test_f = all_vecs_768[4000:] test_y = all_labels[4000:] def solve_and_eval(features_train, y_train, features_test, labels_test, lam=0.1): n = features_train.shape[0] d = features_train.shape[1] fa = torch.cat([features_train, torch.ones(n, 1)], 1) I = torch.eye(d + 1) W = torch.linalg.solve(fa.T @ fa + lam * I * n, fa.T @ y_train) scores = (features_test @ W[:d].squeeze() + W[d].squeeze().item()).sigmoid() return scores, W # ============================================================ # Experiment 1: Threshold sweep on 92 evolved dims # ============================================================ print("=== Exp 1: Threshold sweep (92 evolved dims) ===") scores_92, _ = solve_and_eval(train_f[:, dims_100], train_y, test_f[:, dims_100], test_y) print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: pred = scores_92 > t tp = (pred & test_y).sum().item() fp = (pred & ~test_y).sum().item() fn = (~pred & test_y).sum().item() tn = (~pred & ~test_y).sum().item() prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) f1 = 2*prec*rec / max(prec+rec, 1e-9) marker = " <-- 99%+ prec" if prec >= 0.99 else "" print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}") # ============================================================ # Experiment 2: Full 768 dims # ============================================================ print(f"\n=== Exp 2: Full 768 dims ===") scores_768, _ = solve_and_eval(train_f, train_y, test_f, test_y) print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: pred = scores_768 > t tp = (pred & test_y).sum().item() fp = (pred & ~test_y).sum().item() fn = (~pred & test_y).sum().item() tn = (~pred & ~test_y).sum().item() prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) f1 = 2*prec*rec / max(prec+rec, 1e-9) marker = " <-- 99%+ prec" if prec >= 0.99 else "" print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}") # ============================================================ # Experiment 3: Full 768 dims + mean pooling instead of max # ============================================================ print(f"\n=== Exp 3: Full 768 dims, mean-pool ===") all_vecs_mean = [] for idx in range(len(val)): item = val[idx] spatial = item["spatial"].unsqueeze(0).float() cofibers = cofiber_decompose(spatial, 3) feats = [] for cof in cofibers: B, C, Hc, Wc = cof.shape f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) feats.append(f) all_vecs_mean.append(torch.cat(feats).mean(dim=0)) all_vecs_mean = torch.stack(all_vecs_mean) train_fm = all_vecs_mean[:4000] test_fm = all_vecs_mean[4000:] scores_mean, _ = solve_and_eval(train_fm, train_y, test_fm, test_y) print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: pred = scores_mean > t tp = (pred & test_y).sum().item() fp = (pred & ~test_y).sum().item() fn = (~pred & test_y).sum().item() tn = (~pred & ~test_y).sum().item() prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) f1 = 2*prec*rec / max(prec+rec, 1e-9) marker = " <-- 99%+ prec" if prec >= 0.99 else "" print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}") # ============================================================ # Experiment 4: 768 max + 768 mean concatenated (1536 dims) # ============================================================ print(f"\n=== Exp 4: 768 max + 768 mean (1536 dims) ===") all_vecs_cat = torch.cat([all_vecs_768, all_vecs_mean], dim=1) train_fc = all_vecs_cat[:4000] test_fc = all_vecs_cat[4000:] scores_cat, _ = solve_and_eval(train_fc, train_y, test_fc, test_y) print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: pred = scores_cat > t tp = (pred & test_y).sum().item() fp = (pred & ~test_y).sum().item() fn = (~pred & test_y).sum().item() tn = (~pred & ~test_y).sum().item() prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) f1 = 2*prec*rec / max(prec+rec, 1e-9) marker = " <-- 99%+ prec" if prec >= 0.99 else "" print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")