| """Option 4: Precomputed FCOS target cache. |
| |
| The FCOS target assignment for each image is deterministic given |
| (spatial_features, boxes, labels) and a fixed level layout (strides + sizes). |
| Our 5-scale layout is fixed, so we can precompute targets once per image and |
| cache them alongside the spatial features in each shard. Training then loads |
| targets directly instead of recomputing on every forward pass. |
| |
| Specific to our backbone configuration: 640px input, 40x40 stride-16 spatial |
| output, 5 prediction levels at strides [8, 16, 32, 64, 128] with FCOS standard |
| size ranges. Any architecture change to scale count, strides, or size ranges |
| invalidates the cache. |
| |
| Includes a thorough self-test: builds a synthetic shard via the mock backbone, |
| precomputes targets, runs the same data through the original |
| assign_targets_batched, and asserts bitwise equivalence of all target tensors. |
| """ |
| import json |
| import os |
| import sys |
| import time |
|
|
| import torch |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| |
| STRIDES = [8, 16, 32, 64, 128] |
| SIZE_RANGES = [(-1, 32), (32, 64), (64, 128), (128, 256), (256, float("inf"))] |
| RESOLUTION = 640 |
| H = RESOLUTION // 16 |
|
|
|
|
| def make_locations(feature_sizes, strides, device): |
| locs = [] |
| for (h, w), s in zip(feature_sizes, strides): |
| ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") |
| locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) |
| return locs |
|
|
|
|
| def precompute_targets_for_image(boxes, labels, locs, level_ranges, device): |
| """Compute FCOS targets for one image. Mirrors assign_targets_batched but |
| operates on a single image (B=1 implicit) so we can store per-image targets. |
| |
| boxes: (M, 4) in (x1, y1, x2, y2) |
| labels: (M,) int |
| locs: concatenated (N_total, 2) of (cx, cy) |
| level_ranges: list of (start, end, stride, size_lo, size_hi) |
| Returns: |
| tgt_cls: (N_total,) class index or -1 |
| tgt_reg: (N_total, 4) ltrb distances (only valid where tgt_cls >= 0) |
| tgt_ctr: (N_total,) centerness (only valid where tgt_cls >= 0) |
| """ |
| N = locs.shape[0] |
| tgt_cls = torch.full((N,), -1, dtype=torch.long, device=device) |
| tgt_reg = torch.zeros(N, 4, device=device) |
| tgt_ctr = torch.zeros(N, device=device) |
|
|
| if boxes.numel() == 0: |
| return tgt_cls, tgt_reg, tgt_ctr |
|
|
| areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
| M = boxes.shape[0] |
|
|
| for lo, hi, stride, slo, shi in level_ranges: |
| n = hi - lo |
| loc = locs[lo:hi] |
| l = loc[:, None, 0] - boxes[None, :, 0] |
| t = loc[:, None, 1] - boxes[None, :, 1] |
| r = boxes[None, :, 2] - loc[:, None, 0] |
| b = boxes[None, :, 3] - loc[:, None, 1] |
| ltrb = torch.stack([l, t, r, b], dim=-1) |
| in_box = ltrb.min(dim=-1).values > 0 |
| cx = (boxes[:, 0] + boxes[:, 2]) / 2 |
| cy = (boxes[:, 1] + boxes[:, 3]) / 2 |
| rad = stride * 1.5 |
| in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & |
| (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) |
| max_d = ltrb.max(dim=-1).values |
| in_level = (max_d >= slo) & (max_d <= shi) |
| pos = in_box & in_center & in_level |
| a = areas[None, :].expand_as(pos).clone() |
| a[~pos] = float("inf") |
| matched = a.argmin(dim=-1) |
| is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") |
|
|
| if is_pos.any(): |
| tgt_cls[lo:hi][is_pos] = labels[matched[is_pos]] |
| arange_n = torch.arange(n, device=device)[is_pos] |
| ltrb_pos = ltrb[arange_n, matched[is_pos]] |
| tgt_reg[lo:hi][is_pos] = ltrb_pos |
| lp, tp, rp, bp = ltrb_pos.unbind(-1) |
| tgt_ctr[lo:hi][is_pos] = torch.sqrt( |
| (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) * |
| (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6))) |
|
|
| return tgt_cls, tgt_reg, tgt_ctr |
|
|
|
|
| def precompute_shard_targets(shard, device="cuda"): |
| """Add precomputed (tgt_cls, tgt_reg, tgt_ctr) to each entry in a shard. |
| |
| Modifies shard in place. Each entry gains three keys: |
| tgt_cls: (N_total,) int8 — stored compactly; -1 for negatives. |
| tgt_reg: (N_total, 4) float16 — only meaningful where tgt_cls >= 0. |
| tgt_ctr: (N_total,) float16 — only meaningful where tgt_cls >= 0. |
| """ |
| feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), |
| (H // 4, H // 4), (H // 8, H // 8)] |
| locs_per_level = make_locations(feat_sizes, STRIDES, torch.device(device)) |
| all_locs = torch.cat(locs_per_level, 0) |
| n_per_level = [loc.shape[0] for loc in locs_per_level] |
| level_ranges = [] |
| cumsum = 0 |
| for i, n in enumerate(n_per_level): |
| lo, hi = SIZE_RANGES[i] |
| level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi)) |
| cumsum += n |
|
|
| for entry in shard: |
| boxes = entry["boxes"].to(device).float() |
| labels = entry["labels"].to(device).long() |
| tcls, treg, tctr = precompute_targets_for_image( |
| boxes, labels, all_locs, level_ranges, device) |
| |
| entry["tgt_cls"] = tcls.to(torch.int16).cpu() |
| entry["tgt_reg"] = treg.to(torch.float16).cpu() |
| entry["tgt_ctr"] = tctr.to(torch.float16).cpu() |
| return shard |
|
|
|
|
| def precompute_loss_with_cache(cls_per, reg_per, ctr_per, batch_tgt_cls, batch_tgt_reg, batch_tgt_ctr, |
| num_classes=80): |
| """Compute FCOS loss using PRECOMPUTED targets — replaces the assignment |
| step with cache lookup. The classification, regression, and centerness |
| losses themselves are unchanged from the in-line version. |
| |
| cls_per/reg_per/ctr_per: lists of per-level prediction tensors (B, C, H, W) |
| batch_tgt_cls/reg/ctr: per-batch precomputed targets (B, N_total) and (B, N_total, 4) |
| """ |
| import torch.nn.functional as F |
| B = cls_per[0].shape[0] |
| device = cls_per[0].device |
| flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1, num_classes) for c in cls_per], 1) |
| flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in reg_per], 1) |
| flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1) for c in ctr_per], 1) |
|
|
| pos = batch_tgt_cls >= 0 |
| npos = max(pos.sum().item(), 1) |
| oh = torch.zeros_like(flat_cls) |
| pi = pos.nonzero(as_tuple=True) |
| oh[pi[0], pi[1], batch_tgt_cls[pos].long()] = 1.0 |
|
|
| |
| p = torch.sigmoid(flat_cls) |
| ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none") |
| pt = p * oh + (1 - p) * (1 - oh) |
| at = 0.25 * oh + 0.75 * (1 - oh) |
| loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos |
|
|
| if pos.any(): |
| |
| feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), |
| (H // 4, H // 4), (H // 8, H // 8)] |
| all_locs = torch.cat(make_locations(feat_sizes, STRIDES, device), 0) |
| pl = all_locs[None].expand(B, -1, -1)[pos] |
| pp = flat_reg[pos] |
| tp = batch_tgt_reg[pos].float() |
| pb = torch.stack([pl[:, 0] - pp[:, 0], pl[:, 1] - pp[:, 1], |
| pl[:, 0] + pp[:, 2], pl[:, 1] + pp[:, 3]], -1) |
| tb = torch.stack([pl[:, 0] - tp[:, 0], pl[:, 1] - tp[:, 1], |
| pl[:, 0] + tp[:, 2], pl[:, 1] + tp[:, 3]], -1) |
| from torchvision.ops import generalized_box_iou |
| giou = generalized_box_iou(pb, tb) |
| loss_reg = (1 - giou.diagonal()).sum() / npos |
| loss_ctr = F.binary_cross_entropy_with_logits( |
| flat_ctr[pos], batch_tgt_ctr[pos].float(), reduction="sum") / npos |
| else: |
| loss_reg = torch.tensor(0.0, device=device) |
| loss_ctr = torch.tensor(0.0, device=device) |
|
|
| return loss_cls + loss_reg + loss_ctr |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| from mock_eupe_backbone import make_mock_features, make_mock_boxes |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Self-test on {device}") |
| print("=" * 60) |
|
|
| B = 4 |
| boxes_list, labels_list = make_mock_boxes(B=B, n_boxes_per_image=8, device=device, seed=0) |
|
|
| |
| print("\n1. Building synthetic shard via mock features + boxes...") |
| shard = [] |
| for i in range(B): |
| feats = make_mock_features(B=1, device=device, seed=i)[0].half() |
| shard.append({ |
| "img_id": i, |
| "spatial": feats, |
| "boxes": boxes_list[i].cpu(), |
| "labels": labels_list[i].cpu(), |
| "scale": 1.0, |
| }) |
| print(f" shard with {len(shard)} entries") |
|
|
| |
| print("\n2. Precomputing targets for each image...") |
| t0 = time.time() |
| shard = precompute_shard_targets(shard, device=device) |
| t_precompute = time.time() - t0 |
| print(f" precompute time: {t_precompute*1000:.1f} ms ({t_precompute*1000/B:.1f} ms/image)") |
| for i, e in enumerate(shard): |
| n_pos = (e["tgt_cls"] >= 0).sum().item() |
| print(f" img {i}: tgt_cls shape {e['tgt_cls'].shape}, {n_pos} positives") |
|
|
| |
| print("\n3. Verifying equivalence with in-line assign_targets_batched...") |
| from cache_and_train_fast import assign_targets_batched |
| feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), (H // 4, H // 4), (H // 8, H // 8)] |
| locs_per_level = make_locations(feat_sizes, STRIDES, torch.device(device)) |
| all_locs = torch.cat(locs_per_level, 0) |
| n_per_level = [loc.shape[0] for loc in locs_per_level] |
| level_ranges = [] |
| cumsum = 0 |
| strides_per_loc = torch.zeros(all_locs.shape[0], device=device) |
| for i, n in enumerate(n_per_level): |
| lo, hi = SIZE_RANGES[i] |
| level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi)) |
| strides_per_loc[cumsum:cumsum + n] = STRIDES[i] |
| cumsum += n |
|
|
| max_m = max(b.shape[0] for b in boxes_list) |
| boxes_padded = torch.zeros(B, max_m, 4, device=device) |
| labels_padded = torch.zeros(B, max_m, dtype=torch.long, device=device) |
| valid_mask = torch.zeros(B, max_m, dtype=torch.bool, device=device) |
| for i in range(B): |
| m = boxes_list[i].shape[0] |
| boxes_padded[i, :m] = boxes_list[i] |
| labels_padded[i, :m] = labels_list[i] |
| valid_mask[i, :m] = True |
|
|
| inline_cls, inline_reg, inline_ctr = assign_targets_batched( |
| all_locs, level_ranges, boxes_padded, labels_padded, valid_mask, strides_per_loc) |
|
|
| cached_cls = torch.stack([e["tgt_cls"].to(device).long() for e in shard]) |
| cached_reg = torch.stack([e["tgt_reg"].to(device).float() for e in shard]) |
| cached_ctr = torch.stack([e["tgt_ctr"].to(device).float() for e in shard]) |
|
|
| cls_match = torch.equal(cached_cls, inline_cls) |
| reg_diff = (cached_reg - inline_reg)[inline_cls >= 0].abs().max().item() if (inline_cls >= 0).any() else 0 |
| ctr_diff = (cached_ctr - inline_ctr)[inline_cls >= 0].abs().max().item() if (inline_cls >= 0).any() else 0 |
| print(f" cls equal: {cls_match}") |
| print(f" reg max abs diff (positives only, fp16 precision): {reg_diff:.6f}") |
| print(f" ctr max abs diff (positives only, fp16 precision): {ctr_diff:.6f}") |
|
|
| if not cls_match: |
| n_diff = (cached_cls != inline_cls).sum().item() |
| print(f" WARNING: {n_diff} cls mismatches") |
| sys.exit(1) |
| if reg_diff > 0.5 or ctr_diff > 0.01: |
| print(f" WARNING: reg/ctr drift exceeds fp16 tolerance") |
| sys.exit(1) |
|
|
| print("\n4. Benchmarking loss computation: cached vs in-line...") |
| |
| cls_per = [torch.randn(B, 80, h, w, device=device) for (h, w) in feat_sizes] |
| reg_per = [torch.rand(B, 4, h, w, device=device) * 30 for (h, w) in feat_sizes] |
| ctr_per = [torch.randn(B, 1, h, w, device=device) for (h, w) in feat_sizes] |
|
|
| from cache_and_train_fast import compute_loss |
| |
| for _ in range(3): |
| _ = compute_loss(cls_per, reg_per, ctr_per, locs_per_level, boxes_list, labels_list) |
| _ = precompute_loss_with_cache(cls_per, reg_per, ctr_per, cached_cls, cached_reg, cached_ctr) |
| torch.cuda.synchronize() if device == "cuda" else None |
|
|
| N_ITERS = 100 |
| t0 = time.time() |
| for _ in range(N_ITERS): |
| _ = compute_loss(cls_per, reg_per, ctr_per, locs_per_level, boxes_list, labels_list) |
| if device == "cuda": torch.cuda.synchronize() |
| inline_time = (time.time() - t0) / N_ITERS |
|
|
| t0 = time.time() |
| for _ in range(N_ITERS): |
| _ = precompute_loss_with_cache(cls_per, reg_per, ctr_per, cached_cls, cached_reg, cached_ctr) |
| if device == "cuda": torch.cuda.synchronize() |
| cached_time = (time.time() - t0) / N_ITERS |
|
|
| print(f" in-line compute_loss: {inline_time*1000:.2f} ms/iter") |
| print(f" cached compute_loss: {cached_time*1000:.2f} ms/iter") |
| print(f" speedup: {inline_time / cached_time:.2f}x") |
|
|
| print("\nAll Option 4 tests passed.") |
|
|