""" GPU-accelerated greedy forward construction of a minimal detection head. Batches all candidate evaluations into parallel matmuls on GPU. 50 greedy steps in seconds instead of minutes. """ import argparse import json import os import sys import time import torch import torch.nn.functional as F SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "coco") VAL_CACHE = os.environ.get("ARENA_VAL_CACHE", "val_cache/val.pt") NUM_CLASSES = 80 def cofiber_decompose(f, n_scales): cofibers = [] residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega) residual = omega cofibers.append(residual) return cofibers def make_locations(sizes, strides): locs = [] for (h, w), s in zip(sizes, strides): ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s gy, gx = torch.meshgrid(ys, xs, indexing="ij") locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) return locs def assign_targets(loc, boxes, labels, stride, sr): n = loc.shape[0] if boxes.numel() == 0: return torch.full((n,), -1, dtype=torch.long), torch.zeros(n, 4) areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) l = loc[:, None, 0] - boxes[None, :, 0] t = loc[:, None, 1] - boxes[None, :, 1] r = boxes[None, :, 2] - loc[:, None, 0] b = boxes[None, :, 3] - loc[:, None, 1] ltrb = torch.stack([l, t, r, b], -1) in_box = ltrb.min(-1).values > 0 cx = (boxes[:, 0] + boxes[:, 2]) / 2 cy = (boxes[:, 1] + boxes[:, 3]) / 2 rad = stride * 1.5 in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) max_d = ltrb.max(-1).values in_level = (max_d >= sr[0]) & (max_d <= sr[1]) pos = in_box & in_center & in_level a = areas[None, :].expand_as(pos).clone() a[~pos] = float("inf") matched = a.argmin(1) is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") ct = torch.full((n,), -1, dtype=torch.long) ct[is_pos] = labels[matched[is_pos]] rt = torch.zeros(n, 4) if is_pos.any(): rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]] return ct, rt def build_val_data(val_path, n_images=500, device="cuda"): """Build feature matrix + targets on GPU.""" val = torch.load(val_path, map_location="cpu", weights_only=False) from pycocotools.coco import COCO ann_file = os.path.join(COCO_ROOT, "annotations", "instances_val2017.json") coco = COCO(ann_file) cat_ids = sorted(coco.getCatIds()) cat_to_idx = {c: i for i, c in enumerate(cat_ids)} strides = [16, 32, 64] H = 640 // 16 sizes = [(H, H), (H // 2, H // 2), (H // 4, H // 4)] sr = [(-1, 128), (128, 256), (256, float("inf"))] locs = make_locations(sizes, strides) all_f, all_cls = [], [] for idx in range(min(n_images, len(val))): item = val[idx] spatial = item["spatial"].unsqueeze(0).float() img_id = item["img_id"] scale = item["scale"] ann_ids = coco.getAnnIds(imgIds=int(img_id), iscrowd=False) anns = coco.loadAnns(ann_ids) boxes, labels = [], [] for ann in anns: x, y, w, h = ann["bbox"] if w < 1 or h < 1: continue boxes.append([x * scale, y * scale, (x + w) * scale, (y + h) * scale]) labels.append(cat_to_idx[ann["category_id"]]) boxes_t = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros(0, 4) labels_t = torch.tensor(labels, dtype=torch.long) if labels else torch.zeros(0, dtype=torch.long) cofibers = cofiber_decompose(spatial, 3) for sci, cof in enumerate(cofibers): B, C, Hc, Wc = cof.shape f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) ct, _ = assign_targets(locs[sci], boxes_t, labels_t, strides[sci], sr[sci]) all_f.append(f) all_cls.append(ct) features = torch.cat(all_f).to(device) cls_targets = torch.cat(all_cls).to(device) return features, cls_targets def greedy_step_gpu(features, cls_targets, selected, remaining, lam=0.1): """Test all remaining candidates in parallel on GPU. Return best dim and accuracy.""" pos = cls_targets >= 0 n_pos = pos.sum().item() if n_pos == 0: return -1, 0.0 # Build one-hot targets f_pos = features[pos] y_cls = torch.zeros(n_pos, NUM_CLASSES, device=features.device) y_cls[torch.arange(n_pos, device=features.device), cls_targets[pos]] = 1.0 gt = cls_targets[pos] best_dim = -1 best_acc = -1.0 # For each candidate, solve and score # Batch in chunks to avoid OOM on very large candidate sets chunk_size = 64 for chunk_start in range(0, len(remaining), chunk_size): chunk = remaining[chunk_start:chunk_start + chunk_size] accs = [] for d in chunk: dims = selected + [d] fd = len(dims) fp = f_pos[:, dims] fa = torch.cat([fp, torch.ones(n_pos, 1, device=fp.device)], 1) I = torch.eye(fd + 1, device=fp.device) XtX = fa.T @ fa XtY = fa.T @ y_cls try: W = torch.linalg.solve(XtX + lam * I * n_pos, XtY) except Exception: accs.append(0.0) continue # Score on all positive locations scores = fp @ W[:fd] + W[fd] # (n_pos, 80) pred = scores.argmax(1) acc = (pred == gt).float().mean().item() accs.append(acc) for i, d in enumerate(chunk): if accs[i] > best_acc: best_acc = accs[i] best_dim = d return best_dim, best_acc def main(): parser = argparse.ArgumentParser() parser.add_argument("--max-dims", type=int, default=100) parser.add_argument("--n-eval", type=int, default=500) parser.add_argument("--lam", type=float, default=0.1) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") print("=" * 60) print(f"GPU Greedy Forward Construction (max {args.max_dims} dims)") print("=" * 60, flush=True) print("Building val data...", flush=True) features, cls_targets = build_val_data(VAL_CACHE, args.n_eval, device) pos = cls_targets >= 0 print(f" {features.shape[0]} locations, {pos.sum().item()} positives, " f"{features.shape[1]} dims", flush=True) selected = [] remaining = list(range(768)) history = [] t0 = time.time() for step in range(args.max_dims): t_step = time.time() best_dim, best_acc = greedy_step_gpu(features, cls_targets, selected, remaining, args.lam) if best_dim < 0: break selected.append(best_dim) remaining.remove(best_dim) step_time = time.time() - t_step n_params = len(selected) * NUM_CLASSES + NUM_CLASSES # cls only for now entry = {"step": step + 1, "dim": best_dim, "cls_acc": round(best_acc, 4), "n_params": n_params, "step_ms": round(step_time * 1000)} history.append(entry) print(f" step {step+1:3d}: +dim{best_dim:3d} -> cls_acc={best_acc:.4f} " f"({len(selected)} dims, {n_params} params, {step_time*1000:.0f}ms)", flush=True) # Early stopping if len(history) >= 10: recent_gain = history[-1]["cls_acc"] - history[-10]["cls_acc"] if recent_gain < 0.005: print(f" Converged: <0.5% gain in 10 steps", flush=True) break elapsed = time.time() - t0 print(f"\n{'='*60}") print(f"Selected {len(selected)} dimensions in {elapsed:.1f}s") print(f"Final cls_acc: {history[-1]['cls_acc']:.4f}") print(f"Final params: {history[-1]['n_params']}") print(f"\nTop 20 dimensions (most to least important):") for h in history[:20]: print(f" step {h['step']:2d}: dim{h['dim']:3d} cumul_acc={h['cls_acc']:.4f} ({h['step_ms']}ms)") # Save result = {"selected_dims": selected, "history": history, "final_cls_acc": history[-1]["cls_acc"], "final_params": history[-1]["n_params"], "total_time_s": round(elapsed, 1)} out = os.path.join(SCRIPT_DIR, "analytical_variants", "greedy_forward_gpu.json") os.makedirs(os.path.dirname(out), exist_ok=True) with open(out, "w") as f: json.dump(result, f, indent=2) print(f"\nSaved: {out}") if __name__ == "__main__": main()