cofiber-detection / scripts /train_conv_deep.py
phanerozoic's picture
update repository
dbbceb8
"""
Deep conv-stack detection head. Pure Conv2d — no reshaping overhead.
Pointwise (1x1) convs replace Linear layers.
Depthwise (3x3) convs provide spatial context.
Runs at full GPU throughput with fp16.
"""
import argparse
import json
import math
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
CACHE_DIR = os.environ.get("ARENA_CACHE_DIR")
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT")
VAL_CACHE = os.environ.get("ARENA_VAL_CACHE")
RESOLUTION = 640
NUM_CLASSES = 80
def cofiber_decompose(f, n_scales):
cofibers = []; residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega); residual = omega
cofibers.append(residual); return cofibers
class ConvBlock(nn.Module):
"""Pointwise conv + GELU + Depthwise spatial conv."""
def __init__(self, in_ch, out_ch):
super().__init__()
self.pw = nn.Conv2d(in_ch, out_ch, 1)
self.act = nn.GELU()
self.dw = nn.Conv2d(out_ch, out_ch, 3, padding=1, groups=out_ch)
self.norm = nn.GroupNorm(1, out_ch) # instance norm per channel
def forward(self, x):
x = self.act(self.pw(x))
x = self.norm(self.dw(x))
return x
class DeepConvHead(nn.Module):
"""Pure conv-stack detection head on cofiber features."""
def __init__(self, feat_dim=768, hidden=256, n_blocks=10, n_scales=3,
with_p3=False, lateral=False):
super().__init__()
self.n_scales = n_scales
self.with_p3 = with_p3
self.lateral = lateral
n_total = n_scales + (1 if with_p3 else 0)
self.scale_norms = nn.ModuleList([nn.GroupNorm(1, feat_dim) for _ in range(n_scales)])
# Stem: project from feat_dim to hidden
self.stem = nn.Conv2d(feat_dim, hidden, 1)
self.stem_act = nn.GELU()
# Stride-8 upsample path (P3)
if with_p3:
self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2)
self.p3_norm = nn.GroupNorm(1, hidden)
# Deep conv stack with residual connections
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(ConvBlock(hidden, hidden))
# Lateral top-down fusion
if lateral:
self.lateral_convs = nn.ModuleList()
self.lateral_norms = nn.ModuleList()
for _ in range(n_scales - 1):
self.lateral_convs.append(nn.Conv2d(hidden, hidden, 1))
self.lateral_norms.append(nn.GroupNorm(1, hidden))
# Output heads
self.cls_head = nn.Conv2d(hidden, NUM_CLASSES, 1)
self.reg_head = nn.Conv2d(hidden, 4, 1)
self.ctr_head = nn.Conv2d(hidden, 1, 1)
self.scale_params = nn.Parameter(torch.ones(n_total))
def forward(self, spatial):
cofibers = cofiber_decompose(spatial, self.n_scales)
cls_l, reg_l, ctr_l = [], [], []
# Process stride-16 first (needed for P3 upsample)
scale_offset = 0
if self.with_p3:
cof16 = cofibers[0] # stride 16, 40x40
x16 = self.stem_act(self.stem(self.scale_norms[0](cof16)))
for block in self.blocks:
x16 = x16 + block(x16)
# Create stride-8 via transposed conv (80x80)
p3 = self.p3_norm(self.p3_upsample(x16))
for block in self.blocks:
p3 = p3 + block(p3)
cls = self.cls_head(p3)
reg_raw = (self.reg_head(p3) * self.scale_params[0]).clamp(-10, 10)
reg = reg_raw.exp()
ctr = self.ctr_head(p3)
cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr)
scale_offset = 1
# Process each cofiber scale through the shared conv stack
scale_features = []
for i, cof in enumerate(cofibers):
x = self.scale_norms[i](cof)
x = self.stem_act(self.stem(x))
for block in self.blocks:
x = x + block(x)
scale_features.append(x)
# Top-down lateral fusion: coarse → fine
if self.lateral:
for i in range(len(scale_features) - 2, -1, -1):
coarse = scale_features[i + 1]
coarse_up = F.interpolate(coarse, size=scale_features[i].shape[2:],
mode="bilinear", align_corners=False)
scale_features[i] = self.lateral_norms[i](
scale_features[i] + self.lateral_convs[i](coarse_up))
# Predict from each scale
for i, x in enumerate(scale_features):
cls = self.cls_head(x)
reg_raw = (self.reg_head(x) * self.scale_params[i + scale_offset]).clamp(-10, 10)
reg = reg_raw.exp()
ctr = self.ctr_head(x)
cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr)
return cls_l, reg_l, ctr_l
def make_locations(feature_sizes, strides, device):
locs = []
for (h, w), s in zip(feature_sizes, strides):
ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
locs.append(torch.stack([gx.flatten(), gy.flatten()], -1))
return locs
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--hidden", type=int, default=256)
parser.add_argument("--blocks", type=int, default=10)
parser.add_argument("--epochs", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--with-p3", action="store_true", help="Add stride-8 level via transposed conv")
parser.add_argument("--lateral", action="store_true", help="Top-down lateral connections between scales")
args = parser.parse_args()
head = DeepConvHead(hidden=args.hidden, n_blocks=args.blocks, with_p3=args.with_p3,
lateral=args.lateral).cuda()
n_params = sum(p.numel() for p in head.parameters())
print("=" * 60)
print(f"Deep Conv Head: {args.hidden} hidden, {args.blocks} blocks")
print(f" {n_params:,} params")
print("=" * 60, flush=True)
from cache_and_train_fast import compute_loss
manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json")))
n_shards = manifest["n_shards"]
n_images = manifest["n_images"]
steps_per_epoch = n_images // args.batch_size
total_steps = steps_per_epoch * args.epochs
warmup = int(total_steps * 0.03)
optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda s:
s / max(warmup, 1) if s < warmup else
0.5 * (1 + math.cos(math.pi * (s - warmup) / max(total_steps - warmup, 1))))
scaler = GradScaler()
H = RESOLUTION // 16
if args.with_p3:
strides = [8, 16, 32, 64]
locs = make_locations([(H*2,H*2),(H,H),(H//2,H//2),(H//4,H//4)], strides, torch.device("cuda"))
else:
strides = [16, 32, 64]
locs = make_locations([(H,H),(H//2,H//2),(H//4,H//4)], strides, torch.device("cuda"))
shard_paths = [os.path.join(CACHE_DIR, f"shard_{i:04d}.pt") for i in range(n_shards)]
print(f" {n_images} images, batch {args.batch_size}, {total_steps} steps, {args.epochs} epochs")
print(f" fp16 mixed precision enabled")
print(f" Training...\n", flush=True)
head.train()
global_step = 0
t0 = time.time()
for epoch in range(args.epochs):
shard_order = torch.randperm(n_shards).tolist()
epoch_t0 = time.time()
for shard_idx in shard_order:
if global_step >= total_steps: break
shard = torch.load(shard_paths[shard_idx], map_location="cpu", weights_only=False)
within = torch.randperm(len(shard)).tolist()
for batch_start in range(0, len(shard), args.batch_size):
if global_step >= total_steps: break
batch_idx = within[batch_start:batch_start + args.batch_size]
if len(batch_idx) < 2: continue
spatial = torch.stack([shard[i]["spatial"] for i in batch_idx]).float().cuda()
boxes = [shard[i]["boxes"].cuda() for i in batch_idx]
labels = [shard[i]["labels"].cuda() for i in batch_idx]
try:
with autocast():
cls_l, reg_l, ctr_l = head(spatial)
loss = compute_loss(cls_l, reg_l, ctr_l, locs, boxes, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(head.parameters(), 5.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
global_step += 1
if global_step % 100 == 0:
lr = scheduler.get_last_lr()[0]
elapsed = time.time() - t0
print(f" step {global_step}/{total_steps} (ep {epoch+1}) "
f"loss={loss.item():.4f} lr={lr:.2e} "
f"{global_step/elapsed:.1f} it/s", flush=True)
if global_step % 4000 == 0:
ckpt = f"/home/zootest/checkpoint_convdeep_step{global_step}.pth"
torch.save({"head": head.state_dict(), "step": global_step}, ckpt)
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
optimizer.zero_grad()
global_step += 1
scheduler.step()
continue
raise
del shard
print(f" Epoch {epoch+1}/{args.epochs} complete ({time.time()-epoch_t0:.0f}s)\n", flush=True)
# Save
out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "conv_deep")
os.makedirs(out_dir, exist_ok=True)
out = os.path.join(out_dir, f"conv_deep_{args.hidden}h_{args.blocks}b_{args.epochs}ep.pth")
torch.save(head.state_dict(), out)
elapsed = time.time() - t0
print(f"Saved: {out}")
print(f"{n_params:,} params, {elapsed/60:.1f} minutes")
if __name__ == "__main__":
main()