| """ |
| Fractal cofiber decomposition — wavelet packet style. |
| |
| Instead of recursing only on the low-frequency residual (3 bands), |
| recurse on BOTH the cofiber and residual at each level. |
| |
| Depth 1: 2 bands (standard single split) |
| Depth 2: 4 bands |
| Depth 3: 8 bands |
| |
| Each band is 768 dims. Classification and regression are solved independently |
| per band, then results are merged. The solver picks which bands matter. |
| |
| Or: concatenate all bands and solve one large system. |
| """ |
|
|
| import json, os, sys, time |
| import torch, torch.nn.functional as F |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") |
| VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") |
| CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") |
| DEVICE = "cuda" |
| RESOLUTION = 640 |
| NUM_CLASSES = 80 |
|
|
|
|
| def fractal_decompose(f, depth): |
| """Fractal cofiber decomposition. Returns list of 2^depth feature maps.""" |
| if depth == 0: |
| return [f] |
| omega = F.avg_pool2d(f, 2) |
| sigma_omega = F.interpolate(omega, size=f.shape[2:], mode="bilinear", align_corners=False) |
| cofiber = f - sigma_omega |
|
|
| |
| high_bands = fractal_decompose(cofiber, depth - 1) |
| low_bands = fractal_decompose(omega, depth - 1) |
|
|
| return high_bands + low_bands |
|
|
|
|
| def standard_decompose(f, n_scales): |
| """Standard cofiber: recurse only on residual.""" |
| cofibers = [] |
| residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega) |
| residual = omega |
| cofibers.append(residual) |
| return cofibers |
|
|
|
|
| def make_locations(sizes, strides): |
| locs = [] |
| for (h, w), s in zip(sizes, strides): |
| ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") |
| locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) |
| return locs |
|
|
|
|
| def assign_targets(loc, boxes, labels, stride, sr): |
| n = loc.shape[0] |
| ct = torch.full((n,), -1, dtype=torch.long) |
| rt = torch.zeros(n, 4) |
| ctrt = torch.zeros(n) |
| if boxes.numel() == 0: |
| return ct, rt, ctrt |
| areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
| l = loc[:, None, 0] - boxes[None, :, 0] |
| t = loc[:, None, 1] - boxes[None, :, 1] |
| r = boxes[None, :, 2] - loc[:, None, 0] |
| b = boxes[None, :, 3] - loc[:, None, 1] |
| ltrb = torch.stack([l, t, r, b], -1) |
| in_box = ltrb.min(-1).values > 0 |
| cx = (boxes[:, 0] + boxes[:, 2]) / 2 |
| cy = (boxes[:, 1] + boxes[:, 3]) / 2 |
| rad = stride * 1.5 |
| in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & |
| (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) |
| max_d = ltrb.max(-1).values |
| in_level = (max_d >= sr[0]) & (max_d <= sr[1]) |
| pos = in_box & in_center & in_level |
| a = areas[None, :].expand_as(pos).clone() |
| a[~pos] = float("inf") |
| matched = a.argmin(1) |
| is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") |
| ct[is_pos] = labels[matched[is_pos]] |
| if is_pos.any(): |
| rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]] |
| lp, tp, rp, bp = rt[is_pos].unbind(-1) |
| ctrt[is_pos] = torch.sqrt( |
| (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) * |
| (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6))) |
| return ct, rt, ctrt |
|
|
|
|
| def eval_decomposition(val, coco_gt, cat_ids, decompose_fn, name, lam=0.1, n_train=10000): |
| """Accumulate, solve, and eval a decomposition variant. |
| |
| All bands are upsampled to stride-16 resolution (40x40) and use the same |
| target assignment. The decomposition separates frequencies, not resolutions. |
| """ |
| idx_to_cat = {i: c for i, c in enumerate(cat_ids)} |
| H = RESOLUTION // 16 |
| target_size = (H, H) |
| stride = 16 |
| sr = (-1, float("inf")) |
| locs_flat = make_locations([target_size], [stride]) |
| n_locs = H * H |
|
|
| manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) |
| feat_dim = 768 |
| cls_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) |
| cls_XtY = torch.zeros(feat_dim + 1, NUM_CLASSES, device=DEVICE) |
| reg_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) |
| reg_XtY = torch.zeros(feat_dim + 1, 4, device=DEVICE) |
| ctr_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) |
| ctr_XtY = torch.zeros(feat_dim + 1, 1, device=DEVICE) |
| n_pos = 0; seen = 0 |
|
|
| t0 = time.time() |
| for si in range(manifest["n_shards"]): |
| if seen >= n_train: break |
| shard = torch.load(os.path.join(CACHE_DIR, f"shard_{si:04d}.pt"), |
| map_location="cpu", weights_only=False) |
| for item in shard: |
| if seen >= n_train: break |
| sp = item["spatial"].unsqueeze(0).float().to(DEVICE) |
| boxes = item["boxes"]; labels = item["labels"] |
| bands = decompose_fn(sp) |
|
|
| |
| upsampled = [] |
| for band in bands: |
| if band.shape[2:] != target_size: |
| band = F.interpolate(band, size=target_size, mode="bilinear", align_corners=False) |
| upsampled.append(band) |
| |
| merged = torch.stack(upsampled).mean(0) |
|
|
| B, C, Hc, Wc = merged.shape |
| f = F.layer_norm(merged.permute(0, 2, 3, 1).reshape(-1, C), [C]) |
| ct, rt, ctrt = assign_targets(locs_flat[0], boxes, labels, stride, sr) |
| pos_mask = ct >= 0 |
| if not pos_mask.any(): |
| seen += 1; continue |
| fp = f[pos_mask] |
| fa = torch.cat([fp, torch.ones(fp.shape[0], 1, device=DEVICE)], 1) |
| yc = torch.zeros(fp.shape[0], NUM_CLASSES, device=DEVICE) |
| yc[torch.arange(fp.shape[0], device=DEVICE), ct[pos_mask].to(DEVICE)] = 1.0 |
| cls_XtX += fa.T @ fa; cls_XtY += fa.T @ yc |
| ltrb = rt[pos_mask]; valid = (ltrb > 0).all(1) |
| if valid.any(): |
| fv = fa[valid]; yt = torch.log(ltrb[valid]).to(DEVICE) |
| reg_XtX += fv.T @ fv; reg_XtY += fv.T @ yt |
| ctr_XtX += fa.T @ fa |
| ctr_XtY += fa.T @ ctrt[pos_mask].unsqueeze(1).to(DEVICE) |
| n_pos += pos_mask.sum().item() |
| seen += 1 |
| del shard |
|
|
| I = torch.eye(feat_dim + 1, device=DEVICE) |
| cls_W = torch.linalg.solve(cls_XtX + lam * I * n_pos, cls_XtY) |
| reg_W = torch.linalg.solve(reg_XtX + lam * I * n_pos, reg_XtY) |
| ctr_W = torch.linalg.solve(ctr_XtX + lam * I * n_pos, ctr_XtY) |
| accum_time = time.time() - t0 |
|
|
| all_locs = locs_flat[0].to(DEVICE) |
| all_results = [] |
| for idx in range(len(val)): |
| spatial = val[idx]["spatial"].unsqueeze(0).float().to(DEVICE) |
| img_id = int(val[idx]["img_id"]); scale = val[idx]["scale"] |
| bands = decompose_fn(spatial) |
| upsampled = [] |
| for band in bands: |
| if band.shape[2:] != target_size: |
| band = F.interpolate(band, size=target_size, mode="bilinear", align_corners=False) |
| upsampled.append(band) |
| merged = torch.stack(upsampled).mean(0) |
| B, C, Hc, Wc = merged.shape |
| f = F.layer_norm(merged.permute(0, 2, 3, 1).reshape(-1, C), [C]) |
| cls_s = (f @ cls_W[:feat_dim] + cls_W[feat_dim]).sigmoid() |
| reg_s = (f @ reg_W[:feat_dim] + reg_W[feat_dim]).exp() |
| ctr_s = (f @ ctr_W[:feat_dim] + ctr_W[feat_dim]).sigmoid().squeeze(1) |
| scores = cls_s * ctr_s.unsqueeze(1) |
| max_s, max_c = scores.max(1) |
| topk = min(100, max_s.shape[0]) |
| top_s, top_i = max_s.topk(topk) |
| tc = max_c[top_i]; tr = reg_s[top_i]; tl = all_locs[top_i] |
| x1 = (tl[:,0]-tr[:,0])/scale; y1 = (tl[:,1]-tr[:,1])/scale |
| x2 = (tl[:,0]+tr[:,2])/scale; y2 = (tl[:,1]+tr[:,3])/scale |
| w = (x2-x1).clamp(min=0); h = (y2-y1).clamp(min=0) |
| for i in range(topk): |
| s = top_s[i].item() |
| if s < 0.01: continue |
| all_results.append({"image_id": img_id, "category_id": idx_to_cat[tc[i].item()], |
| "bbox": [x1[i].item(), y1[i].item(), w[i].item(), h[i].item()], |
| "score": s}) |
|
|
| |
| from pycocotools.cocoeval import COCOeval |
| if not all_results: |
| print(f" {name}: no detections"); return 0.0 |
| coco_dt = coco_gt.loadRes(all_results) |
| coco_eval = COCOeval(coco_gt, coco_dt, "bbox") |
| coco_eval.params.imgIds = sorted(coco_gt.getImgIds())[:len(val)] |
| coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize() |
| mAP = coco_eval.stats[0] |
| mAP50 = coco_eval.stats[1] |
| mAP75 = coco_eval.stats[2] |
| print(f" {name}: mAP={mAP:.4f} mAP50={mAP50:.4f} mAP75={mAP75:.4f} " |
| f"({accum_time:.0f}s accum, {n_pos} pos)") |
| return mAP |
|
|
|
|
| def main(): |
| from pycocotools.coco import COCO |
|
|
| print("=" * 60) |
| print("Fractal vs Standard Cofiber Decomposition") |
| print("=" * 60, flush=True) |
|
|
| val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False) |
| ann_file = os.path.join(COCO_ROOT, "annotations", "instances_val2017.json") |
| coco_gt = COCO(ann_file) |
| cat_ids = sorted(coco_gt.getCatIds()) |
|
|
| results = [] |
|
|
| |
| print("\n1. Standard 3-band cofiber:", flush=True) |
| mAP = eval_decomposition(val, coco_gt, cat_ids, |
| lambda sp: standard_decompose(sp, 3), "standard_3band") |
| results.append({"name": "standard_3band", "mAP": mAP, "bands": 3}) |
|
|
| |
| print("\n2. Fractal depth 2 (4 bands):", flush=True) |
| mAP = eval_decomposition(val, coco_gt, cat_ids, |
| lambda sp: fractal_decompose(sp, 2), "fractal_depth2") |
| results.append({"name": "fractal_depth2", "mAP": mAP, "bands": 4}) |
|
|
| |
| print("\n3. Fractal depth 3 (8 bands):", flush=True) |
| mAP = eval_decomposition(val, coco_gt, cat_ids, |
| lambda sp: fractal_decompose(sp, 3), "fractal_depth3") |
| results.append({"name": "fractal_depth3", "mAP": mAP, "bands": 8}) |
|
|
| print(f"\n{'='*60}") |
| print("Summary:") |
| for r in results: |
| print(f" {r['name']:20s}: mAP={r['mAP']:.4f} ({r['bands']} bands)") |
|
|
| out = os.path.join(SCRIPT_DIR, "analytical_variants", "fractal_results.json") |
| with open(out, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"Saved: {out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|