""" Deep narrow head on evolved feature dims. Analytical evolution selected the 100 most informative dims. Now train a deep nonlinear MLP on those 100 dims with spatial context. """ import argparse import json import math import os import sys import time import torch import torch.nn as nn import torch.nn.functional as F SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") RESOLUTION = 640 NUM_CLASSES = 80 def cofiber_decompose(f, n_scales): cofibers = []; residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega); residual = omega cofibers.append(residual); return cofibers class EvolvedDeepHead(nn.Module): """Deep MLP on evolved feature dims with spatial depthwise convolutions.""" def __init__(self, evolved_dims, hidden=128, n_layers=10, n_scales=3): super().__init__() self.evolved_dims = evolved_dims self.n_scales = n_scales K = len(evolved_dims) self.dim_idx = nn.Parameter(torch.tensor(evolved_dims, dtype=torch.long), requires_grad=False) self.scale_norms = nn.ModuleList([nn.LayerNorm(768) for _ in range(n_scales)]) # Deep MLP with interleaved spatial convolutions layers = [] in_dim = K for i in range(n_layers): layers.append(nn.Linear(in_dim, hidden)) layers.append(nn.GELU()) if i % 2 == 1: # spatial conv every other layer layers.append(SpatialDWConv(hidden)) in_dim = hidden self.backbone = nn.Sequential(*layers) # Separate output heads self.cls_head = nn.Linear(hidden, NUM_CLASSES) self.reg_head = nn.Linear(hidden, 4) self.ctr_head = nn.Linear(hidden, 1) self.scale_params = nn.Parameter(torch.ones(n_scales)) def forward(self, spatial): cofibers = cofiber_decompose(spatial, self.n_scales) cls_l, reg_l, ctr_l = [], [], [] for i, cof in enumerate(cofibers): B, C, H, W = cof.shape f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C)) # Select evolved dims f_sel = f[:, self.dim_idx] # Deep MLP with spatial context # Need to reshape for spatial convs f_sel = f_sel.reshape(B, H, W, -1) h = self._forward_with_spatial(f_sel, B, H, W) # Output heads cls = self.cls_head(h.reshape(-1, h.shape[-1])).reshape(B, H, W, -1).permute(0, 3, 1, 2) reg_raw = (self.reg_head(h.reshape(-1, h.shape[-1])) * self.scale_params[i]).clamp(-10, 10) reg = reg_raw.exp().reshape(B, H, W, 4).permute(0, 3, 1, 2) ctr = self.ctr_head(h.reshape(-1, h.shape[-1])).reshape(B, H, W, 1).permute(0, 3, 1, 2) cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr) return cls_l, reg_l, ctr_l def _forward_with_spatial(self, x, B, H, W): """Run the backbone layers, reshaping for spatial convs.""" # x: (B, H, W, K) for layer in self.backbone: if isinstance(layer, SpatialDWConv): x = layer(x, B, H, W) elif isinstance(layer, nn.Linear): x = layer(x) elif isinstance(layer, nn.GELU): x = layer(x) return x class SpatialDWConv(nn.Module): """Depthwise 3x3 conv that operates on (B, H, W, C) tensors.""" def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, padding=1, groups=channels) def forward(self, x, B, H, W): # x: (B, H, W, C) or (B*H*W, C) if x.dim() == 4: c = x.shape[-1] x = x.permute(0, 3, 1, 2) # (B, C, H, W) x = self.conv(x) x = x.permute(0, 2, 3, 1) # (B, H, W, C) return x def make_locations(feature_sizes, strides, device): locs = [] for (h, w), s in zip(feature_sizes, strides): ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s gy, gx = torch.meshgrid(ys, xs, indexing="ij") locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) return locs def main(): parser = argparse.ArgumentParser() parser.add_argument("--hidden", type=int, default=128) parser.add_argument("--layers", type=int, default=10) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--batch-size", type=int, default=128) parser.add_argument("--lr", type=float, default=1e-3) args = parser.parse_args() # Load evolved dims evolved_path = os.path.join(SCRIPT_DIR, "circuit", "evolved_extreme.json") with open(evolved_path) as f: evolved = json.load(f) dims = None for r in evolved: if r["K"] == 100: dims = sorted(list(set(r["genome"]))) break if dims is None: print("No K=100 genome found"); return print("=" * 60) print(f"Deep Evolved Head: {len(dims)} dims, {args.hidden} hidden, {args.layers} layers") print("=" * 60, flush=True) head = EvolvedDeepHead(dims, hidden=args.hidden, n_layers=args.layers).cuda() n_params = sum(p.numel() for p in head.parameters() if p.requires_grad) print(f" {n_params:,} trainable params", flush=True) # Training setup from cache_and_train_fast import compute_loss manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) n_shards = manifest["n_shards"] n_images = manifest["n_images"] steps_per_epoch = n_images // args.batch_size total_steps = steps_per_epoch * args.epochs warmup = int(total_steps * 0.03) optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda s: s / max(warmup, 1) if s < warmup else 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(total_steps - warmup, 1)))) strides = [16, 32, 64] H = RESOLUTION // 16 locs = make_locations([(H, H), (H//2, H//2), (H//4, H//4)], strides, torch.device("cuda")) shard_paths = [os.path.join(CACHE_DIR, f"shard_{i:04d}.pt") for i in range(n_shards)] print(f" {n_images} images, batch {args.batch_size}, {total_steps} steps, {args.epochs} epochs") print(f" Training...", flush=True) head.train() global_step = 0 t0 = time.time() for epoch in range(args.epochs): shard_order = torch.randperm(n_shards).tolist() epoch_t0 = time.time() for shard_idx in shard_order: if global_step >= total_steps: break shard = torch.load(shard_paths[shard_idx], map_location="cpu", weights_only=False) within = torch.randperm(len(shard)).tolist() for batch_start in range(0, len(shard), args.batch_size): if global_step >= total_steps: break batch_idx = within[batch_start:batch_start + args.batch_size] if len(batch_idx) < 2: continue spatial = torch.stack([shard[i]["spatial"] for i in batch_idx]).float().cuda() boxes = [shard[i]["boxes"].cuda() for i in batch_idx] labels = [shard[i]["labels"].cuda() for i in batch_idx] try: cls_l, reg_l, ctr_l = head(spatial) loss = compute_loss(cls_l, reg_l, ctr_l, locs, boxes, labels) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(head.parameters(), 5.0) optimizer.step() scheduler.step() if global_step % 200 == 0: torch.cuda.synchronize() global_step += 1 if global_step % 100 == 0: lr = scheduler.get_last_lr()[0] elapsed = time.time() - t0 print(f" step {global_step}/{total_steps} (ep {epoch+1}) " f"loss={loss.item():.4f} lr={lr:.2e} " f"{global_step/elapsed:.1f} it/s", flush=True) except RuntimeError as e: if "out of memory" in str(e): torch.cuda.empty_cache() optimizer.zero_grad() global_step += 1 scheduler.step() continue raise del shard print(f" Epoch {epoch+1}/{args.epochs} complete ({time.time()-epoch_t0:.0f}s)", flush=True) # Save out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "evolved_deep") os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, f"evolved_deep_{args.hidden}h_{args.layers}l_{args.epochs}ep.pth") torch.save(head.state_dict(), out_path) elapsed = time.time() - t0 print(f"\nSaved: {out_path}") print(f"{n_params:,} params, {elapsed/60:.1f} minutes") if __name__ == "__main__": main()