""" Build the best analytical head from our findings and run full mAP eval. Classification: 768 raw LayerNorm'd features (69.6% accuracy) Regression: 768 raw + H^1 vertical + H^1 horizontal boundary features (68.7% quality) Centerness: 768 raw features Accumulate on training data, solve, save checkpoint, eval via eval_coco_map.py. """ import json import os import sys import time import torch import torch.nn.functional as F SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) CACHE_DIR = os.environ.get("ARENA_CACHE_DIR", "feature_cache") COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "coco") VAL_CACHE = os.environ.get("ARENA_VAL_CACHE", "val_cache/val.pt") RESOLUTION = 640 NUM_CLASSES = 80 DEVICE = "cuda" def cofiber_decompose(f, n_scales): cofibers = [] residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega) residual = omega cofibers.append(residual) return cofibers def make_locations(sizes, strides): locs = [] for (h, w), s in zip(sizes, strides): ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s gy, gx = torch.meshgrid(ys, xs, indexing="ij") locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) return locs def assign_targets(loc, boxes, labels, stride, sr): n = loc.shape[0] ct = torch.full((n,), -1, dtype=torch.long) rt = torch.zeros(n, 4) ctrt = torch.zeros(n) if boxes.numel() == 0: return ct, rt, ctrt areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) l = loc[:, None, 0] - boxes[None, :, 0] t = loc[:, None, 1] - boxes[None, :, 1] r = boxes[None, :, 2] - loc[:, None, 0] b = boxes[None, :, 3] - loc[:, None, 1] ltrb = torch.stack([l, t, r, b], -1) in_box = ltrb.min(-1).values > 0 cx = (boxes[:, 0] + boxes[:, 2]) / 2 cy = (boxes[:, 1] + boxes[:, 3]) / 2 rad = stride * 1.5 in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) max_d = ltrb.max(-1).values in_level = (max_d >= sr[0]) & (max_d <= sr[1]) pos = in_box & in_center & in_level a = areas[None, :].expand_as(pos).clone() a[~pos] = float("inf") matched = a.argmin(1) is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") ct[is_pos] = labels[matched[is_pos]] if is_pos.any(): rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]] lp, tp, rp, bp = rt[is_pos].unbind(-1) ctrt[is_pos] = torch.sqrt( (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) * (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6))) return ct, rt, ctrt def compute_h1(f, B, H, W, C): """Sheaf H^1 compact: vertical + horizontal boundary magnitudes.""" f_4d = f.reshape(B, H, W, C).permute(0, 3, 1, 2) d_up = f_4d - F.pad(f_4d[:, :, 1:, :], (0, 0, 0, 1)) d_down = f_4d - F.pad(f_4d[:, :, :-1, :], (0, 0, 1, 0)) d_left = f_4d - F.pad(f_4d[:, :, :, 1:], (0, 1, 0, 0)) d_right = f_4d - F.pad(f_4d[:, :, :, :-1], (1, 0, 0, 0)) v_bound = (d_up.abs() + d_down.abs()).permute(0, 2, 3, 1).reshape(-1, C) h_bound = (d_left.abs() + d_right.abs()).permute(0, 2, 3, 1).reshape(-1, C) return v_bound, h_bound def main(): print("=" * 60) print("Best Analytical Head: 768 cls + H^1 regression") print("=" * 60, flush=True) manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) n_shards = manifest["n_shards"] strides = [16, 32, 64] H = RESOLUTION // 16 sizes = [(H, H), (H // 2, H // 2), (H // 4, H // 4)] sr = [(-1, 128), (128, 256), (256, float("inf"))] locs = make_locations(sizes, strides) feat_dim = 768 reg_dim = 768 * 3 # raw + h1v + h1h # Accumulators cls_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) cls_XtY = torch.zeros(feat_dim + 1, NUM_CLASSES, device=DEVICE) reg_XtX = torch.zeros(reg_dim + 1, reg_dim + 1, device=DEVICE) reg_XtY = torch.zeros(reg_dim + 1, 4, device=DEVICE) ctr_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) ctr_XtY = torch.zeros(feat_dim + 1, 1, device=DEVICE) n_pos = 0 n_images = 20000 seen = 0 t0 = time.time() for si in range(n_shards): if seen >= n_images: break shard = torch.load(os.path.join(CACHE_DIR, f"shard_{si:04d}.pt"), map_location="cpu", weights_only=False) for item in shard: if seen >= n_images: break sp = item["spatial"].unsqueeze(0).float() boxes = item["boxes"] labels = item["labels"] cofibers = cofiber_decompose(sp, 3) for sci, cof in enumerate(cofibers): B, C, Hc, Wc = cof.shape f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) h1v, h1h = compute_h1(f, B, Hc, Wc, C) ct, rt, ctrt = assign_targets(locs[sci], boxes, labels, strides[sci], sr[sci]) pos_mask = ct >= 0 if not pos_mask.any(): continue # Classification: raw features only fp = f[pos_mask].to(DEVICE) fa = torch.cat([fp, torch.ones(fp.shape[0], 1, device=DEVICE)], 1) yc = torch.zeros(fp.shape[0], NUM_CLASSES, device=DEVICE) yc[torch.arange(fp.shape[0], device=DEVICE), ct[pos_mask].to(DEVICE)] = 1.0 cls_XtX += fa.T @ fa cls_XtY += fa.T @ yc # Regression: raw + H^1 f_reg = torch.cat([f[pos_mask], h1v[pos_mask], h1h[pos_mask]], 1).to(DEVICE) ltrb = rt[pos_mask] valid = (ltrb > 0).all(1) if valid.any(): fv = f_reg[valid] fva = torch.cat([fv, torch.ones(fv.shape[0], 1, device=DEVICE)], 1) yt = torch.log(ltrb[valid]).to(DEVICE) reg_XtX += fva.T @ fva reg_XtY += fva.T @ yt # Centerness: raw features ctr_XtX += fa.T @ fa ctr_XtY += fa.T @ ctrt[pos_mask].unsqueeze(1).to(DEVICE) n_pos += pos_mask.sum().item() seen += 1 del shard if (si + 1) % 5 == 0: print(f" shard {si+1}: {seen} imgs, {n_pos} pos, {time.time()-t0:.0f}s", flush=True) print(f"\nAccumulated {seen} images, {n_pos} positives", flush=True) # Solve lam = 0.1 I_cls = torch.eye(feat_dim + 1, device=DEVICE) I_reg = torch.eye(reg_dim + 1, device=DEVICE) I_ctr = torch.eye(feat_dim + 1, device=DEVICE) cls_W = torch.linalg.solve(cls_XtX + lam * I_cls * n_pos, cls_XtY) reg_W = torch.linalg.solve(reg_XtX + lam * I_reg * n_pos, reg_XtY) ctr_W = torch.linalg.solve(ctr_XtX + lam * I_ctr * n_pos, ctr_XtY) print(f"Solved. cls: {feat_dim}->80, reg: {reg_dim}->4, ctr: {feat_dim}->1", flush=True) # Save as state dict state = { "cls_weight": cls_W[:feat_dim].T.cpu(), "cls_bias": cls_W[feat_dim].cpu(), "reg_weight": reg_W[:reg_dim].T.cpu(), "reg_bias": reg_W[reg_dim].cpu(), "ctr_weight": ctr_W[:feat_dim].T.cpu(), "ctr_bias": ctr_W[feat_dim].cpu(), "scale_norms.0.weight": torch.ones(768), "scale_norms.0.bias": torch.zeros(768), "scale_norms.1.weight": torch.ones(768), "scale_norms.1.bias": torch.zeros(768), "scale_norms.2.weight": torch.ones(768), "scale_norms.2.bias": torch.zeros(768), "scale_params": torch.ones(3), "meta": {"cls_features": "768_layernorm", "reg_features": "768_layernorm_h1v_h1h", "ctr_features": "768_layernorm", "lambda": lam, "n_images": seen, "n_pos": n_pos}, } out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "analytical_h1") os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, "analytical_h1_best.pth") torch.save(state, out_path) n_params = sum(v.numel() for k, v in state.items() if isinstance(v, torch.Tensor)) elapsed = time.time() - t0 print(f"\nSaved: {out_path}") print(f"Total params: {n_params:,}") print(f"Construction time: {elapsed:.0f}s") print(f"\nClassification: 768 dims, {feat_dim * NUM_CLASSES + NUM_CLASSES:,} params") print(f"Regression: {reg_dim} dims (768+768+768), {reg_dim * 4 + 4:,} params") print(f"Centerness: 768 dims, {feat_dim + 1:,} params") if __name__ == "__main__": main()