"""Option 1: CUDA Graphs wrapper. Captures the full forward + backward + optimizer step as a CUDA graph that can be replayed on each iteration. Eliminates per-op Python dispatch overhead (our likely bottleneck given that batch 8/16/32 all yield ~95 img/s). Constraint: CUDA Graphs require fixed input shapes. Our actual data has variable per-image GT box counts. Workaround: pad boxes to a fixed maximum (e.g., 64 per image) and use a validity mask. The mask is part of the captured graph; only the values inside change between iterations. Includes a self-test against the mock backbone that benchmarks the graph-replay training step against the eager-mode equivalent. """ import os import sys import time import torch import torch.nn.functional as F SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) class CudaGraphTrainStep: """Captures one full training step (forward + loss + backward + optimizer.step) into a CUDA graph and replays it. Inputs and targets are written into pinned device buffers; outputs are read back from device buffers. Uses bf16 throughout (no GradScaler needed) because GradScaler.step() does a Python-level .item() call which is forbidden during stream capture. bf16 has a wide enough exponent range (8-bit) to handle the value distributions in our model without scaling. Usage: step = CudaGraphTrainStep(head, optimizer, batch_size=16) step.warmup_and_capture() for batch in loader: step.set_inputs(spatial, tgt_cls, tgt_reg, tgt_ctr) loss_value = step.run() """ def __init__(self, head, optimizer, batch_size, max_boxes, all_locs, spatial_h=40, spatial_w=40, feat_dim=768, n_classes=80, reg_weight=1.0, sheaf_weight=0.1): """all_locs: (N, 2) tensor of (cx, cy) for every prediction location across every level, in the same flatten order produced by `cat([per-level], 1)` in `_step_body`. Required for the in-graph GIoU box regression loss.""" self.head = head self.optimizer = optimizer self.B = batch_size self.M = max_boxes self.spatial_h = spatial_h self.spatial_w = spatial_w self.feat_dim = feat_dim self.C = n_classes self.N = all_locs.shape[0] self.reg_weight = reg_weight self.sheaf_weight = sheaf_weight self.buf_spatial = torch.zeros(batch_size, feat_dim, spatial_h, spatial_w, device="cuda", dtype=torch.bfloat16) self.buf_tgt_cls = torch.full((batch_size, self.N), -1, device="cuda", dtype=torch.long) self.buf_tgt_reg = torch.zeros(batch_size, self.N, 4, device="cuda", dtype=torch.float32) self.buf_tgt_ctr = torch.zeros(batch_size, self.N, device="cuda", dtype=torch.float32) self.buf_locs = all_locs.float().to("cuda") self.buf_loss = torch.zeros(1, device="cuda", dtype=torch.float32) self.graph = None self.captured = False def _step_body(self): """Forward under bf16 autocast; loss + backward + optimizer step in fp32. GradScaler is incompatible with stream capture; autocast(bf16) substitutes (bf16's fp32-equivalent exponent range needs no scaling).""" with torch.autocast("cuda", dtype=torch.bfloat16): cls_per, reg_per, ctr_per = self.head(self.buf_spatial) flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(self.B, -1, self.C) for c in cls_per], 1).float() flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(self.B, -1, 4) for r in reg_per], 1).float() flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(self.B, -1) for c in ctr_per], 1).float() pos = self.buf_tgt_cls >= 0 npos = pos.sum().clamp(min=1).float() oh = torch.zeros_like(flat_cls) cls_idx = self.buf_tgt_cls.clamp(min=0) oh.scatter_(2, cls_idx.unsqueeze(-1), 1.0) oh = oh * pos.unsqueeze(-1).float() p = torch.sigmoid(flat_cls) ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none") pt = p * oh + (1 - p) * (1 - oh) at = 0.25 * oh + 0.75 * (1 - oh) loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos ctr_target = self.buf_tgt_ctr * pos.float() loss_ctr = (F.binary_cross_entropy_with_logits(flat_ctr, ctr_target, reduction="none") * pos.float()).sum() / npos # Element-wise GIoU on every location, masked by pos. Static-shape so it # captures cleanly. Negatives have tgt_reg=0 and the `* pos.float()` # mask zeros them out before the sum. pl = self.buf_locs.unsqueeze(0).expand(self.B, -1, -1) tgt_reg_f = self.buf_tgt_reg pb_x1 = pl[..., 0] - flat_reg[..., 0] pb_y1 = pl[..., 1] - flat_reg[..., 1] pb_x2 = pl[..., 0] + flat_reg[..., 2] pb_y2 = pl[..., 1] + flat_reg[..., 3] tb_x1 = pl[..., 0] - tgt_reg_f[..., 0] tb_y1 = pl[..., 1] - tgt_reg_f[..., 1] tb_x2 = pl[..., 0] + tgt_reg_f[..., 2] tb_y2 = pl[..., 1] + tgt_reg_f[..., 3] inter_x1 = torch.maximum(pb_x1, tb_x1) inter_y1 = torch.maximum(pb_y1, tb_y1) inter_x2 = torch.minimum(pb_x2, tb_x2) inter_y2 = torch.minimum(pb_y2, tb_y2) inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) ap = (pb_x2 - pb_x1).clamp(min=0) * (pb_y2 - pb_y1).clamp(min=0) at_a = (tb_x2 - tb_x1).clamp(min=0) * (tb_y2 - tb_y1).clamp(min=0) union = ap + at_a - inter iou = inter / union.clamp(min=1e-6) enc_x1 = torch.minimum(pb_x1, tb_x1) enc_y1 = torch.minimum(pb_y1, tb_y1) enc_x2 = torch.maximum(pb_x2, tb_x2) enc_y2 = torch.maximum(pb_y2, tb_y2) enc = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0) giou = iou - (enc - union) / enc.clamp(min=1e-6) loss_reg = ((1 - giou) * pos.float()).sum() / npos # Sheaf consistency: penalize spatial discontinuity in cls logits. # For each pyramid level, adjacent patches should produce similar class # predictions. L2 on logits, averaged across levels. Light weight so it # regularizes without overwhelming focal loss. loss_sheaf = torch.tensor(0.0, device=flat_cls.device) for c in cls_per: cf = c.float() h_diff = (cf[:, :, :, :-1] - cf[:, :, :, 1:]).pow(2).mean() v_diff = (cf[:, :, :-1, :] - cf[:, :, 1:, :]).pow(2).mean() loss_sheaf = loss_sheaf + h_diff + v_diff loss_sheaf = loss_sheaf * self.sheaf_weight loss = loss_cls + self.reg_weight * loss_reg + loss_ctr + loss_sheaf loss.backward() self.optimizer.step() self.optimizer.zero_grad(set_to_none=False) self.buf_loss.copy_(loss.detach()) def warmup_and_capture(self): """Run a few warmup steps in eager mode (allocates buffers, primes cudnn), then capture the graph.""" # Warmup for _ in range(3): self._step_body() torch.cuda.synchronize() # Capture self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): self._step_body() self.captured = True def set_inputs(self, spatial, tgt_cls, tgt_reg, tgt_ctr): """Copy inputs into the persistent buffers.""" self.buf_spatial.copy_(spatial) self.buf_tgt_cls.copy_(tgt_cls) self.buf_tgt_reg.copy_(tgt_reg) self.buf_tgt_ctr.copy_(tgt_ctr) def run(self): """Async: replays the captured graph if available, else runs one eager training step from the persistent buffers. No sync — call `last_loss()` when the scalar is needed (sync point) for logging.""" if self.captured: self.graph.replay() else: self._step_body() def last_loss(self): """Read the most recent captured-graph loss from the GPU buffer. Forces a stream sync — call only when you actually need the value.""" return self.buf_loss.item() # ============================================================ # Self-test # ============================================================ if __name__ == "__main__": from mock_eupe_backbone import make_mock_features, make_mock_boxes from target_cache import ( precompute_targets_for_image, make_locations, STRIDES, SIZE_RANGES, H, ) from train_split_tower_5scale import SplitTowerHead if not torch.cuda.is_available(): print("CUDA not available; skipping CUDA Graphs test.") sys.exit(0) print("CUDA Graphs self-test") print("=" * 60) B, M = 16, 64 head = SplitTowerHead(hidden=192, n_std_layers=5, n_dw_layers=4, n_scales=4).cuda() optimizer = torch.optim.AdamW(head.parameters(), lr=5e-4, weight_decay=1e-4, capturable=True) feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), (H // 4, H // 4), (H // 8, H // 8)] locs_per_level = make_locations(feat_sizes, STRIDES, torch.device("cuda")) n_total = sum(loc.shape[0] for loc in locs_per_level) print(f" total locations: {n_total}") boxes_list, labels_list = make_mock_boxes(B=B, n_boxes_per_image=8, device="cuda", seed=0) spatial = make_mock_features(B=B, device="cuda", seed=0).to(torch.bfloat16) all_locs = torch.cat(locs_per_level, 0) n_per_level = [loc.shape[0] for loc in locs_per_level] level_ranges = [] cumsum = 0 for i, n in enumerate(n_per_level): lo, hi = SIZE_RANGES[i] level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi)) cumsum += n tgt_cls_batch = torch.zeros(B, n_total, dtype=torch.long, device="cuda") tgt_reg_batch = torch.zeros(B, n_total, 4, device="cuda") tgt_ctr_batch = torch.zeros(B, n_total, device="cuda") for i in range(B): tcls, treg, tctr = precompute_targets_for_image( boxes_list[i], labels_list[i], all_locs, level_ranges, "cuda") tgt_cls_batch[i] = tcls tgt_reg_batch[i] = treg tgt_ctr_batch[i] = tctr # Initialize graph-step step = CudaGraphTrainStep(head, optimizer, batch_size=B, max_boxes=M, all_locs=all_locs) step.set_inputs(spatial, tgt_cls_batch, tgt_reg_batch, tgt_ctr_batch) print("\n1. Capturing CUDA graph...") t0 = time.time() step.warmup_and_capture() capture_time = time.time() - t0 print(f" capture time (3 warmup + 1 capture): {capture_time*1000:.0f} ms") print("\n2. Benchmarking graph replay vs eager-mode step...") # Eager-mode body that replicates _step_body without graph capture def eager_step(): with torch.autocast("cuda", dtype=torch.bfloat16): cls_per, reg_per, ctr_per = head(step.buf_spatial) flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1, 80) for c in cls_per], 1).float() flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in reg_per], 1).float() flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1) for c in ctr_per], 1).float() pos = step.buf_tgt_cls >= 0 npos = pos.sum().clamp(min=1).float() oh = torch.zeros_like(flat_cls) cls_idx = step.buf_tgt_cls.clamp(min=0) oh.scatter_(2, cls_idx.unsqueeze(-1), 1.0) oh = oh * pos.unsqueeze(-1).float() p = torch.sigmoid(flat_cls) ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none") pt = p * oh + (1 - p) * (1 - oh) at = 0.25 * oh + 0.75 * (1 - oh) loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos ctr_target = step.buf_tgt_ctr * pos.float() loss_ctr = (F.binary_cross_entropy_with_logits(flat_ctr, ctr_target, reduction="none") * pos.float()).sum() / npos pl = step.buf_locs.unsqueeze(0).expand(B, -1, -1) tgt_reg_f = step.buf_tgt_reg pb_x1 = pl[..., 0] - flat_reg[..., 0]; pb_y1 = pl[..., 1] - flat_reg[..., 1] pb_x2 = pl[..., 0] + flat_reg[..., 2]; pb_y2 = pl[..., 1] + flat_reg[..., 3] tb_x1 = pl[..., 0] - tgt_reg_f[..., 0]; tb_y1 = pl[..., 1] - tgt_reg_f[..., 1] tb_x2 = pl[..., 0] + tgt_reg_f[..., 2]; tb_y2 = pl[..., 1] + tgt_reg_f[..., 3] inter_x1 = torch.maximum(pb_x1, tb_x1); inter_y1 = torch.maximum(pb_y1, tb_y1) inter_x2 = torch.minimum(pb_x2, tb_x2); inter_y2 = torch.minimum(pb_y2, tb_y2) inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) ap = (pb_x2 - pb_x1).clamp(min=0) * (pb_y2 - pb_y1).clamp(min=0) at_a = (tb_x2 - tb_x1).clamp(min=0) * (tb_y2 - tb_y1).clamp(min=0) union = ap + at_a - inter iou = inter / union.clamp(min=1e-6) enc_x1 = torch.minimum(pb_x1, tb_x1); enc_y1 = torch.minimum(pb_y1, tb_y1) enc_x2 = torch.maximum(pb_x2, tb_x2); enc_y2 = torch.maximum(pb_y2, tb_y2) enc = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0) giou = iou - (enc - union) / enc.clamp(min=1e-6) loss_reg = ((1 - giou) * pos.float()).sum() / npos loss = loss_cls + step.reg_weight * loss_reg + loss_ctr loss.backward() optimizer.step() optimizer.zero_grad() return loss.item() # Warmup eager for _ in range(3): eager_step() torch.cuda.synchronize() N = 50 t0 = time.time() for _ in range(N): eager_step() torch.cuda.synchronize() eager_time = (time.time() - t0) / N t0 = time.time() for _ in range(N): step.run() torch.cuda.synchronize() # one sync at the end, not per-step graph_time = (time.time() - t0) / N print(f" eager step: {eager_time*1000:.2f} ms") print(f" graph replay: {graph_time*1000:.2f} ms") print(f" speedup: {eager_time / graph_time:.2f}x") if graph_time > eager_time: print(" WARNING: graph replay is slower; CUDA Graph not helping") else: print(f"\nGraph replay {eager_time / graph_time:.2f}x faster than eager mode.")