""" Deep conv-stack detection head. Pure Conv2d — no reshaping overhead. Pointwise (1x1) convs replace Linear layers. Depthwise (3x3) convs provide spatial context. Runs at full GPU throughput with fp16. """ import argparse import json import math import os import sys import time import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import autocast, GradScaler SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") RESOLUTION = 640 NUM_CLASSES = 80 def cofiber_decompose(f, n_scales): cofibers = []; residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega); residual = omega cofibers.append(residual); return cofibers class ConvBlock(nn.Module): """Pointwise conv + GELU + Depthwise spatial conv.""" def __init__(self, in_ch, out_ch): super().__init__() self.pw = nn.Conv2d(in_ch, out_ch, 1) self.act = nn.GELU() self.dw = nn.Conv2d(out_ch, out_ch, 3, padding=1, groups=out_ch) self.norm = nn.GroupNorm(1, out_ch) # instance norm per channel def forward(self, x): x = self.act(self.pw(x)) x = self.norm(self.dw(x)) return x class DeepConvHead(nn.Module): """Pure conv-stack detection head on cofiber features.""" def __init__(self, feat_dim=768, hidden=256, n_blocks=10, n_scales=3, with_p3=False, lateral=False): super().__init__() self.n_scales = n_scales self.with_p3 = with_p3 self.lateral = lateral n_total = n_scales + (1 if with_p3 else 0) self.scale_norms = nn.ModuleList([nn.GroupNorm(1, feat_dim) for _ in range(n_scales)]) # Stem: project from feat_dim to hidden self.stem = nn.Conv2d(feat_dim, hidden, 1) self.stem_act = nn.GELU() # Stride-8 upsample path (P3) if with_p3: self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2) self.p3_norm = nn.GroupNorm(1, hidden) # Deep conv stack with residual connections self.blocks = nn.ModuleList() for _ in range(n_blocks): self.blocks.append(ConvBlock(hidden, hidden)) # Lateral top-down fusion if lateral: self.lateral_convs = nn.ModuleList() self.lateral_norms = nn.ModuleList() for _ in range(n_scales - 1): self.lateral_convs.append(nn.Conv2d(hidden, hidden, 1)) self.lateral_norms.append(nn.GroupNorm(1, hidden)) # Output heads self.cls_head = nn.Conv2d(hidden, NUM_CLASSES, 1) self.reg_head = nn.Conv2d(hidden, 4, 1) self.ctr_head = nn.Conv2d(hidden, 1, 1) self.scale_params = nn.Parameter(torch.ones(n_total)) def forward(self, spatial): cofibers = cofiber_decompose(spatial, self.n_scales) cls_l, reg_l, ctr_l = [], [], [] # Process stride-16 first (needed for P3 upsample) scale_offset = 0 if self.with_p3: cof16 = cofibers[0] # stride 16, 40x40 x16 = self.stem_act(self.stem(self.scale_norms[0](cof16))) for block in self.blocks: x16 = x16 + block(x16) # Create stride-8 via transposed conv (80x80) p3 = self.p3_norm(self.p3_upsample(x16)) for block in self.blocks: p3 = p3 + block(p3) cls = self.cls_head(p3) reg_raw = (self.reg_head(p3) * self.scale_params[0]).clamp(-10, 10) reg = reg_raw.exp() ctr = self.ctr_head(p3) cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr) scale_offset = 1 # Process each cofiber scale through the shared conv stack scale_features = [] for i, cof in enumerate(cofibers): x = self.scale_norms[i](cof) x = self.stem_act(self.stem(x)) for block in self.blocks: x = x + block(x) scale_features.append(x) # Top-down lateral fusion: coarse → fine if self.lateral: for i in range(len(scale_features) - 2, -1, -1): coarse = scale_features[i + 1] coarse_up = F.interpolate(coarse, size=scale_features[i].shape[2:], mode="bilinear", align_corners=False) scale_features[i] = self.lateral_norms[i]( scale_features[i] + self.lateral_convs[i](coarse_up)) # Predict from each scale for i, x in enumerate(scale_features): cls = self.cls_head(x) reg_raw = (self.reg_head(x) * self.scale_params[i + scale_offset]).clamp(-10, 10) reg = reg_raw.exp() ctr = self.ctr_head(x) cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr) return cls_l, reg_l, ctr_l def make_locations(feature_sizes, strides, device): locs = [] for (h, w), s in zip(feature_sizes, strides): ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s gy, gx = torch.meshgrid(ys, xs, indexing="ij") locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) return locs def main(): parser = argparse.ArgumentParser() parser.add_argument("--hidden", type=int, default=256) parser.add_argument("--blocks", type=int, default=10) parser.add_argument("--epochs", type=int, default=8) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--with-p3", action="store_true", help="Add stride-8 level via transposed conv") parser.add_argument("--lateral", action="store_true", help="Top-down lateral connections between scales") args = parser.parse_args() head = DeepConvHead(hidden=args.hidden, n_blocks=args.blocks, with_p3=args.with_p3, lateral=args.lateral).cuda() n_params = sum(p.numel() for p in head.parameters()) print("=" * 60) print(f"Deep Conv Head: {args.hidden} hidden, {args.blocks} blocks") print(f" {n_params:,} params") print("=" * 60, flush=True) from cache_and_train_fast import compute_loss manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) n_shards = manifest["n_shards"] n_images = manifest["n_images"] steps_per_epoch = n_images // args.batch_size total_steps = steps_per_epoch * args.epochs warmup = int(total_steps * 0.03) optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda s: s / max(warmup, 1) if s < warmup else 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(total_steps - warmup, 1)))) scaler = GradScaler() H = RESOLUTION // 16 if args.with_p3: strides = [8, 16, 32, 64] locs = make_locations([(H*2,H*2),(H,H),(H//2,H//2),(H//4,H//4)], strides, torch.device("cuda")) else: strides = [16, 32, 64] locs = make_locations([(H,H),(H//2,H//2),(H//4,H//4)], strides, torch.device("cuda")) shard_paths = [os.path.join(CACHE_DIR, f"shard_{i:04d}.pt") for i in range(n_shards)] print(f" {n_images} images, batch {args.batch_size}, {total_steps} steps, {args.epochs} epochs") print(f" fp16 mixed precision enabled") print(f" Training...\n", flush=True) head.train() global_step = 0 t0 = time.time() for epoch in range(args.epochs): shard_order = torch.randperm(n_shards).tolist() epoch_t0 = time.time() for shard_idx in shard_order: if global_step >= total_steps: break shard = torch.load(shard_paths[shard_idx], map_location="cpu", weights_only=False) within = torch.randperm(len(shard)).tolist() for batch_start in range(0, len(shard), args.batch_size): if global_step >= total_steps: break batch_idx = within[batch_start:batch_start + args.batch_size] if len(batch_idx) < 2: continue spatial = torch.stack([shard[i]["spatial"] for i in batch_idx]).float().cuda() boxes = [shard[i]["boxes"].cuda() for i in batch_idx] labels = [shard[i]["labels"].cuda() for i in batch_idx] try: with autocast(): cls_l, reg_l, ctr_l = head(spatial) loss = compute_loss(cls_l, reg_l, ctr_l, locs, boxes, labels) optimizer.zero_grad() scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(head.parameters(), 5.0) scaler.step(optimizer) scaler.update() scheduler.step() global_step += 1 if global_step % 100 == 0: lr = scheduler.get_last_lr()[0] elapsed = time.time() - t0 print(f" step {global_step}/{total_steps} (ep {epoch+1}) " f"loss={loss.item():.4f} lr={lr:.2e} " f"{global_step/elapsed:.1f} it/s", flush=True) if global_step % 4000 == 0: ckpt = f"/home/zootest/checkpoint_convdeep_step{global_step}.pth" torch.save({"head": head.state_dict(), "step": global_step}, ckpt) except RuntimeError as e: if "out of memory" in str(e): torch.cuda.empty_cache() optimizer.zero_grad() global_step += 1 scheduler.step() continue raise del shard print(f" Epoch {epoch+1}/{args.epochs} complete ({time.time()-epoch_t0:.0f}s)\n", flush=True) # Save out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "conv_deep") os.makedirs(out_dir, exist_ok=True) out = os.path.join(out_dir, f"conv_deep_{args.hidden}h_{args.blocks}b_{args.epochs}ep.pth") torch.save(head.state_dict(), out) elapsed = time.time() - t0 print(f"Saved: {out}") print(f"{n_params:,} params, {elapsed/60:.1f} minutes") if __name__ == "__main__": main()