| """Split tower: FCOS-style detection head on cofiber features. |
| |
| Separate cls/reg towers (std 3x3 + depthwise residual). P3 stride-8 via |
| transposed conv, top-down lateral connections. Cofiber decomposition |
| replaces the FPN. 2D sinusoidal positional encoding concatenated to the |
| backbone patch features before the stem. |
| |
| Training: autocast bf16 forward, fp32 master params + fp32 moments, CUDA |
| graph capture of the forward+loss+backward+optimizer step, shard prefetch |
| on a dedicated copy stream, precomputed per-image FCOS targets stored in |
| the feature cache shards. |
| """ |
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import threading |
| import time |
| from queue import Queue |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| from target_cache import ( |
| precompute_targets_for_image, make_locations, |
| STRIDES, SIZE_RANGES, |
| ) |
| from cuda_graph_trainer import CudaGraphTrainStep |
|
|
| CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") |
| COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") |
| VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") |
| RESOLUTION = int(os.environ.get("ARENA_RESOLUTION", "640")) |
| H = RESOLUTION // 16 |
| NUM_CLASSES = 80 |
|
|
|
|
| def cofiber_decompose(f, n_scales): |
| cofibers = []; residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega); residual = omega |
| cofibers.append(residual); return cofibers |
|
|
|
|
| class ConvGNBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
| self.act = nn.GELU() |
|
|
| def forward(self, x): |
| return self.act(self.norm(self.conv(x))) |
|
|
|
|
| class DWResBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.pw = nn.Conv2d(channels, channels, 1) |
| self.act = nn.GELU() |
| self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels) |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
|
|
| def forward(self, x): |
| return x + self.norm(self.dw(self.act(self.pw(x)))) |
|
|
|
|
| def make_tower(hidden, n_std, n_dw): |
| layers = [ConvGNBlock(hidden) for _ in range(n_std)] |
| layers += [DWResBlock(hidden) for _ in range(n_dw)] |
| return nn.Sequential(*layers) |
|
|
|
|
| def make_sin_pos_emb(H, W, dim, device): |
| assert dim % 4 == 0, "pos emb dim must be divisible by 4" |
| d = dim // 4 |
| ys = torch.arange(H, device=device, dtype=torch.float32) |
| xs = torch.arange(W, device=device, dtype=torch.float32) |
| omega = torch.exp(torch.arange(d, device=device, dtype=torch.float32) |
| * -(math.log(10000.0) / d)) |
| pe_y = torch.zeros(H, d * 2, device=device) |
| pe_y[:, 0::2] = torch.sin(ys[:, None] * omega[None, :]) |
| pe_y[:, 1::2] = torch.cos(ys[:, None] * omega[None, :]) |
| pe_x = torch.zeros(W, d * 2, device=device) |
| pe_x[:, 0::2] = torch.sin(xs[:, None] * omega[None, :]) |
| pe_x[:, 1::2] = torch.cos(xs[:, None] * omega[None, :]) |
| pos = torch.zeros(dim, H, W, device=device) |
| pos[:d*2] = pe_y.permute(1, 0)[:, :, None].expand(-1, H, W) |
| pos[d*2:] = pe_x.permute(1, 0)[None, :, :].expand(H, -1, W).permute(1, 0, 2) |
| return pos.unsqueeze(0) |
|
|
|
|
| class SplitTowerHead(nn.Module): |
| def __init__(self, feat_dim=768, hidden=192, n_std_layers=5, n_dw_layers=4, n_scales=4, |
| pos_emb_dim=64, text_embed_path=None): |
| super().__init__() |
| self.n_scales = n_scales |
| self.pos_emb_dim = pos_emb_dim |
| n_total = n_scales + 1 |
|
|
| input_dim = feat_dim + pos_emb_dim |
| self.scale_norms = nn.ModuleList([nn.GroupNorm(1, input_dim) for _ in range(n_scales)]) |
| self.stem = nn.Conv2d(input_dim, hidden, 1) |
| self.stem_act = nn.GELU() |
| self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2) |
| self.p3_norm = nn.GroupNorm(min(32, hidden), hidden) |
| self.lateral_convs = nn.ModuleList([nn.Conv2d(hidden, hidden, 1) for _ in range(n_scales - 1)]) |
| self.lateral_norms = nn.ModuleList([nn.GroupNorm(min(32, hidden), hidden) for _ in range(n_scales - 1)]) |
| self.cls_tower = make_tower(hidden, n_std_layers, n_dw_layers) |
| self.reg_tower = make_tower(hidden, n_std_layers, n_dw_layers) |
|
|
| |
| |
| |
| if text_embed_path is None: |
| text_embed_path = os.environ.get("COCO_TEXT_EMBED_PATH") |
| assert text_embed_path and os.path.isfile(text_embed_path), \ |
| f"text_embed_path missing: {text_embed_path}" |
| blob = torch.load(text_embed_path, map_location="cpu", weights_only=False) |
| text_embed = blob["embeddings"].float() |
| assert text_embed.shape[0] == NUM_CLASSES |
| self.text_embed_dim = text_embed.shape[1] |
| self.register_buffer("text_embed", text_embed) |
| self.cls_project = nn.Linear(hidden, self.text_embed_dim, bias=False) |
| self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / 0.07))) |
| self.cls_bias = nn.Parameter(torch.full((NUM_CLASSES,), -math.log(99))) |
|
|
| self.reg_pred = nn.Conv2d(hidden, 4, 1) |
| self.ctr_pred = nn.Conv2d(hidden, 1, 1) |
| self.scale_params = nn.Parameter(torch.ones(n_total)) |
|
|
| def forward(self, spatial): |
| B, C, H_, W_ = spatial.shape |
| pos = make_sin_pos_emb(H_, W_, self.pos_emb_dim, spatial.device).expand(B, -1, -1, -1) |
| spatial = torch.cat([spatial, pos], dim=1) |
| cofibers = cofiber_decompose(spatial, self.n_scales) |
| cls_l, reg_l, ctr_l = [], [], [] |
|
|
| scale_features = [] |
| for i, cof in enumerate(cofibers): |
| x = self.stem_act(self.stem(self.scale_norms[i](cof))) |
| scale_features.append(x) |
|
|
| for i in range(len(scale_features) - 2, -1, -1): |
| coarse_up = F.interpolate(scale_features[i + 1], size=scale_features[i].shape[2:], |
| mode="bilinear", align_corners=False) |
| scale_features[i] = self.lateral_norms[i]( |
| scale_features[i] + self.lateral_convs[i](coarse_up)) |
|
|
| p3 = self.p3_norm(self.p3_upsample(scale_features[0])) |
| all_features = [p3] + scale_features |
|
|
| for i, x in enumerate(all_features): |
| cls_feat = self.cls_tower(x) |
| reg_feat = self.reg_tower(x) |
|
|
| |
| |
| B_, _, Hi, Wi = cls_feat.shape |
| f = cls_feat.permute(0, 2, 3, 1).reshape(-1, cls_feat.shape[1]) |
| f_proj = self.cls_project(f) |
| f_norm = F.normalize(f_proj, p=2, dim=-1) |
| |
| logits = f_norm @ self.text_embed.t() |
| cls = (logits * self.logit_scale.exp() + self.cls_bias).reshape( |
| B_, Hi, Wi, NUM_CLASSES).permute(0, 3, 1, 2) |
|
|
| reg_raw = (self.reg_pred(reg_feat) * self.scale_params[i]).clamp(-10, 10) |
| reg = reg_raw.exp() |
| ctr = self.ctr_pred(reg_feat) |
| cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr) |
| return cls_l, reg_l, ctr_l |
|
|
|
|
| class _NullCtx: |
| def __enter__(self): return self |
| def __exit__(self, *a): pass |
|
|
|
|
| def load_shard_to_gpu(shard_path, copy_stream=None): |
| """Load an amended shard, pre-stack per-image tensors into pinned CPU |
| blocks, async H2D to GPU. spatial as bf16 (forward runs under autocast); |
| targets as fp32 for loss precision.""" |
| shard = torch.load(shard_path, map_location="cpu", weights_only=False) |
| n = len(shard) |
| spatial = torch.stack([s["spatial"] for s in shard]).pin_memory() |
| tgt_cls = torch.stack([s["tgt_cls"] for s in shard]).pin_memory() |
| tgt_reg = torch.stack([s["tgt_reg"] for s in shard]).pin_memory() |
| tgt_ctr = torch.stack([s["tgt_ctr"] for s in shard]).pin_memory() |
| ctx = torch.cuda.stream(copy_stream) if copy_stream is not None else _NullCtx() |
| with ctx: |
| result = { |
| "spatial": spatial.cuda(non_blocking=True).to(torch.bfloat16), |
| "tgt_cls": tgt_cls.cuda(non_blocking=True).long(), |
| "tgt_reg": tgt_reg.cuda(non_blocking=True).float(), |
| "tgt_ctr": tgt_ctr.cuda(non_blocking=True).float(), |
| "n": n, |
| } |
| return result |
|
|
|
|
| def shard_prefetcher(shard_paths, shard_order, queue, copy_stream): |
| """Worker loading shards onto `copy_stream` and pushing them to `queue` |
| so the H2D overlaps with main-thread compute. Pushes None at end.""" |
| for idx in shard_order: |
| shard_gpu = load_shard_to_gpu(shard_paths[idx], copy_stream=copy_stream) |
| copy_stream.synchronize() |
| queue.put(shard_gpu) |
| queue.put(None) |
|
|
|
|
| def amend_shard_with_targets(shard_path, device="cuda"): |
| """Add precomputed FCOS targets to a shard in place. Idempotent.""" |
| shard = torch.load(shard_path, map_location="cpu", weights_only=False) |
| if "tgt_cls" in shard[0]: |
| return shard |
| feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), |
| (H // 4, H // 4), (H // 8, H // 8)] |
| locs_per_level = make_locations(feat_sizes, STRIDES, torch.device(device)) |
| all_locs = torch.cat(locs_per_level, 0) |
| n_per_level = [loc.shape[0] for loc in locs_per_level] |
| level_ranges = [] |
| cumsum = 0 |
| for i, n in enumerate(n_per_level): |
| lo, hi = SIZE_RANGES[i] |
| level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi)) |
| cumsum += n |
| for entry in shard: |
| boxes = entry["boxes"].to(device).float() |
| labels = entry["labels"].to(device).long() |
| tcls, treg, tctr = precompute_targets_for_image( |
| boxes, labels, all_locs, level_ranges, device) |
| entry["tgt_cls"] = tcls.to(torch.int16).cpu() |
| entry["tgt_reg"] = treg.to(torch.float16).cpu() |
| entry["tgt_ctr"] = tctr.to(torch.float16).cpu() |
| torch.save(shard, shard_path) |
| return shard |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hidden", type=int, default=192) |
| parser.add_argument("--std-layers", type=int, default=5) |
| parser.add_argument("--dw-layers", type=int, default=4) |
| parser.add_argument("--epochs", type=int, default=8) |
| parser.add_argument("--batch-size", type=int, default=16) |
| parser.add_argument("--lr", type=float, default=5e-4) |
| parser.add_argument("--no-graph", action="store_true", |
| help="Disable CUDA Graphs (eager bf16 fallback)") |
| parser.add_argument("--no-precompute", action="store_true", |
| help="Skip shard target-amend pass") |
| parser.add_argument("--resume", type=str, default=None, |
| help="Resume from a checkpoint path") |
| parser.add_argument("--no-aug", action="store_true", |
| help="Disable feature-space augmentations") |
| parser.add_argument("--mixup-alpha", type=float, default=0.2, |
| help="Beta distribution alpha for mixup (0 = disabled)") |
| args = parser.parse_args() |
|
|
| head = SplitTowerHead(hidden=args.hidden, n_std_layers=args.std_layers, |
| n_dw_layers=args.dw_layers, n_scales=4).cuda() |
| n_params = sum(p.numel() for p in head.parameters()) |
| print("=" * 60) |
| print(f"Split tower 5-scale: {args.hidden} hidden, " |
| f"{args.std_layers} std + {args.dw_layers} dw layers per tower") |
| print(f" {n_params:,} params") |
| print(f" CUDA Graphs: {'enabled' if not args.no_graph else 'disabled'}") |
| print("=" * 60, flush=True) |
|
|
| manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) |
| n_shards = manifest["n_shards"] |
| n_images = manifest["n_images"] |
| steps_per_epoch = n_images // args.batch_size |
| total_steps = steps_per_epoch * args.epochs |
| warmup = int(total_steps * 0.03) |
|
|
| optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4, |
| capturable=not args.no_graph) |
| |
| |
| |
| lr_tensor = torch.tensor(args.lr, device="cuda", dtype=torch.float32) |
| for g in optimizer.param_groups: |
| g['lr'] = lr_tensor |
|
|
| def lr_at(step): |
| if step < warmup: |
| return args.lr * step / max(warmup, 1) |
| progress = (step - warmup) / max(total_steps - warmup, 1) |
| return args.lr * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
| feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), |
| (H // 4, H // 4), (H // 8, H // 8)] |
| locs_per_level = make_locations(feat_sizes, STRIDES, torch.device("cuda")) |
| all_locs = torch.cat(locs_per_level, 0) |
| n_total = all_locs.shape[0] |
| shard_paths = [os.path.join(CACHE_DIR, f"shard_{i:04d}.pt") for i in range(n_shards)] |
|
|
| print(f" {n_images} images, batch {args.batch_size}, {total_steps} steps, " |
| f"{args.epochs} epochs", flush=True) |
|
|
| |
| |
| manifest_path = os.path.join(CACHE_DIR, "manifest.json") |
| if not args.no_precompute and not manifest.get("targets_amended"): |
| print("\nAmending shards with precomputed FCOS targets (one-time cost)...", flush=True) |
| t0 = time.time() |
| for i, sp in enumerate(shard_paths): |
| shard = amend_shard_with_targets(sp) |
| del shard |
| if (i + 1) % 10 == 0: |
| print(f" shard {i+1}/{n_shards} ({(time.time()-t0)/(i+1):.1f}s/shard)", flush=True) |
| manifest["targets_amended"] = True |
| with open(manifest_path, "w") as f: |
| json.dump(manifest, f, indent=2) |
| print(f" done in {(time.time()-t0)/60:.1f} min", flush=True) |
|
|
| |
| start_step = 0 |
| if args.resume: |
| ckpt = torch.load(args.resume, map_location="cuda", weights_only=False) |
| head.load_state_dict(ckpt["head"]) |
| if "optimizer" in ckpt: |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| start_step = ckpt.get("step", 0) |
| if start_step < 0: start_step = 0 |
| print(f" Resumed from step {start_step}", flush=True) |
|
|
| graph_step = CudaGraphTrainStep(head, optimizer, batch_size=args.batch_size, |
| max_boxes=64, all_locs=all_locs) |
|
|
| |
| |
| |
| flip_perm_parts = [] |
| for h, w in feat_sizes: |
| grid = torch.arange(h * w, device="cuda").reshape(h, w) |
| flip_perm_parts.append(grid.flip(-1).flatten()) |
| flip_perm = torch.cat(flip_perm_parts) |
|
|
| print(f"\n Training...\n", flush=True) |
| head.train() |
| global_step = start_step |
| t0 = time.time() |
|
|
| copy_stream = torch.cuda.Stream() |
|
|
| for epoch in range(args.epochs): |
| if global_step >= total_steps: break |
| shard_order = torch.randperm(n_shards).tolist() |
| epoch_t0 = time.time() |
|
|
| prefetch_q = Queue(maxsize=1) |
| prefetch_thread = threading.Thread( |
| target=shard_prefetcher, |
| args=(shard_paths, shard_order, prefetch_q, copy_stream), |
| daemon=True, |
| ) |
| prefetch_thread.start() |
|
|
| for shard_pos in range(len(shard_order)): |
| if global_step >= total_steps: break |
| shard_gpu = prefetch_q.get() |
| if shard_gpu is None: break |
| n = shard_gpu["n"] |
| within = torch.randperm(n, device="cuda") |
| for batch_start in range(0, n, args.batch_size): |
| if global_step >= total_steps: break |
| batch_idx = within[batch_start:batch_start + args.batch_size] |
| if batch_idx.shape[0] < args.batch_size: continue |
|
|
| lr_tensor.fill_(lr_at(global_step)) |
|
|
| torch.index_select(shard_gpu["spatial"], 0, batch_idx, out=graph_step.buf_spatial) |
| torch.index_select(shard_gpu["tgt_cls"], 0, batch_idx, out=graph_step.buf_tgt_cls) |
| torch.index_select(shard_gpu["tgt_reg"], 0, batch_idx, out=graph_step.buf_tgt_reg) |
| torch.index_select(shard_gpu["tgt_ctr"], 0, batch_idx, out=graph_step.buf_tgt_ctr) |
|
|
| if not args.no_aug: |
| |
| |
| |
| if torch.rand(1).item() < 0.5: |
| graph_step.buf_spatial.copy_(graph_step.buf_spatial.flip(-1)) |
| graph_step.buf_tgt_cls.copy_(graph_step.buf_tgt_cls[:, flip_perm]) |
| graph_step.buf_tgt_ctr.copy_(graph_step.buf_tgt_ctr[:, flip_perm]) |
| flipped_reg = graph_step.buf_tgt_reg[:, flip_perm].clone() |
| flipped_reg[..., [0, 2]] = flipped_reg[..., [2, 0]] |
| graph_step.buf_tgt_reg.copy_(flipped_reg) |
|
|
| |
| |
| |
| if args.mixup_alpha > 0 and n >= 2 * args.batch_size: |
| lam = torch.distributions.Beta(args.mixup_alpha, args.mixup_alpha).sample().item() |
| if lam < 0.5: lam = 1.0 - lam |
| mix_idx = within[torch.randperm(n, device="cuda")[:args.batch_size]] |
| mix_sp = shard_gpu["spatial"][mix_idx] |
| graph_step.buf_spatial.mul_(lam).add_(mix_sp, alpha=1.0 - lam) |
|
|
| if global_step == start_step and not args.no_graph: |
| print("Capturing CUDA graph (first batch is slow)...", flush=True) |
| try: |
| graph_step.warmup_and_capture() |
| print(" graph captured", flush=True) |
| except Exception as e: |
| print(f" graph capture failed ({type(e).__name__}); falling back to eager: {e}", flush=True) |
| graph_step.captured = False |
|
|
| graph_step.run() |
| global_step += 1 |
|
|
| if global_step % 100 == 0: |
| loss_val = graph_step.last_loss() |
| elapsed = time.time() - t0 |
| print(f" step {global_step}/{total_steps} (ep {epoch+1}) " |
| f"loss={loss_val:.4f} lr={lr_tensor.item():.2e} " |
| f"{global_step/elapsed:.1f} it/s", flush=True) |
|
|
| if global_step % 4000 == 0: |
| out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "split_tower_5scale_textaligned") |
| os.makedirs(out_dir, exist_ok=True) |
| ckpt_path = os.path.join(out_dir, f"checkpoint_step{global_step}.pth") |
| sd = head.state_dict() |
| torch.save({"head": {k: v.float() for k, v in sd.items()}, |
| "optimizer": optimizer.state_dict(), |
| "step": global_step}, ckpt_path) |
| del shard_gpu |
| torch.cuda.empty_cache() |
|
|
| while True: |
| try: |
| leftover = prefetch_q.get_nowait() |
| if leftover is not None: |
| del leftover |
| except Exception: |
| break |
| prefetch_thread.join(timeout=30) |
| torch.cuda.empty_cache() |
| print(f" Epoch {epoch+1}/{args.epochs} complete ({time.time()-epoch_t0:.0f}s)\n", flush=True) |
|
|
| out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "split_tower_5scale_textaligned") |
| os.makedirs(out_dir, exist_ok=True) |
| sd = head.state_dict() |
| torch.save({"head": {k: v.float() for k, v in sd.items()}, "step": -1}, |
| os.path.join(out_dir, f"split_tower_5scale_{args.hidden}h_{args.std_layers}std_{args.dw_layers}dw_{args.epochs}ep.pth")) |
| print(f"Done. Total time: {(time.time()-t0)/60:.1f} min") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|