| """ |
| Build the best analytical head from our findings and run full mAP eval. |
| |
| Classification: 768 raw LayerNorm'd features (69.6% accuracy) |
| Regression: 768 raw + H^1 vertical + H^1 horizontal boundary features (68.7% quality) |
| Centerness: 768 raw features |
| |
| Accumulate on training data, solve, save checkpoint, eval via eval_coco_map.py. |
| """ |
|
|
| import json |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| CACHE_DIR = os.environ.get("ARENA_CACHE_DIR", "feature_cache") |
| COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "coco") |
| VAL_CACHE = os.environ.get("ARENA_VAL_CACHE", "val_cache/val.pt") |
| RESOLUTION = 640 |
| NUM_CLASSES = 80 |
| DEVICE = "cuda" |
|
|
|
|
| def cofiber_decompose(f, n_scales): |
| cofibers = [] |
| residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega) |
| residual = omega |
| cofibers.append(residual) |
| return cofibers |
|
|
|
|
| def make_locations(sizes, strides): |
| locs = [] |
| for (h, w), s in zip(sizes, strides): |
| ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") |
| locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) |
| return locs |
|
|
|
|
| def assign_targets(loc, boxes, labels, stride, sr): |
| n = loc.shape[0] |
| ct = torch.full((n,), -1, dtype=torch.long) |
| rt = torch.zeros(n, 4) |
| ctrt = torch.zeros(n) |
| if boxes.numel() == 0: |
| return ct, rt, ctrt |
| areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
| l = loc[:, None, 0] - boxes[None, :, 0] |
| t = loc[:, None, 1] - boxes[None, :, 1] |
| r = boxes[None, :, 2] - loc[:, None, 0] |
| b = boxes[None, :, 3] - loc[:, None, 1] |
| ltrb = torch.stack([l, t, r, b], -1) |
| in_box = ltrb.min(-1).values > 0 |
| cx = (boxes[:, 0] + boxes[:, 2]) / 2 |
| cy = (boxes[:, 1] + boxes[:, 3]) / 2 |
| rad = stride * 1.5 |
| in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & |
| (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) |
| max_d = ltrb.max(-1).values |
| in_level = (max_d >= sr[0]) & (max_d <= sr[1]) |
| pos = in_box & in_center & in_level |
| a = areas[None, :].expand_as(pos).clone() |
| a[~pos] = float("inf") |
| matched = a.argmin(1) |
| is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") |
| ct[is_pos] = labels[matched[is_pos]] |
| if is_pos.any(): |
| rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]] |
| lp, tp, rp, bp = rt[is_pos].unbind(-1) |
| ctrt[is_pos] = torch.sqrt( |
| (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) * |
| (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6))) |
| return ct, rt, ctrt |
|
|
|
|
| def compute_h1(f, B, H, W, C): |
| """Sheaf H^1 compact: vertical + horizontal boundary magnitudes.""" |
| f_4d = f.reshape(B, H, W, C).permute(0, 3, 1, 2) |
| d_up = f_4d - F.pad(f_4d[:, :, 1:, :], (0, 0, 0, 1)) |
| d_down = f_4d - F.pad(f_4d[:, :, :-1, :], (0, 0, 1, 0)) |
| d_left = f_4d - F.pad(f_4d[:, :, :, 1:], (0, 1, 0, 0)) |
| d_right = f_4d - F.pad(f_4d[:, :, :, :-1], (1, 0, 0, 0)) |
| v_bound = (d_up.abs() + d_down.abs()).permute(0, 2, 3, 1).reshape(-1, C) |
| h_bound = (d_left.abs() + d_right.abs()).permute(0, 2, 3, 1).reshape(-1, C) |
| return v_bound, h_bound |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("Best Analytical Head: 768 cls + H^1 regression") |
| print("=" * 60, flush=True) |
|
|
| manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) |
| n_shards = manifest["n_shards"] |
| strides = [16, 32, 64] |
| H = RESOLUTION // 16 |
| sizes = [(H, H), (H // 2, H // 2), (H // 4, H // 4)] |
| sr = [(-1, 128), (128, 256), (256, float("inf"))] |
| locs = make_locations(sizes, strides) |
|
|
| feat_dim = 768 |
| reg_dim = 768 * 3 |
|
|
| |
| cls_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) |
| cls_XtY = torch.zeros(feat_dim + 1, NUM_CLASSES, device=DEVICE) |
| reg_XtX = torch.zeros(reg_dim + 1, reg_dim + 1, device=DEVICE) |
| reg_XtY = torch.zeros(reg_dim + 1, 4, device=DEVICE) |
| ctr_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) |
| ctr_XtY = torch.zeros(feat_dim + 1, 1, device=DEVICE) |
| n_pos = 0 |
| n_images = 20000 |
| seen = 0 |
| t0 = time.time() |
|
|
| for si in range(n_shards): |
| if seen >= n_images: |
| break |
| shard = torch.load(os.path.join(CACHE_DIR, f"shard_{si:04d}.pt"), |
| map_location="cpu", weights_only=False) |
| for item in shard: |
| if seen >= n_images: |
| break |
| sp = item["spatial"].unsqueeze(0).float() |
| boxes = item["boxes"] |
| labels = item["labels"] |
| cofibers = cofiber_decompose(sp, 3) |
| for sci, cof in enumerate(cofibers): |
| B, C, Hc, Wc = cof.shape |
| f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) |
| h1v, h1h = compute_h1(f, B, Hc, Wc, C) |
| ct, rt, ctrt = assign_targets(locs[sci], boxes, labels, strides[sci], sr[sci]) |
| pos_mask = ct >= 0 |
| if not pos_mask.any(): |
| continue |
|
|
| |
| fp = f[pos_mask].to(DEVICE) |
| fa = torch.cat([fp, torch.ones(fp.shape[0], 1, device=DEVICE)], 1) |
| yc = torch.zeros(fp.shape[0], NUM_CLASSES, device=DEVICE) |
| yc[torch.arange(fp.shape[0], device=DEVICE), ct[pos_mask].to(DEVICE)] = 1.0 |
| cls_XtX += fa.T @ fa |
| cls_XtY += fa.T @ yc |
|
|
| |
| f_reg = torch.cat([f[pos_mask], h1v[pos_mask], h1h[pos_mask]], 1).to(DEVICE) |
| ltrb = rt[pos_mask] |
| valid = (ltrb > 0).all(1) |
| if valid.any(): |
| fv = f_reg[valid] |
| fva = torch.cat([fv, torch.ones(fv.shape[0], 1, device=DEVICE)], 1) |
| yt = torch.log(ltrb[valid]).to(DEVICE) |
| reg_XtX += fva.T @ fva |
| reg_XtY += fva.T @ yt |
|
|
| |
| ctr_XtX += fa.T @ fa |
| ctr_XtY += fa.T @ ctrt[pos_mask].unsqueeze(1).to(DEVICE) |
|
|
| n_pos += pos_mask.sum().item() |
| seen += 1 |
| del shard |
| if (si + 1) % 5 == 0: |
| print(f" shard {si+1}: {seen} imgs, {n_pos} pos, {time.time()-t0:.0f}s", flush=True) |
|
|
| print(f"\nAccumulated {seen} images, {n_pos} positives", flush=True) |
|
|
| |
| lam = 0.1 |
| I_cls = torch.eye(feat_dim + 1, device=DEVICE) |
| I_reg = torch.eye(reg_dim + 1, device=DEVICE) |
| I_ctr = torch.eye(feat_dim + 1, device=DEVICE) |
|
|
| cls_W = torch.linalg.solve(cls_XtX + lam * I_cls * n_pos, cls_XtY) |
| reg_W = torch.linalg.solve(reg_XtX + lam * I_reg * n_pos, reg_XtY) |
| ctr_W = torch.linalg.solve(ctr_XtX + lam * I_ctr * n_pos, ctr_XtY) |
|
|
| print(f"Solved. cls: {feat_dim}->80, reg: {reg_dim}->4, ctr: {feat_dim}->1", flush=True) |
|
|
| |
| state = { |
| "cls_weight": cls_W[:feat_dim].T.cpu(), |
| "cls_bias": cls_W[feat_dim].cpu(), |
| "reg_weight": reg_W[:reg_dim].T.cpu(), |
| "reg_bias": reg_W[reg_dim].cpu(), |
| "ctr_weight": ctr_W[:feat_dim].T.cpu(), |
| "ctr_bias": ctr_W[feat_dim].cpu(), |
| "scale_norms.0.weight": torch.ones(768), |
| "scale_norms.0.bias": torch.zeros(768), |
| "scale_norms.1.weight": torch.ones(768), |
| "scale_norms.1.bias": torch.zeros(768), |
| "scale_norms.2.weight": torch.ones(768), |
| "scale_norms.2.bias": torch.zeros(768), |
| "scale_params": torch.ones(3), |
| "meta": {"cls_features": "768_layernorm", |
| "reg_features": "768_layernorm_h1v_h1h", |
| "ctr_features": "768_layernorm", |
| "lambda": lam, "n_images": seen, "n_pos": n_pos}, |
| } |
|
|
| out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "analytical_h1") |
| os.makedirs(out_dir, exist_ok=True) |
| out_path = os.path.join(out_dir, "analytical_h1_best.pth") |
| torch.save(state, out_path) |
|
|
| n_params = sum(v.numel() for k, v in state.items() if isinstance(v, torch.Tensor)) |
| elapsed = time.time() - t0 |
| print(f"\nSaved: {out_path}") |
| print(f"Total params: {n_params:,}") |
| print(f"Construction time: {elapsed:.0f}s") |
| print(f"\nClassification: 768 dims, {feat_dim * NUM_CLASSES + NUM_CLASSES:,} params") |
| print(f"Regression: {reg_dim} dims (768+768+768), {reg_dim * 4 + 4:,} params") |
| print(f"Centerness: 768 dims, {feat_dim + 1:,} params") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|