| """ |
| FCOS-Lite: Slim FCOS-style head on cofiber features. |
| |
| Separate cls and reg towers with standard 3x3 convolutions (full cross-channel mixing). |
| P3 stride-8 via transposed conv. Top-down lateral connections. |
| Cofiber decomposition replaces the heavy FPN. |
| |
| Target: match FCOS (41.0 mAP at 16.14M) at ≤4M params. |
| |
| Key differences from conv_deep: |
| - Standard Conv2d(256, 256, 3) instead of depthwise (256× more params per layer but full mixing) |
| - Separate cls and reg towers (FCOS-style) |
| - Fewer blocks (4 per tower instead of 20 shared) |
| """ |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.cuda.amp import autocast, GradScaler |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") |
| COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") |
| VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") |
| RESOLUTION = 640 |
| NUM_CLASSES = 80 |
|
|
|
|
| def cofiber_decompose(f, n_scales): |
| cofibers = []; residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega); residual = omega |
| cofibers.append(residual); return cofibers |
|
|
|
|
| class ConvGNBlock(nn.Module): |
| """Standard 3x3 conv + GroupNorm + GELU. Full cross-channel mixing.""" |
| def __init__(self, channels): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
| self.act = nn.GELU() |
|
|
| def forward(self, x): |
| return self.act(self.norm(self.conv(x))) |
|
|
|
|
| class DWResBlock(nn.Module): |
| """Depthwise residual block: pointwise + GELU + DW 3x3 + GN + residual.""" |
| def __init__(self, channels): |
| super().__init__() |
| self.pw = nn.Conv2d(channels, channels, 1) |
| self.act = nn.GELU() |
| self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels) |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
|
|
| def forward(self, x): |
| return x + self.norm(self.dw(self.act(self.pw(x)))) |
|
|
|
|
| def make_tower(hidden, n_std, n_dw): |
| """Build a hybrid tower: standard 3x3 layers + depthwise residual blocks.""" |
| layers = [] |
| for _ in range(n_std): |
| layers.append(ConvGNBlock(hidden)) |
| for _ in range(n_dw): |
| layers.append(DWResBlock(hidden)) |
| return nn.Sequential(*layers) |
|
|
|
|
| class FCOSLiteHead(nn.Module): |
| """Slim FCOS head on cofiber features with P3 + lateral + hybrid towers.""" |
|
|
| def __init__(self, feat_dim=768, hidden=256, n_std_layers=3, n_dw_layers=6, n_scales=3): |
| super().__init__() |
| self.n_scales = n_scales |
| n_total = n_scales + 1 |
|
|
| |
| self.scale_norms = nn.ModuleList([nn.GroupNorm(1, feat_dim) for _ in range(n_scales)]) |
|
|
| |
| self.stem = nn.Conv2d(feat_dim, hidden, 1) |
| self.stem_act = nn.GELU() |
|
|
| |
| self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2) |
| self.p3_norm = nn.GroupNorm(min(32, hidden), hidden) |
|
|
| |
| self.lateral_convs = nn.ModuleList([nn.Conv2d(hidden, hidden, 1) for _ in range(n_scales - 1)]) |
| self.lateral_norms = nn.ModuleList([nn.GroupNorm(min(32, hidden), hidden) for _ in range(n_scales - 1)]) |
|
|
| |
| self.cls_tower = make_tower(hidden, n_std_layers, n_dw_layers) |
| self.reg_tower = make_tower(hidden, n_std_layers, n_dw_layers) |
|
|
| |
| self.cls_pred = nn.Conv2d(hidden, NUM_CLASSES, 1) |
| self.reg_pred = nn.Conv2d(hidden, 4, 1) |
| self.ctr_pred = nn.Conv2d(hidden, 1, 1) |
| self.scale_params = nn.Parameter(torch.ones(n_total)) |
|
|
| |
| nn.init.constant_(self.cls_pred.bias, -math.log(99)) |
|
|
| def forward(self, spatial): |
| cofibers = cofiber_decompose(spatial, self.n_scales) |
| cls_l, reg_l, ctr_l = [], [], [] |
|
|
| |
| scale_features = [] |
| for i, cof in enumerate(cofibers): |
| x = self.stem_act(self.stem(self.scale_norms[i](cof))) |
| scale_features.append(x) |
|
|
| |
| for i in range(len(scale_features) - 2, -1, -1): |
| coarse_up = F.interpolate(scale_features[i + 1], size=scale_features[i].shape[2:], |
| mode="bilinear", align_corners=False) |
| scale_features[i] = self.lateral_norms[i]( |
| scale_features[i] + self.lateral_convs[i](coarse_up)) |
|
|
| |
| p3 = self.p3_norm(self.p3_upsample(scale_features[0])) |
| all_features = [p3] + scale_features |
|
|
| |
| for i, x in enumerate(all_features): |
| cls_feat = self.cls_tower(x) |
| reg_feat = self.reg_tower(x) |
|
|
| cls = self.cls_pred(cls_feat) |
| reg_raw = (self.reg_pred(reg_feat) * self.scale_params[i]).clamp(-10, 10) |
| reg = reg_raw.exp() |
| ctr = self.ctr_pred(reg_feat) |
|
|
| cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr) |
| return cls_l, reg_l, ctr_l |
|
|
|
|
| def make_locations(feature_sizes, strides, device): |
| locs = [] |
| for (h, w), s in zip(feature_sizes, strides): |
| ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") |
| locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) |
| return locs |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hidden", type=int, default=224) |
| parser.add_argument("--std-layers", type=int, default=3) |
| parser.add_argument("--dw-layers", type=int, default=6) |
| parser.add_argument("--epochs", type=int, default=8) |
| parser.add_argument("--batch-size", type=int, default=16) |
| parser.add_argument("--lr", type=float, default=5e-4) |
| parser.add_argument("--resume", type=str, default=None) |
| args = parser.parse_args() |
|
|
| head = FCOSLiteHead(hidden=args.hidden, n_std_layers=args.std_layers, n_dw_layers=args.dw_layers, n_scales=4).cuda() |
| start_step = 0 |
| if args.resume: |
| ckpt = torch.load(args.resume, map_location="cuda", weights_only=False) |
| head.load_state_dict(ckpt["head"]) |
| start_step = ckpt["step"] |
| print(f"Resumed from step {start_step}") |
| n_params = sum(p.numel() for p in head.parameters()) |
| print("=" * 60) |
| print(f"FCOS-Lite: {args.hidden} hidden, {args.std_layers} std + {args.dw_layers} dw layers per tower") |
| print(f" {n_params:,} params") |
| print(f" Separate cls/reg towers, standard 3x3 convs, P3 + lateral") |
| print("=" * 60, flush=True) |
|
|
| from cache_and_train_fast import compute_loss |
| manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) |
| n_shards = manifest["n_shards"] |
| n_images = manifest["n_images"] |
| steps_per_epoch = n_images // args.batch_size |
| total_steps = steps_per_epoch * args.epochs |
| warmup = int(total_steps * 0.03) |
|
|
| optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda s: |
| s / max(warmup, 1) if s < warmup else |
| 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(total_steps - warmup, 1)))) |
| scaler = GradScaler() |
|
|
| H = RESOLUTION // 16 |
| strides = [8, 16, 32, 64, 128] |
| locs = make_locations([(H*2,H*2),(H,H),(H//2,H//2),(H//4,H//4),(H//8,H//8)], strides, torch.device("cuda")) |
| shard_paths = [os.path.join(CACHE_DIR, f"shard_{i:04d}.pt") for i in range(n_shards)] |
|
|
| print(f" {n_images} images, batch {args.batch_size}, {total_steps} steps, {args.epochs} epochs") |
| print(f" fp16 mixed precision") |
| print(f" Training...\n", flush=True) |
|
|
| head.train() |
| global_step = start_step |
| if start_step > 0: |
| for _ in range(start_step): |
| scheduler.step() |
| print(f" Scheduler advanced to step {start_step}", flush=True) |
| t0 = time.time() |
|
|
| for epoch in range(args.epochs): |
| shard_order = torch.randperm(n_shards).tolist() |
| epoch_t0 = time.time() |
|
|
| for shard_idx in shard_order: |
| if global_step >= total_steps: break |
| shard = torch.load(shard_paths[shard_idx], map_location="cpu", weights_only=False) |
| within = torch.randperm(len(shard)).tolist() |
|
|
| for batch_start in range(0, len(shard), args.batch_size): |
| if global_step >= total_steps: break |
| batch_idx = within[batch_start:batch_start + args.batch_size] |
| if len(batch_idx) < 2: continue |
|
|
| spatial = torch.stack([shard[i]["spatial"] for i in batch_idx]).float().cuda() |
| boxes = [shard[i]["boxes"].cuda() for i in batch_idx] |
| labels = [shard[i]["labels"].cuda() for i in batch_idx] |
|
|
| try: |
| with autocast(): |
| cls_l, reg_l, ctr_l = head(spatial) |
| loss = compute_loss(cls_l, reg_l, ctr_l, locs, boxes, labels) |
|
|
| optimizer.zero_grad() |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(head.parameters(), 5.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| global_step += 1 |
|
|
| if global_step % 100 == 0: |
| lr = scheduler.get_last_lr()[0] |
| elapsed = time.time() - t0 |
| print(f" step {global_step}/{total_steps} (ep {epoch+1}) " |
| f"loss={loss.item():.4f} lr={lr:.2e} " |
| f"{global_step/elapsed:.1f} it/s", flush=True) |
|
|
| if global_step % 4000 == 0: |
| out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "split_tower_5scale") |
| os.makedirs(out_dir, exist_ok=True) |
| ckpt = os.path.join(out_dir, f"checkpoint_step{global_step}.pth") |
| torch.save({"head": head.state_dict(), "step": global_step}, ckpt) |
|
|
| except RuntimeError as e: |
| if "out of memory" in str(e): |
| torch.cuda.empty_cache() |
| optimizer.zero_grad() |
| global_step += 1 |
| scheduler.step() |
| continue |
| raise |
|
|
| del shard |
|
|
| print(f" Epoch {epoch+1}/{args.epochs} complete ({time.time()-epoch_t0:.0f}s)\n", flush=True) |
|
|
| out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "split_tower_5scale") |
| os.makedirs(out_dir, exist_ok=True) |
| out = os.path.join(out_dir, f"split_tower_5scale_{args.hidden}h_{args.std_layers}std_{args.dw_layers}dw_{args.epochs}ep.pth") |
| torch.save({"head": head.state_dict(), "step": -1, "config": { |
| "hidden": args.hidden, "std_layers": args.std_layers, "dw_layers": args.dw_layers |
| }}, out) |
| elapsed = time.time() - t0 |
| print(f"Saved: {out}") |
| print(f"{n_params:,} params, {elapsed/60:.1f} minutes") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|