| """Option 1: CUDA Graphs wrapper. |
| |
| Captures the full forward + backward + optimizer step as a CUDA graph that can |
| be replayed on each iteration. Eliminates per-op Python dispatch overhead |
| (our likely bottleneck given that batch 8/16/32 all yield ~95 img/s). |
| |
| Constraint: CUDA Graphs require fixed input shapes. Our actual data has |
| variable per-image GT box counts. Workaround: pad boxes to a fixed maximum |
| (e.g., 64 per image) and use a validity mask. The mask is part of the captured |
| graph; only the values inside change between iterations. |
| |
| Includes a self-test against the mock backbone that benchmarks the graph-replay |
| training step against the eager-mode equivalent. |
| """ |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
|
|
| class CudaGraphTrainStep: |
| """Captures one full training step (forward + loss + backward + optimizer.step) |
| into a CUDA graph and replays it. Inputs and targets are written into pinned |
| device buffers; outputs are read back from device buffers. |
| |
| Uses bf16 throughout (no GradScaler needed) because GradScaler.step() does |
| a Python-level .item() call which is forbidden during stream capture. bf16 |
| has a wide enough exponent range (8-bit) to handle the value distributions |
| in our model without scaling. |
| |
| Usage: |
| step = CudaGraphTrainStep(head, optimizer, batch_size=16) |
| step.warmup_and_capture() |
| for batch in loader: |
| step.set_inputs(spatial, tgt_cls, tgt_reg, tgt_ctr) |
| loss_value = step.run() |
| """ |
| def __init__(self, head, optimizer, batch_size, max_boxes, all_locs, |
| spatial_h=40, spatial_w=40, feat_dim=768, n_classes=80, |
| reg_weight=1.0, sheaf_weight=0.1): |
| """all_locs: (N, 2) tensor of (cx, cy) for every prediction location across |
| every level, in the same flatten order produced by `cat([per-level], 1)` |
| in `_step_body`. Required for the in-graph GIoU box regression loss.""" |
| self.head = head |
| self.optimizer = optimizer |
| self.B = batch_size |
| self.M = max_boxes |
| self.spatial_h = spatial_h |
| self.spatial_w = spatial_w |
| self.feat_dim = feat_dim |
| self.C = n_classes |
| self.N = all_locs.shape[0] |
| self.reg_weight = reg_weight |
| self.sheaf_weight = sheaf_weight |
| self.buf_spatial = torch.zeros(batch_size, feat_dim, spatial_h, spatial_w, |
| device="cuda", dtype=torch.bfloat16) |
| self.buf_tgt_cls = torch.full((batch_size, self.N), -1, device="cuda", dtype=torch.long) |
| self.buf_tgt_reg = torch.zeros(batch_size, self.N, 4, device="cuda", dtype=torch.float32) |
| self.buf_tgt_ctr = torch.zeros(batch_size, self.N, device="cuda", dtype=torch.float32) |
| self.buf_locs = all_locs.float().to("cuda") |
| self.buf_loss = torch.zeros(1, device="cuda", dtype=torch.float32) |
| self.graph = None |
| self.captured = False |
|
|
| def _step_body(self): |
| """Forward under bf16 autocast; loss + backward + optimizer step in |
| fp32. GradScaler is incompatible with stream capture; autocast(bf16) |
| substitutes (bf16's fp32-equivalent exponent range needs no scaling).""" |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| cls_per, reg_per, ctr_per = self.head(self.buf_spatial) |
| flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(self.B, -1, self.C) for c in cls_per], 1).float() |
| flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(self.B, -1, 4) for r in reg_per], 1).float() |
| flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(self.B, -1) for c in ctr_per], 1).float() |
| pos = self.buf_tgt_cls >= 0 |
| npos = pos.sum().clamp(min=1).float() |
| oh = torch.zeros_like(flat_cls) |
| cls_idx = self.buf_tgt_cls.clamp(min=0) |
| oh.scatter_(2, cls_idx.unsqueeze(-1), 1.0) |
| oh = oh * pos.unsqueeze(-1).float() |
| p = torch.sigmoid(flat_cls) |
| ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none") |
| pt = p * oh + (1 - p) * (1 - oh) |
| at = 0.25 * oh + 0.75 * (1 - oh) |
| loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos |
| ctr_target = self.buf_tgt_ctr * pos.float() |
| loss_ctr = (F.binary_cross_entropy_with_logits(flat_ctr, ctr_target, reduction="none") * pos.float()).sum() / npos |
| |
| |
| |
| pl = self.buf_locs.unsqueeze(0).expand(self.B, -1, -1) |
| tgt_reg_f = self.buf_tgt_reg |
| pb_x1 = pl[..., 0] - flat_reg[..., 0] |
| pb_y1 = pl[..., 1] - flat_reg[..., 1] |
| pb_x2 = pl[..., 0] + flat_reg[..., 2] |
| pb_y2 = pl[..., 1] + flat_reg[..., 3] |
| tb_x1 = pl[..., 0] - tgt_reg_f[..., 0] |
| tb_y1 = pl[..., 1] - tgt_reg_f[..., 1] |
| tb_x2 = pl[..., 0] + tgt_reg_f[..., 2] |
| tb_y2 = pl[..., 1] + tgt_reg_f[..., 3] |
| inter_x1 = torch.maximum(pb_x1, tb_x1) |
| inter_y1 = torch.maximum(pb_y1, tb_y1) |
| inter_x2 = torch.minimum(pb_x2, tb_x2) |
| inter_y2 = torch.minimum(pb_y2, tb_y2) |
| inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) |
| ap = (pb_x2 - pb_x1).clamp(min=0) * (pb_y2 - pb_y1).clamp(min=0) |
| at_a = (tb_x2 - tb_x1).clamp(min=0) * (tb_y2 - tb_y1).clamp(min=0) |
| union = ap + at_a - inter |
| iou = inter / union.clamp(min=1e-6) |
| enc_x1 = torch.minimum(pb_x1, tb_x1) |
| enc_y1 = torch.minimum(pb_y1, tb_y1) |
| enc_x2 = torch.maximum(pb_x2, tb_x2) |
| enc_y2 = torch.maximum(pb_y2, tb_y2) |
| enc = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0) |
| giou = iou - (enc - union) / enc.clamp(min=1e-6) |
| loss_reg = ((1 - giou) * pos.float()).sum() / npos |
| |
| |
| |
| |
| loss_sheaf = torch.tensor(0.0, device=flat_cls.device) |
| for c in cls_per: |
| cf = c.float() |
| h_diff = (cf[:, :, :, :-1] - cf[:, :, :, 1:]).pow(2).mean() |
| v_diff = (cf[:, :, :-1, :] - cf[:, :, 1:, :]).pow(2).mean() |
| loss_sheaf = loss_sheaf + h_diff + v_diff |
| loss_sheaf = loss_sheaf * self.sheaf_weight |
| loss = loss_cls + self.reg_weight * loss_reg + loss_ctr + loss_sheaf |
| loss.backward() |
| self.optimizer.step() |
| self.optimizer.zero_grad(set_to_none=False) |
| self.buf_loss.copy_(loss.detach()) |
|
|
| def warmup_and_capture(self): |
| """Run a few warmup steps in eager mode (allocates buffers, primes |
| cudnn), then capture the graph.""" |
| |
| for _ in range(3): |
| self._step_body() |
| torch.cuda.synchronize() |
| |
| self.graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(self.graph): |
| self._step_body() |
| self.captured = True |
|
|
| def set_inputs(self, spatial, tgt_cls, tgt_reg, tgt_ctr): |
| """Copy inputs into the persistent buffers.""" |
| self.buf_spatial.copy_(spatial) |
| self.buf_tgt_cls.copy_(tgt_cls) |
| self.buf_tgt_reg.copy_(tgt_reg) |
| self.buf_tgt_ctr.copy_(tgt_ctr) |
|
|
| def run(self): |
| """Async: replays the captured graph if available, else runs one eager |
| training step from the persistent buffers. No sync — call `last_loss()` |
| when the scalar is needed (sync point) for logging.""" |
| if self.captured: |
| self.graph.replay() |
| else: |
| self._step_body() |
|
|
| def last_loss(self): |
| """Read the most recent captured-graph loss from the GPU buffer. |
| Forces a stream sync — call only when you actually need the value.""" |
| return self.buf_loss.item() |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| from mock_eupe_backbone import make_mock_features, make_mock_boxes |
| from target_cache import ( |
| precompute_targets_for_image, make_locations, |
| STRIDES, SIZE_RANGES, H, |
| ) |
| from train_split_tower_5scale import SplitTowerHead |
|
|
| if not torch.cuda.is_available(): |
| print("CUDA not available; skipping CUDA Graphs test.") |
| sys.exit(0) |
|
|
| print("CUDA Graphs self-test") |
| print("=" * 60) |
|
|
| B, M = 16, 64 |
| head = SplitTowerHead(hidden=192, n_std_layers=5, n_dw_layers=4, n_scales=4).cuda() |
| optimizer = torch.optim.AdamW(head.parameters(), lr=5e-4, weight_decay=1e-4, capturable=True) |
|
|
| feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), |
| (H // 4, H // 4), (H // 8, H // 8)] |
| locs_per_level = make_locations(feat_sizes, STRIDES, torch.device("cuda")) |
| n_total = sum(loc.shape[0] for loc in locs_per_level) |
| print(f" total locations: {n_total}") |
|
|
| boxes_list, labels_list = make_mock_boxes(B=B, n_boxes_per_image=8, device="cuda", seed=0) |
| spatial = make_mock_features(B=B, device="cuda", seed=0).to(torch.bfloat16) |
|
|
| all_locs = torch.cat(locs_per_level, 0) |
| n_per_level = [loc.shape[0] for loc in locs_per_level] |
| level_ranges = [] |
| cumsum = 0 |
| for i, n in enumerate(n_per_level): |
| lo, hi = SIZE_RANGES[i] |
| level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi)) |
| cumsum += n |
|
|
| tgt_cls_batch = torch.zeros(B, n_total, dtype=torch.long, device="cuda") |
| tgt_reg_batch = torch.zeros(B, n_total, 4, device="cuda") |
| tgt_ctr_batch = torch.zeros(B, n_total, device="cuda") |
| for i in range(B): |
| tcls, treg, tctr = precompute_targets_for_image( |
| boxes_list[i], labels_list[i], all_locs, level_ranges, "cuda") |
| tgt_cls_batch[i] = tcls |
| tgt_reg_batch[i] = treg |
| tgt_ctr_batch[i] = tctr |
|
|
| |
| step = CudaGraphTrainStep(head, optimizer, batch_size=B, max_boxes=M, |
| all_locs=all_locs) |
| step.set_inputs(spatial, tgt_cls_batch, tgt_reg_batch, tgt_ctr_batch) |
|
|
| print("\n1. Capturing CUDA graph...") |
| t0 = time.time() |
| step.warmup_and_capture() |
| capture_time = time.time() - t0 |
| print(f" capture time (3 warmup + 1 capture): {capture_time*1000:.0f} ms") |
|
|
| print("\n2. Benchmarking graph replay vs eager-mode step...") |
| |
| def eager_step(): |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| cls_per, reg_per, ctr_per = head(step.buf_spatial) |
| flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1, 80) for c in cls_per], 1).float() |
| flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in reg_per], 1).float() |
| flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1) for c in ctr_per], 1).float() |
| pos = step.buf_tgt_cls >= 0 |
| npos = pos.sum().clamp(min=1).float() |
| oh = torch.zeros_like(flat_cls) |
| cls_idx = step.buf_tgt_cls.clamp(min=0) |
| oh.scatter_(2, cls_idx.unsqueeze(-1), 1.0) |
| oh = oh * pos.unsqueeze(-1).float() |
| p = torch.sigmoid(flat_cls) |
| ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none") |
| pt = p * oh + (1 - p) * (1 - oh) |
| at = 0.25 * oh + 0.75 * (1 - oh) |
| loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos |
| ctr_target = step.buf_tgt_ctr * pos.float() |
| loss_ctr = (F.binary_cross_entropy_with_logits(flat_ctr, ctr_target, reduction="none") * pos.float()).sum() / npos |
| pl = step.buf_locs.unsqueeze(0).expand(B, -1, -1) |
| tgt_reg_f = step.buf_tgt_reg |
| pb_x1 = pl[..., 0] - flat_reg[..., 0]; pb_y1 = pl[..., 1] - flat_reg[..., 1] |
| pb_x2 = pl[..., 0] + flat_reg[..., 2]; pb_y2 = pl[..., 1] + flat_reg[..., 3] |
| tb_x1 = pl[..., 0] - tgt_reg_f[..., 0]; tb_y1 = pl[..., 1] - tgt_reg_f[..., 1] |
| tb_x2 = pl[..., 0] + tgt_reg_f[..., 2]; tb_y2 = pl[..., 1] + tgt_reg_f[..., 3] |
| inter_x1 = torch.maximum(pb_x1, tb_x1); inter_y1 = torch.maximum(pb_y1, tb_y1) |
| inter_x2 = torch.minimum(pb_x2, tb_x2); inter_y2 = torch.minimum(pb_y2, tb_y2) |
| inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) |
| ap = (pb_x2 - pb_x1).clamp(min=0) * (pb_y2 - pb_y1).clamp(min=0) |
| at_a = (tb_x2 - tb_x1).clamp(min=0) * (tb_y2 - tb_y1).clamp(min=0) |
| union = ap + at_a - inter |
| iou = inter / union.clamp(min=1e-6) |
| enc_x1 = torch.minimum(pb_x1, tb_x1); enc_y1 = torch.minimum(pb_y1, tb_y1) |
| enc_x2 = torch.maximum(pb_x2, tb_x2); enc_y2 = torch.maximum(pb_y2, tb_y2) |
| enc = (enc_x2 - enc_x1).clamp(min=0) * (enc_y2 - enc_y1).clamp(min=0) |
| giou = iou - (enc - union) / enc.clamp(min=1e-6) |
| loss_reg = ((1 - giou) * pos.float()).sum() / npos |
| loss = loss_cls + step.reg_weight * loss_reg + loss_ctr |
| loss.backward() |
| optimizer.step() |
| optimizer.zero_grad() |
| return loss.item() |
|
|
| |
| for _ in range(3): |
| eager_step() |
| torch.cuda.synchronize() |
|
|
| N = 50 |
| t0 = time.time() |
| for _ in range(N): |
| eager_step() |
| torch.cuda.synchronize() |
| eager_time = (time.time() - t0) / N |
|
|
| t0 = time.time() |
| for _ in range(N): |
| step.run() |
| torch.cuda.synchronize() |
| graph_time = (time.time() - t0) / N |
|
|
| print(f" eager step: {eager_time*1000:.2f} ms") |
| print(f" graph replay: {graph_time*1000:.2f} ms") |
| print(f" speedup: {eager_time / graph_time:.2f}x") |
|
|
| if graph_time > eager_time: |
| print(" WARNING: graph replay is slower; CUDA Graph not helping") |
| else: |
| print(f"\nGraph replay {eager_time / graph_time:.2f}x faster than eager mode.") |
|
|