""" Fractal cofiber decomposition — wavelet packet style. Instead of recursing only on the low-frequency residual (3 bands), recurse on BOTH the cofiber and residual at each level. Depth 1: 2 bands (standard single split) Depth 2: 4 bands Depth 3: 8 bands Each band is 768 dims. Classification and regression are solved independently per band, then results are merged. The solver picks which bands matter. Or: concatenate all bands and solve one large system. """ import json, os, sys, time import torch, torch.nn.functional as F SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") DEVICE = "cuda" RESOLUTION = 640 NUM_CLASSES = 80 def fractal_decompose(f, depth): """Fractal cofiber decomposition. Returns list of 2^depth feature maps.""" if depth == 0: return [f] omega = F.avg_pool2d(f, 2) sigma_omega = F.interpolate(omega, size=f.shape[2:], mode="bilinear", align_corners=False) cofiber = f - sigma_omega # high frequency at this scale # Recurse on BOTH branches high_bands = fractal_decompose(cofiber, depth - 1) low_bands = fractal_decompose(omega, depth - 1) return high_bands + low_bands def standard_decompose(f, n_scales): """Standard cofiber: recurse only on residual.""" cofibers = [] residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega) residual = omega cofibers.append(residual) return cofibers def make_locations(sizes, strides): locs = [] for (h, w), s in zip(sizes, strides): ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s gy, gx = torch.meshgrid(ys, xs, indexing="ij") locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) return locs def assign_targets(loc, boxes, labels, stride, sr): n = loc.shape[0] ct = torch.full((n,), -1, dtype=torch.long) rt = torch.zeros(n, 4) ctrt = torch.zeros(n) if boxes.numel() == 0: return ct, rt, ctrt areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) l = loc[:, None, 0] - boxes[None, :, 0] t = loc[:, None, 1] - boxes[None, :, 1] r = boxes[None, :, 2] - loc[:, None, 0] b = boxes[None, :, 3] - loc[:, None, 1] ltrb = torch.stack([l, t, r, b], -1) in_box = ltrb.min(-1).values > 0 cx = (boxes[:, 0] + boxes[:, 2]) / 2 cy = (boxes[:, 1] + boxes[:, 3]) / 2 rad = stride * 1.5 in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) max_d = ltrb.max(-1).values in_level = (max_d >= sr[0]) & (max_d <= sr[1]) pos = in_box & in_center & in_level a = areas[None, :].expand_as(pos).clone() a[~pos] = float("inf") matched = a.argmin(1) is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") ct[is_pos] = labels[matched[is_pos]] if is_pos.any(): rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]] lp, tp, rp, bp = rt[is_pos].unbind(-1) ctrt[is_pos] = torch.sqrt( (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) * (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6))) return ct, rt, ctrt def eval_decomposition(val, coco_gt, cat_ids, decompose_fn, name, lam=0.1, n_train=10000): """Accumulate, solve, and eval a decomposition variant. All bands are upsampled to stride-16 resolution (40x40) and use the same target assignment. The decomposition separates frequencies, not resolutions. """ idx_to_cat = {i: c for i, c in enumerate(cat_ids)} H = RESOLUTION // 16 target_size = (H, H) stride = 16 sr = (-1, float("inf")) # single scale, all object sizes locs_flat = make_locations([target_size], [stride]) n_locs = H * H manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) feat_dim = 768 cls_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) cls_XtY = torch.zeros(feat_dim + 1, NUM_CLASSES, device=DEVICE) reg_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) reg_XtY = torch.zeros(feat_dim + 1, 4, device=DEVICE) ctr_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE) ctr_XtY = torch.zeros(feat_dim + 1, 1, device=DEVICE) n_pos = 0; seen = 0 t0 = time.time() for si in range(manifest["n_shards"]): if seen >= n_train: break shard = torch.load(os.path.join(CACHE_DIR, f"shard_{si:04d}.pt"), map_location="cpu", weights_only=False) for item in shard: if seen >= n_train: break sp = item["spatial"].unsqueeze(0).float().to(DEVICE) boxes = item["boxes"]; labels = item["labels"] bands = decompose_fn(sp) # Upsample all bands to 40x40, average them upsampled = [] for band in bands: if band.shape[2:] != target_size: band = F.interpolate(band, size=target_size, mode="bilinear", align_corners=False) upsampled.append(band) # Average across all bands — the solver sees the mean multi-frequency representation merged = torch.stack(upsampled).mean(0) # (1, 768, 40, 40) B, C, Hc, Wc = merged.shape f = F.layer_norm(merged.permute(0, 2, 3, 1).reshape(-1, C), [C]) ct, rt, ctrt = assign_targets(locs_flat[0], boxes, labels, stride, sr) pos_mask = ct >= 0 if not pos_mask.any(): seen += 1; continue fp = f[pos_mask] fa = torch.cat([fp, torch.ones(fp.shape[0], 1, device=DEVICE)], 1) yc = torch.zeros(fp.shape[0], NUM_CLASSES, device=DEVICE) yc[torch.arange(fp.shape[0], device=DEVICE), ct[pos_mask].to(DEVICE)] = 1.0 cls_XtX += fa.T @ fa; cls_XtY += fa.T @ yc ltrb = rt[pos_mask]; valid = (ltrb > 0).all(1) if valid.any(): fv = fa[valid]; yt = torch.log(ltrb[valid]).to(DEVICE) reg_XtX += fv.T @ fv; reg_XtY += fv.T @ yt ctr_XtX += fa.T @ fa ctr_XtY += fa.T @ ctrt[pos_mask].unsqueeze(1).to(DEVICE) n_pos += pos_mask.sum().item() seen += 1 del shard I = torch.eye(feat_dim + 1, device=DEVICE) cls_W = torch.linalg.solve(cls_XtX + lam * I * n_pos, cls_XtY) reg_W = torch.linalg.solve(reg_XtX + lam * I * n_pos, reg_XtY) ctr_W = torch.linalg.solve(ctr_XtX + lam * I * n_pos, ctr_XtY) accum_time = time.time() - t0 all_locs = locs_flat[0].to(DEVICE) all_results = [] for idx in range(len(val)): spatial = val[idx]["spatial"].unsqueeze(0).float().to(DEVICE) img_id = int(val[idx]["img_id"]); scale = val[idx]["scale"] bands = decompose_fn(spatial) upsampled = [] for band in bands: if band.shape[2:] != target_size: band = F.interpolate(band, size=target_size, mode="bilinear", align_corners=False) upsampled.append(band) merged = torch.stack(upsampled).mean(0) B, C, Hc, Wc = merged.shape f = F.layer_norm(merged.permute(0, 2, 3, 1).reshape(-1, C), [C]) cls_s = (f @ cls_W[:feat_dim] + cls_W[feat_dim]).sigmoid() reg_s = (f @ reg_W[:feat_dim] + reg_W[feat_dim]).exp() ctr_s = (f @ ctr_W[:feat_dim] + ctr_W[feat_dim]).sigmoid().squeeze(1) scores = cls_s * ctr_s.unsqueeze(1) max_s, max_c = scores.max(1) topk = min(100, max_s.shape[0]) top_s, top_i = max_s.topk(topk) tc = max_c[top_i]; tr = reg_s[top_i]; tl = all_locs[top_i] x1 = (tl[:,0]-tr[:,0])/scale; y1 = (tl[:,1]-tr[:,1])/scale x2 = (tl[:,0]+tr[:,2])/scale; y2 = (tl[:,1]+tr[:,3])/scale w = (x2-x1).clamp(min=0); h = (y2-y1).clamp(min=0) for i in range(topk): s = top_s[i].item() if s < 0.01: continue all_results.append({"image_id": img_id, "category_id": idx_to_cat[tc[i].item()], "bbox": [x1[i].item(), y1[i].item(), w[i].item(), h[i].item()], "score": s}) # pycocotools eval from pycocotools.cocoeval import COCOeval if not all_results: print(f" {name}: no detections"); return 0.0 coco_dt = coco_gt.loadRes(all_results) coco_eval = COCOeval(coco_gt, coco_dt, "bbox") coco_eval.params.imgIds = sorted(coco_gt.getImgIds())[:len(val)] coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize() mAP = coco_eval.stats[0] mAP50 = coco_eval.stats[1] mAP75 = coco_eval.stats[2] print(f" {name}: mAP={mAP:.4f} mAP50={mAP50:.4f} mAP75={mAP75:.4f} " f"({accum_time:.0f}s accum, {n_pos} pos)") return mAP def main(): from pycocotools.coco import COCO print("=" * 60) print("Fractal vs Standard Cofiber Decomposition") print("=" * 60, flush=True) val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False) ann_file = os.path.join(COCO_ROOT, "annotations", "instances_val2017.json") coco_gt = COCO(ann_file) cat_ids = sorted(coco_gt.getCatIds()) results = [] # Standard 3-band cofiber (baseline) print("\n1. Standard 3-band cofiber:", flush=True) mAP = eval_decomposition(val, coco_gt, cat_ids, lambda sp: standard_decompose(sp, 3), "standard_3band") results.append({"name": "standard_3band", "mAP": mAP, "bands": 3}) # Fractal depth 2 (4 bands) print("\n2. Fractal depth 2 (4 bands):", flush=True) mAP = eval_decomposition(val, coco_gt, cat_ids, lambda sp: fractal_decompose(sp, 2), "fractal_depth2") results.append({"name": "fractal_depth2", "mAP": mAP, "bands": 4}) # Fractal depth 3 (8 bands) print("\n3. Fractal depth 3 (8 bands):", flush=True) mAP = eval_decomposition(val, coco_gt, cat_ids, lambda sp: fractal_decompose(sp, 3), "fractal_depth3") results.append({"name": "fractal_depth3", "mAP": mAP, "bands": 8}) print(f"\n{'='*60}") print("Summary:") for r in results: print(f" {r['name']:20s}: mAP={r['mAP']:.4f} ({r['bands']} bands)") out = os.path.join(SCRIPT_DIR, "analytical_variants", "fractal_results.json") with open(out, "w") as f: json.dump(results, f, indent=2) print(f"Saved: {out}") if __name__ == "__main__": main()