| """ |
| GPU-accelerated greedy forward construction of a minimal detection head. |
| |
| Batches all candidate evaluations into parallel matmuls on GPU. |
| 50 greedy steps in seconds instead of minutes. |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "coco") |
| VAL_CACHE = os.environ.get("ARENA_VAL_CACHE", "val_cache/val.pt") |
| NUM_CLASSES = 80 |
|
|
|
|
| def cofiber_decompose(f, n_scales): |
| cofibers = [] |
| residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega) |
| residual = omega |
| cofibers.append(residual) |
| return cofibers |
|
|
|
|
| def make_locations(sizes, strides): |
| locs = [] |
| for (h, w), s in zip(sizes, strides): |
| ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") |
| locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) |
| return locs |
|
|
|
|
| def assign_targets(loc, boxes, labels, stride, sr): |
| n = loc.shape[0] |
| if boxes.numel() == 0: |
| return torch.full((n,), -1, dtype=torch.long), torch.zeros(n, 4) |
| areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
| l = loc[:, None, 0] - boxes[None, :, 0] |
| t = loc[:, None, 1] - boxes[None, :, 1] |
| r = boxes[None, :, 2] - loc[:, None, 0] |
| b = boxes[None, :, 3] - loc[:, None, 1] |
| ltrb = torch.stack([l, t, r, b], -1) |
| in_box = ltrb.min(-1).values > 0 |
| cx = (boxes[:, 0] + boxes[:, 2]) / 2 |
| cy = (boxes[:, 1] + boxes[:, 3]) / 2 |
| rad = stride * 1.5 |
| in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & |
| (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) |
| max_d = ltrb.max(-1).values |
| in_level = (max_d >= sr[0]) & (max_d <= sr[1]) |
| pos = in_box & in_center & in_level |
| a = areas[None, :].expand_as(pos).clone() |
| a[~pos] = float("inf") |
| matched = a.argmin(1) |
| is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") |
| ct = torch.full((n,), -1, dtype=torch.long) |
| ct[is_pos] = labels[matched[is_pos]] |
| rt = torch.zeros(n, 4) |
| if is_pos.any(): |
| rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]] |
| return ct, rt |
|
|
|
|
| def build_val_data(val_path, n_images=500, device="cuda"): |
| """Build feature matrix + targets on GPU.""" |
| val = torch.load(val_path, map_location="cpu", weights_only=False) |
| from pycocotools.coco import COCO |
| ann_file = os.path.join(COCO_ROOT, "annotations", "instances_val2017.json") |
| coco = COCO(ann_file) |
| cat_ids = sorted(coco.getCatIds()) |
| cat_to_idx = {c: i for i, c in enumerate(cat_ids)} |
|
|
| strides = [16, 32, 64] |
| H = 640 // 16 |
| sizes = [(H, H), (H // 2, H // 2), (H // 4, H // 4)] |
| sr = [(-1, 128), (128, 256), (256, float("inf"))] |
| locs = make_locations(sizes, strides) |
|
|
| all_f, all_cls = [], [] |
| for idx in range(min(n_images, len(val))): |
| item = val[idx] |
| spatial = item["spatial"].unsqueeze(0).float() |
| img_id = item["img_id"] |
| scale = item["scale"] |
| ann_ids = coco.getAnnIds(imgIds=int(img_id), iscrowd=False) |
| anns = coco.loadAnns(ann_ids) |
| boxes, labels = [], [] |
| for ann in anns: |
| x, y, w, h = ann["bbox"] |
| if w < 1 or h < 1: |
| continue |
| boxes.append([x * scale, y * scale, (x + w) * scale, (y + h) * scale]) |
| labels.append(cat_to_idx[ann["category_id"]]) |
| boxes_t = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros(0, 4) |
| labels_t = torch.tensor(labels, dtype=torch.long) if labels else torch.zeros(0, dtype=torch.long) |
|
|
| cofibers = cofiber_decompose(spatial, 3) |
| for sci, cof in enumerate(cofibers): |
| B, C, Hc, Wc = cof.shape |
| f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C]) |
| ct, _ = assign_targets(locs[sci], boxes_t, labels_t, strides[sci], sr[sci]) |
| all_f.append(f) |
| all_cls.append(ct) |
|
|
| features = torch.cat(all_f).to(device) |
| cls_targets = torch.cat(all_cls).to(device) |
| return features, cls_targets |
|
|
|
|
| def greedy_step_gpu(features, cls_targets, selected, remaining, lam=0.1): |
| """Test all remaining candidates in parallel on GPU. Return best dim and accuracy.""" |
| pos = cls_targets >= 0 |
| n_pos = pos.sum().item() |
| if n_pos == 0: |
| return -1, 0.0 |
|
|
| |
| f_pos = features[pos] |
| y_cls = torch.zeros(n_pos, NUM_CLASSES, device=features.device) |
| y_cls[torch.arange(n_pos, device=features.device), cls_targets[pos]] = 1.0 |
| gt = cls_targets[pos] |
|
|
| best_dim = -1 |
| best_acc = -1.0 |
|
|
| |
| |
| chunk_size = 64 |
| for chunk_start in range(0, len(remaining), chunk_size): |
| chunk = remaining[chunk_start:chunk_start + chunk_size] |
| accs = [] |
| for d in chunk: |
| dims = selected + [d] |
| fd = len(dims) |
| fp = f_pos[:, dims] |
| fa = torch.cat([fp, torch.ones(n_pos, 1, device=fp.device)], 1) |
| I = torch.eye(fd + 1, device=fp.device) |
| XtX = fa.T @ fa |
| XtY = fa.T @ y_cls |
| try: |
| W = torch.linalg.solve(XtX + lam * I * n_pos, XtY) |
| except Exception: |
| accs.append(0.0) |
| continue |
| |
| scores = fp @ W[:fd] + W[fd] |
| pred = scores.argmax(1) |
| acc = (pred == gt).float().mean().item() |
| accs.append(acc) |
|
|
| for i, d in enumerate(chunk): |
| if accs[i] > best_acc: |
| best_acc = accs[i] |
| best_dim = d |
|
|
| return best_dim, best_acc |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--max-dims", type=int, default=100) |
| parser.add_argument("--n-eval", type=int, default=500) |
| parser.add_argument("--lam", type=float, default=0.1) |
| args = parser.parse_args() |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
| print("=" * 60) |
| print(f"GPU Greedy Forward Construction (max {args.max_dims} dims)") |
| print("=" * 60, flush=True) |
|
|
| print("Building val data...", flush=True) |
| features, cls_targets = build_val_data(VAL_CACHE, args.n_eval, device) |
| pos = cls_targets >= 0 |
| print(f" {features.shape[0]} locations, {pos.sum().item()} positives, " |
| f"{features.shape[1]} dims", flush=True) |
|
|
| selected = [] |
| remaining = list(range(768)) |
| history = [] |
| t0 = time.time() |
|
|
| for step in range(args.max_dims): |
| t_step = time.time() |
| best_dim, best_acc = greedy_step_gpu(features, cls_targets, selected, remaining, args.lam) |
|
|
| if best_dim < 0: |
| break |
|
|
| selected.append(best_dim) |
| remaining.remove(best_dim) |
| step_time = time.time() - t_step |
|
|
| n_params = len(selected) * NUM_CLASSES + NUM_CLASSES |
| entry = {"step": step + 1, "dim": best_dim, "cls_acc": round(best_acc, 4), |
| "n_params": n_params, "step_ms": round(step_time * 1000)} |
| history.append(entry) |
|
|
| print(f" step {step+1:3d}: +dim{best_dim:3d} -> cls_acc={best_acc:.4f} " |
| f"({len(selected)} dims, {n_params} params, {step_time*1000:.0f}ms)", flush=True) |
|
|
| |
| if len(history) >= 10: |
| recent_gain = history[-1]["cls_acc"] - history[-10]["cls_acc"] |
| if recent_gain < 0.005: |
| print(f" Converged: <0.5% gain in 10 steps", flush=True) |
| break |
|
|
| elapsed = time.time() - t0 |
| print(f"\n{'='*60}") |
| print(f"Selected {len(selected)} dimensions in {elapsed:.1f}s") |
| print(f"Final cls_acc: {history[-1]['cls_acc']:.4f}") |
| print(f"Final params: {history[-1]['n_params']}") |
| print(f"\nTop 20 dimensions (most to least important):") |
| for h in history[:20]: |
| print(f" step {h['step']:2d}: dim{h['dim']:3d} cumul_acc={h['cls_acc']:.4f} ({h['step_ms']}ms)") |
|
|
| |
| result = {"selected_dims": selected, "history": history, |
| "final_cls_acc": history[-1]["cls_acc"], "final_params": history[-1]["n_params"], |
| "total_time_s": round(elapsed, 1)} |
| out = os.path.join(SCRIPT_DIR, "analytical_variants", "greedy_forward_gpu.json") |
| os.makedirs(os.path.dirname(out), exist_ok=True) |
| with open(out, "w") as f: |
| json.dump(result, f, indent=2) |
| print(f"\nSaved: {out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|