| """Push person classifier precision toward 99% while keeping recall high.""" |
| import json, os, torch |
| import torch.nn.functional as F |
| from pycocotools.coco import COCO |
|
|
| COCO_ROOT = os.environ["ARENA_COCO_ROOT"] |
| VAL_CACHE = os.environ["ARENA_VAL_CACHE"] |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
| with open(os.path.join(SCRIPT_DIR, "..", "circuit", "evolved_extreme.json")) as f: |
| evolved = json.load(f) |
| dims_100 = sorted(list(set([r for r in evolved if r["K"] == 100][0]["genome"]))) |
|
|
| val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False) |
| coco = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json")) |
| PERSON_CAT = 1 |
|
|
| def cofiber_decompose(f, n_scales): |
| cofibers = []; residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega); residual = omega |
| cofibers.append(residual); return cofibers |
|
|
| def has_person(img_id): |
| return len(coco.getAnnIds(imgIds=img_id, catIds=[PERSON_CAT], iscrowd=False)) > 0 |
|
|
| |
| print("Pre-computing image vectors...", flush=True) |
| all_vecs_768 = [] |
| all_labels = [] |
| for idx in range(len(val)): |
| item = val[idx] |
| spatial = item["spatial"].unsqueeze(0).float() |
| cofibers = cofiber_decompose(spatial, 3) |
| feats = [] |
| for cof in cofibers: |
| B, C, Hc, Wc = cof.shape |
| f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) |
| feats.append(f) |
| |
| all_f = torch.cat(feats) |
| all_vecs_768.append(all_f.max(dim=0).values) |
| all_labels.append(has_person(int(item["img_id"]))) |
| if (idx + 1) % 1000 == 0: |
| print(f" {idx+1}/{len(val)}", flush=True) |
|
|
| all_vecs_768 = torch.stack(all_vecs_768) |
| all_labels = torch.tensor(all_labels, dtype=torch.bool) |
| print(f" {all_labels.sum()} person, {(~all_labels).sum()} non-person\n") |
|
|
| |
| train_f = all_vecs_768[:4000] |
| train_y = all_labels[:4000].float().unsqueeze(1) |
| test_f = all_vecs_768[4000:] |
| test_y = all_labels[4000:] |
|
|
| def solve_and_eval(features_train, y_train, features_test, labels_test, lam=0.1): |
| n = features_train.shape[0] |
| d = features_train.shape[1] |
| fa = torch.cat([features_train, torch.ones(n, 1)], 1) |
| I = torch.eye(d + 1) |
| W = torch.linalg.solve(fa.T @ fa + lam * I * n, fa.T @ y_train) |
| scores = (features_test @ W[:d].squeeze() + W[d].squeeze().item()).sigmoid() |
| return scores, W |
|
|
| |
| |
| |
| print("=== Exp 1: Threshold sweep (92 evolved dims) ===") |
| scores_92, _ = solve_and_eval(train_f[:, dims_100], train_y, test_f[:, dims_100], test_y) |
| print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") |
| for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: |
| pred = scores_92 > t |
| tp = (pred & test_y).sum().item() |
| fp = (pred & ~test_y).sum().item() |
| fn = (~pred & test_y).sum().item() |
| tn = (~pred & ~test_y).sum().item() |
| prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) |
| f1 = 2*prec*rec / max(prec+rec, 1e-9) |
| marker = " <-- 99%+ prec" if prec >= 0.99 else "" |
| print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}") |
|
|
| |
| |
| |
| print(f"\n=== Exp 2: Full 768 dims ===") |
| scores_768, _ = solve_and_eval(train_f, train_y, test_f, test_y) |
| print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") |
| for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: |
| pred = scores_768 > t |
| tp = (pred & test_y).sum().item() |
| fp = (pred & ~test_y).sum().item() |
| fn = (~pred & test_y).sum().item() |
| tn = (~pred & ~test_y).sum().item() |
| prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) |
| f1 = 2*prec*rec / max(prec+rec, 1e-9) |
| marker = " <-- 99%+ prec" if prec >= 0.99 else "" |
| print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}") |
|
|
| |
| |
| |
| print(f"\n=== Exp 3: Full 768 dims, mean-pool ===") |
| all_vecs_mean = [] |
| for idx in range(len(val)): |
| item = val[idx] |
| spatial = item["spatial"].unsqueeze(0).float() |
| cofibers = cofiber_decompose(spatial, 3) |
| feats = [] |
| for cof in cofibers: |
| B, C, Hc, Wc = cof.shape |
| f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) |
| feats.append(f) |
| all_vecs_mean.append(torch.cat(feats).mean(dim=0)) |
| all_vecs_mean = torch.stack(all_vecs_mean) |
| train_fm = all_vecs_mean[:4000] |
| test_fm = all_vecs_mean[4000:] |
| scores_mean, _ = solve_and_eval(train_fm, train_y, test_fm, test_y) |
| print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") |
| for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: |
| pred = scores_mean > t |
| tp = (pred & test_y).sum().item() |
| fp = (pred & ~test_y).sum().item() |
| fn = (~pred & test_y).sum().item() |
| tn = (~pred & ~test_y).sum().item() |
| prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) |
| f1 = 2*prec*rec / max(prec+rec, 1e-9) |
| marker = " <-- 99%+ prec" if prec >= 0.99 else "" |
| print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}") |
|
|
| |
| |
| |
| print(f"\n=== Exp 4: 768 max + 768 mean (1536 dims) ===") |
| all_vecs_cat = torch.cat([all_vecs_768, all_vecs_mean], dim=1) |
| train_fc = all_vecs_cat[:4000] |
| test_fc = all_vecs_cat[4000:] |
| scores_cat, _ = solve_and_eval(train_fc, train_y, test_fc, test_y) |
| print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}") |
| for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]: |
| pred = scores_cat > t |
| tp = (pred & test_y).sum().item() |
| fp = (pred & ~test_y).sum().item() |
| fn = (~pred & test_y).sum().item() |
| tn = (~pred & ~test_y).sum().item() |
| prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1) |
| f1 = 2*prec*rec / max(prec+rec, 1e-9) |
| marker = " <-- 99%+ prec" if prec >= 0.99 else "" |
| print(f" {t:5.2f} {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}") |
|
|