File size: 6,818 Bytes
74e3c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Push person classifier precision toward 99% while keeping recall high."""
import json, os, torch
import torch.nn.functional as F
from pycocotools.coco import COCO

COCO_ROOT = os.environ["ARENA_COCO_ROOT"]
VAL_CACHE = os.environ["ARENA_VAL_CACHE"]
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

with open(os.path.join(SCRIPT_DIR, "..", "circuit", "evolved_extreme.json")) as f:
    evolved = json.load(f)
dims_100 = sorted(list(set([r for r in evolved if r["K"] == 100][0]["genome"])))

val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
coco = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
PERSON_CAT = 1

def cofiber_decompose(f, n_scales):
    cofibers = []; residual = f
    for _ in range(n_scales - 1):
        omega = F.avg_pool2d(residual, 2)
        sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
        cofibers.append(residual - sigma_omega); residual = omega
    cofibers.append(residual); return cofibers

def has_person(img_id):
    return len(coco.getAnnIds(imgIds=img_id, catIds=[PERSON_CAT], iscrowd=False)) > 0

# Pre-compute ALL image vectors at full 768 dims
print("Pre-computing image vectors...", flush=True)
all_vecs_768 = []
all_labels = []
for idx in range(len(val)):
    item = val[idx]
    spatial = item["spatial"].unsqueeze(0).float()
    cofibers = cofiber_decompose(spatial, 3)
    feats = []
    for cof in cofibers:
        B, C, Hc, Wc = cof.shape
        f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
        feats.append(f)
    # Max-pool across all 2100 locations
    all_f = torch.cat(feats)  # (2100, 768)
    all_vecs_768.append(all_f.max(dim=0).values)
    all_labels.append(has_person(int(item["img_id"])))
    if (idx + 1) % 1000 == 0:
        print(f"  {idx+1}/{len(val)}", flush=True)

all_vecs_768 = torch.stack(all_vecs_768)
all_labels = torch.tensor(all_labels, dtype=torch.bool)
print(f"  {all_labels.sum()} person, {(~all_labels).sum()} non-person\n")

# Split: first 4000 train, last 1000 test
train_f = all_vecs_768[:4000]
train_y = all_labels[:4000].float().unsqueeze(1)
test_f = all_vecs_768[4000:]
test_y = all_labels[4000:]

def solve_and_eval(features_train, y_train, features_test, labels_test, lam=0.1):
    n = features_train.shape[0]
    d = features_train.shape[1]
    fa = torch.cat([features_train, torch.ones(n, 1)], 1)
    I = torch.eye(d + 1)
    W = torch.linalg.solve(fa.T @ fa + lam * I * n, fa.T @ y_train)
    scores = (features_test @ W[:d].squeeze() + W[d].squeeze().item()).sigmoid()
    return scores, W

# ============================================================
# Experiment 1: Threshold sweep on 92 evolved dims
# ============================================================
print("=== Exp 1: Threshold sweep (92 evolved dims) ===")
scores_92, _ = solve_and_eval(train_f[:, dims_100], train_y, test_f[:, dims_100], test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
    pred = scores_92 > t
    tp = (pred & test_y).sum().item()
    fp = (pred & ~test_y).sum().item()
    fn = (~pred & test_y).sum().item()
    tn = (~pred & ~test_y).sum().item()
    prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
    f1 = 2*prec*rec / max(prec+rec, 1e-9)
    marker = " <-- 99%+ prec" if prec >= 0.99 else ""
    print(f"  {t:5.2f}  {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")

# ============================================================
# Experiment 2: Full 768 dims
# ============================================================
print(f"\n=== Exp 2: Full 768 dims ===")
scores_768, _ = solve_and_eval(train_f, train_y, test_f, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
    pred = scores_768 > t
    tp = (pred & test_y).sum().item()
    fp = (pred & ~test_y).sum().item()
    fn = (~pred & test_y).sum().item()
    tn = (~pred & ~test_y).sum().item()
    prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
    f1 = 2*prec*rec / max(prec+rec, 1e-9)
    marker = " <-- 99%+ prec" if prec >= 0.99 else ""
    print(f"  {t:5.2f}  {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")

# ============================================================
# Experiment 3: Full 768 dims + mean pooling instead of max
# ============================================================
print(f"\n=== Exp 3: Full 768 dims, mean-pool ===")
all_vecs_mean = []
for idx in range(len(val)):
    item = val[idx]
    spatial = item["spatial"].unsqueeze(0).float()
    cofibers = cofiber_decompose(spatial, 3)
    feats = []
    for cof in cofibers:
        B, C, Hc, Wc = cof.shape
        f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
        feats.append(f)
    all_vecs_mean.append(torch.cat(feats).mean(dim=0))
all_vecs_mean = torch.stack(all_vecs_mean)
train_fm = all_vecs_mean[:4000]
test_fm = all_vecs_mean[4000:]
scores_mean, _ = solve_and_eval(train_fm, train_y, test_fm, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
    pred = scores_mean > t
    tp = (pred & test_y).sum().item()
    fp = (pred & ~test_y).sum().item()
    fn = (~pred & test_y).sum().item()
    tn = (~pred & ~test_y).sum().item()
    prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
    f1 = 2*prec*rec / max(prec+rec, 1e-9)
    marker = " <-- 99%+ prec" if prec >= 0.99 else ""
    print(f"  {t:5.2f}  {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")

# ============================================================
# Experiment 4: 768 max + 768 mean concatenated (1536 dims)
# ============================================================
print(f"\n=== Exp 4: 768 max + 768 mean (1536 dims) ===")
all_vecs_cat = torch.cat([all_vecs_768, all_vecs_mean], dim=1)
train_fc = all_vecs_cat[:4000]
test_fc = all_vecs_cat[4000:]
scores_cat, _ = solve_and_eval(train_fc, train_y, test_fc, test_y)
print(f"{'Thresh':>7s} {'TP':>5s} {'FP':>5s} {'FN':>5s} {'TN':>5s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s}")
for t in [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.97, 0.99]:
    pred = scores_cat > t
    tp = (pred & test_y).sum().item()
    fp = (pred & ~test_y).sum().item()
    fn = (~pred & test_y).sum().item()
    tn = (~pred & ~test_y).sum().item()
    prec = tp / max(tp+fp, 1); rec = tp / max(tp+fn, 1)
    f1 = 2*prec*rec / max(prec+rec, 1e-9)
    marker = " <-- 99%+ prec" if prec >= 0.99 else ""
    print(f"  {t:5.2f}  {tp:5d} {fp:5d} {fn:5d} {tn:5d} {prec:6.3f} {rec:6.3f} {f1:6.3f}{marker}")