detection-heads / cuda_graph_trainer.py
phanerozoic's picture
update repository
74e3c01
"""Option 1: CUDA Graphs wrapper.
Captures the full forward + backward + optimizer step as a CUDA graph that can
be replayed on each iteration. Eliminates per-op Python dispatch overhead
(our likely bottleneck given that batch 8/16/32 all yield ~95 img/s).
Constraint: CUDA Graphs require fixed input shapes. Our actual data has
variable per-image GT box counts. Workaround: pad boxes to a fixed maximum
(e.g., 64 per image) and use a validity mask. The mask is part of the captured
graph; only the values inside change between iterations.
Includes a self-test against the mock backbone that benchmarks the graph-replay
training step against the eager-mode equivalent.
"""
import os
import sys
import time
import torch
import torch.nn.functional as F
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
class CudaGraphTrainStep:
"""Captures one full training step (forward + loss + backward + optimizer.step)
into a CUDA graph and replays it. Inputs and targets are written into pinned
device buffers; outputs are read back from device buffers.
Uses bf16 throughout (no GradScaler needed) because GradScaler.step() does
a Python-level .item() call which is forbidden during stream capture. bf16
has a wide enough exponent range (8-bit) to handle the value distributions
in our model without scaling.
Usage:
step = CudaGraphTrainStep(head, optimizer, batch_size=16)
step.warmup_and_capture()
for batch in loader:
step.set_inputs(spatial, tgt_cls, tgt_reg, tgt_ctr)
loss_value = step.run()
"""
def __init__(self, head, optimizer, batch_size, max_boxes, all_locs,
spatial_h=40, spatial_w=40, feat_dim=768, n_classes=80,
reg_weight=1.0, sheaf_weight=0.1):
"""all_locs: (N, 2) tensor of (cx, cy) for every prediction location across
every level, in the same flatten order produced by `cat([per-level], 1)`
in `_step_body`. Required for the in-graph GIoU box regression loss."""
self.head = head
self.optimizer = optimizer
self.B = batch_size
self.M = max_boxes
self.spatial_h = spatial_h
self.spatial_w = spatial_w
self.feat_dim = feat_dim
self.C = n_classes
self.N = all_locs.shape[0]
self.reg_weight = reg_weight
self.sheaf_weight = sheaf_weight
self.buf_spatial = torch.zeros(batch_size, feat_dim, spatial_h, spatial_w,
device="cuda", dtype=torch.bfloat16)
self.buf_tgt_cls = torch.full((batch_size, self.N), -1, device="cuda", dtype=torch.long)
self.buf_tgt_reg = torch.zeros(batch_size, self.N, 4, device="cuda", dtype=torch.float32)
self.buf_tgt_ctr = torch.zeros(batch_size, self.N, device="cuda", dtype=torch.float32)
self.buf_locs = all_locs.float().to("cuda")
self.buf_loss = torch.zeros(1, device="cuda", dtype=torch.float32)
self.graph = None
self.captured = False
def _step_body(self):
"""Forward under bf16 autocast; loss + backward + optimizer step in
fp32. GradScaler is incompatible with stream capture; autocast(bf16)
substitutes (bf16's fp32-equivalent exponent range needs no scaling)."""
with torch.autocast("cuda", dtype=torch.bfloat16):
cls_per, reg_per, ctr_per = self.head(self.buf_spatial)
flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(self.B, -1, self.C) for c in cls_per], 1).float()
flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(self.B, -1, 4) for r in reg_per], 1).float()
flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(self.B, -1) for c in ctr_per], 1).float()
pos = self.buf_tgt_cls >= 0
npos = pos.sum().clamp(min=1).float()
oh = torch.zeros_like(flat_cls)
cls_idx = self.buf_tgt_cls.clamp(min=0)
oh.scatter_(2, cls_idx.unsqueeze(-1), 1.0)
oh = oh * pos.unsqueeze(-1).float()
p = torch.sigmoid(flat_cls)
ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none")
pt = p * oh + (1 - p) * (1 - oh)
at = 0.25 * oh + 0.75 * (1 - oh)
loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos
ctr_target = self.buf_tgt_ctr * pos.float()
loss_ctr = (F.binary_cross_entropy_with_logits(flat_ctr, ctr_target, reduction="none") * pos.float()).sum() / npos
# Element-wise GIoU on every location, masked by pos. Static-shape so it
# captures cleanly. Negatives have tgt_reg=0 and the `* pos.float()`
# mask zeros them out before the sum.
pl = self.buf_locs.unsqueeze(0).expand(self.B, -1, -1)
tgt_reg_f = self.buf_tgt_reg
pb_x1 = pl[..., 0] - flat_reg[..., 0]
pb_y1 = pl[..., 1] - flat_reg[..., 1]
pb_x2 = pl[..., 0] + flat_reg[..., 2]
pb_y2 = pl[..., 1] + flat_reg[..., 3]
tb_x1 = pl[..., 0] - tgt_reg_f[..., 0]
tb_y1 = pl[..., 1] - tgt_reg_f[..., 1]
tb_x2 = pl[..., 0] + tgt_reg_f[..., 2]
tb_y2 = pl[..., 1] + tgt_reg_f[..., 3]
inter_x1 = torch.maximum(pb_x1, tb_x1)
inter_y1 = torch.maximum(pb_y1, tb_y1)
inter_x2 = torch.minimum(pb_x2, tb_x2)
inter_y2 = torch.minimum(pb_y2, tb_y2)
inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
ap = (pb_x2 - pb_x1).clamp(min=0) * (pb_y2 - pb_y1).clamp(min=0)
at_a = (tb_x2 - tb_x1).clamp(min=0) * (tb_y2 - tb_y1).clamp(min=0)
union = ap + at_a - inter
iou = inter / union.clamp(min=1e-6)
enc_x1 = torch.minimum(pb_x1, tb_x1)
enc_y1 = torch.minimum(pb_y1, tb_y1)
enc_x2 = torch.maximum(pb_x2, tb_x2)
enc_y2 = torch.maximum(pb_y2, tb_y2)
enc = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0)
giou = iou - (enc - union) / enc.clamp(min=1e-6)
loss_reg = ((1 - giou) * pos.float()).sum() / npos
# Sheaf consistency: penalize spatial discontinuity in cls logits.
# For each pyramid level, adjacent patches should produce similar class
# predictions. L2 on logits, averaged across levels. Light weight so it
# regularizes without overwhelming focal loss.
loss_sheaf = torch.tensor(0.0, device=flat_cls.device)
for c in cls_per:
cf = c.float()
h_diff = (cf[:, :, :, :-1] - cf[:, :, :, 1:]).pow(2).mean()
v_diff = (cf[:, :, :-1, :] - cf[:, :, 1:, :]).pow(2).mean()
loss_sheaf = loss_sheaf + h_diff + v_diff
loss_sheaf = loss_sheaf * self.sheaf_weight
loss = loss_cls + self.reg_weight * loss_reg + loss_ctr + loss_sheaf
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=False)
self.buf_loss.copy_(loss.detach())
def warmup_and_capture(self):
"""Run a few warmup steps in eager mode (allocates buffers, primes
cudnn), then capture the graph."""
# Warmup
for _ in range(3):
self._step_body()
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self._step_body()
self.captured = True
def set_inputs(self, spatial, tgt_cls, tgt_reg, tgt_ctr):
"""Copy inputs into the persistent buffers."""
self.buf_spatial.copy_(spatial)
self.buf_tgt_cls.copy_(tgt_cls)
self.buf_tgt_reg.copy_(tgt_reg)
self.buf_tgt_ctr.copy_(tgt_ctr)
def run(self):
"""Async: replays the captured graph if available, else runs one eager
training step from the persistent buffers. No sync — call `last_loss()`
when the scalar is needed (sync point) for logging."""
if self.captured:
self.graph.replay()
else:
self._step_body()
def last_loss(self):
"""Read the most recent captured-graph loss from the GPU buffer.
Forces a stream sync — call only when you actually need the value."""
return self.buf_loss.item()
# ============================================================
# Self-test
# ============================================================
if __name__ == "__main__":
from mock_eupe_backbone import make_mock_features, make_mock_boxes
from target_cache import (
precompute_targets_for_image, make_locations,
STRIDES, SIZE_RANGES, H,
)
from train_split_tower_5scale import SplitTowerHead
if not torch.cuda.is_available():
print("CUDA not available; skipping CUDA Graphs test.")
sys.exit(0)
print("CUDA Graphs self-test")
print("=" * 60)
B, M = 16, 64
head = SplitTowerHead(hidden=192, n_std_layers=5, n_dw_layers=4, n_scales=4).cuda()
optimizer = torch.optim.AdamW(head.parameters(), lr=5e-4, weight_decay=1e-4, capturable=True)
feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2),
(H // 4, H // 4), (H // 8, H // 8)]
locs_per_level = make_locations(feat_sizes, STRIDES, torch.device("cuda"))
n_total = sum(loc.shape[0] for loc in locs_per_level)
print(f" total locations: {n_total}")
boxes_list, labels_list = make_mock_boxes(B=B, n_boxes_per_image=8, device="cuda", seed=0)
spatial = make_mock_features(B=B, device="cuda", seed=0).to(torch.bfloat16)
all_locs = torch.cat(locs_per_level, 0)
n_per_level = [loc.shape[0] for loc in locs_per_level]
level_ranges = []
cumsum = 0
for i, n in enumerate(n_per_level):
lo, hi = SIZE_RANGES[i]
level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi))
cumsum += n
tgt_cls_batch = torch.zeros(B, n_total, dtype=torch.long, device="cuda")
tgt_reg_batch = torch.zeros(B, n_total, 4, device="cuda")
tgt_ctr_batch = torch.zeros(B, n_total, device="cuda")
for i in range(B):
tcls, treg, tctr = precompute_targets_for_image(
boxes_list[i], labels_list[i], all_locs, level_ranges, "cuda")
tgt_cls_batch[i] = tcls
tgt_reg_batch[i] = treg
tgt_ctr_batch[i] = tctr
# Initialize graph-step
step = CudaGraphTrainStep(head, optimizer, batch_size=B, max_boxes=M,
all_locs=all_locs)
step.set_inputs(spatial, tgt_cls_batch, tgt_reg_batch, tgt_ctr_batch)
print("\n1. Capturing CUDA graph...")
t0 = time.time()
step.warmup_and_capture()
capture_time = time.time() - t0
print(f" capture time (3 warmup + 1 capture): {capture_time*1000:.0f} ms")
print("\n2. Benchmarking graph replay vs eager-mode step...")
# Eager-mode body that replicates _step_body without graph capture
def eager_step():
with torch.autocast("cuda", dtype=torch.bfloat16):
cls_per, reg_per, ctr_per = head(step.buf_spatial)
flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1, 80) for c in cls_per], 1).float()
flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in reg_per], 1).float()
flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1) for c in ctr_per], 1).float()
pos = step.buf_tgt_cls >= 0
npos = pos.sum().clamp(min=1).float()
oh = torch.zeros_like(flat_cls)
cls_idx = step.buf_tgt_cls.clamp(min=0)
oh.scatter_(2, cls_idx.unsqueeze(-1), 1.0)
oh = oh * pos.unsqueeze(-1).float()
p = torch.sigmoid(flat_cls)
ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none")
pt = p * oh + (1 - p) * (1 - oh)
at = 0.25 * oh + 0.75 * (1 - oh)
loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos
ctr_target = step.buf_tgt_ctr * pos.float()
loss_ctr = (F.binary_cross_entropy_with_logits(flat_ctr, ctr_target, reduction="none") * pos.float()).sum() / npos
pl = step.buf_locs.unsqueeze(0).expand(B, -1, -1)
tgt_reg_f = step.buf_tgt_reg
pb_x1 = pl[..., 0] - flat_reg[..., 0]; pb_y1 = pl[..., 1] - flat_reg[..., 1]
pb_x2 = pl[..., 0] + flat_reg[..., 2]; pb_y2 = pl[..., 1] + flat_reg[..., 3]
tb_x1 = pl[..., 0] - tgt_reg_f[..., 0]; tb_y1 = pl[..., 1] - tgt_reg_f[..., 1]
tb_x2 = pl[..., 0] + tgt_reg_f[..., 2]; tb_y2 = pl[..., 1] + tgt_reg_f[..., 3]
inter_x1 = torch.maximum(pb_x1, tb_x1); inter_y1 = torch.maximum(pb_y1, tb_y1)
inter_x2 = torch.minimum(pb_x2, tb_x2); inter_y2 = torch.minimum(pb_y2, tb_y2)
inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
ap = (pb_x2 - pb_x1).clamp(min=0) * (pb_y2 - pb_y1).clamp(min=0)
at_a = (tb_x2 - tb_x1).clamp(min=0) * (tb_y2 - tb_y1).clamp(min=0)
union = ap + at_a - inter
iou = inter / union.clamp(min=1e-6)
enc_x1 = torch.minimum(pb_x1, tb_x1); enc_y1 = torch.minimum(pb_y1, tb_y1)
enc_x2 = torch.maximum(pb_x2, tb_x2); enc_y2 = torch.maximum(pb_y2, tb_y2)
enc = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0)
giou = iou - (enc - union) / enc.clamp(min=1e-6)
loss_reg = ((1 - giou) * pos.float()).sum() / npos
loss = loss_cls + step.reg_weight * loss_reg + loss_ctr
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss.item()
# Warmup eager
for _ in range(3):
eager_step()
torch.cuda.synchronize()
N = 50
t0 = time.time()
for _ in range(N):
eager_step()
torch.cuda.synchronize()
eager_time = (time.time() - t0) / N
t0 = time.time()
for _ in range(N):
step.run()
torch.cuda.synchronize() # one sync at the end, not per-step
graph_time = (time.time() - t0) / N
print(f" eager step: {eager_time*1000:.2f} ms")
print(f" graph replay: {graph_time*1000:.2f} ms")
print(f" speedup: {eager_time / graph_time:.2f}x")
if graph_time > eager_time:
print(" WARNING: graph replay is slower; CUDA Graph not helping")
else:
print(f"\nGraph replay {eager_time / graph_time:.2f}x faster than eager mode.")