File size: 13,416 Bytes
74e3c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""Option 4: Precomputed FCOS target cache.

The FCOS target assignment for each image is deterministic given
(spatial_features, boxes, labels) and a fixed level layout (strides + sizes).
Our 5-scale layout is fixed, so we can precompute targets once per image and
cache them alongside the spatial features in each shard. Training then loads
targets directly instead of recomputing on every forward pass.

Specific to our backbone configuration: 640px input, 40x40 stride-16 spatial
output, 5 prediction levels at strides [8, 16, 32, 64, 128] with FCOS standard
size ranges. Any architecture change to scale count, strides, or size ranges
invalidates the cache.

Includes a thorough self-test: builds a synthetic shard via the mock backbone,
precomputes targets, runs the same data through the original
assign_targets_batched, and asserts bitwise equivalence of all target tensors.
"""
import json
import os
import sys
import time

import torch

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)

# Fixed level layout for the 5-scale split-tower head this cache targets.
STRIDES = [8, 16, 32, 64, 128]
SIZE_RANGES = [(-1, 32), (32, 64), (64, 128), (128, 256), (256, float("inf"))]
RESOLUTION = 640
H = RESOLUTION // 16  # 40 — base patch grid


def make_locations(feature_sizes, strides, device):
    locs = []
    for (h, w), s in zip(feature_sizes, strides):
        ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
        xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
        gy, gx = torch.meshgrid(ys, xs, indexing="ij")
        locs.append(torch.stack([gx.flatten(), gy.flatten()], -1))
    return locs


def precompute_targets_for_image(boxes, labels, locs, level_ranges, device):
    """Compute FCOS targets for one image. Mirrors assign_targets_batched but
    operates on a single image (B=1 implicit) so we can store per-image targets.

    boxes: (M, 4) in (x1, y1, x2, y2)
    labels: (M,) int
    locs: concatenated (N_total, 2) of (cx, cy)
    level_ranges: list of (start, end, stride, size_lo, size_hi)
    Returns:
      tgt_cls: (N_total,) class index or -1
      tgt_reg: (N_total, 4) ltrb distances (only valid where tgt_cls >= 0)
      tgt_ctr: (N_total,) centerness (only valid where tgt_cls >= 0)
    """
    N = locs.shape[0]
    tgt_cls = torch.full((N,), -1, dtype=torch.long, device=device)
    tgt_reg = torch.zeros(N, 4, device=device)
    tgt_ctr = torch.zeros(N, device=device)

    if boxes.numel() == 0:
        return tgt_cls, tgt_reg, tgt_ctr

    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    M = boxes.shape[0]

    for lo, hi, stride, slo, shi in level_ranges:
        n = hi - lo
        loc = locs[lo:hi]
        l = loc[:, None, 0] - boxes[None, :, 0]
        t = loc[:, None, 1] - boxes[None, :, 1]
        r = boxes[None, :, 2] - loc[:, None, 0]
        b = boxes[None, :, 3] - loc[:, None, 1]
        ltrb = torch.stack([l, t, r, b], dim=-1)
        in_box = ltrb.min(dim=-1).values > 0
        cx = (boxes[:, 0] + boxes[:, 2]) / 2
        cy = (boxes[:, 1] + boxes[:, 3]) / 2
        rad = stride * 1.5
        in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) &
                     (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad))
        max_d = ltrb.max(dim=-1).values
        in_level = (max_d >= slo) & (max_d <= shi)
        pos = in_box & in_center & in_level
        a = areas[None, :].expand_as(pos).clone()
        a[~pos] = float("inf")
        matched = a.argmin(dim=-1)
        is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf")

        if is_pos.any():
            tgt_cls[lo:hi][is_pos] = labels[matched[is_pos]]
            arange_n = torch.arange(n, device=device)[is_pos]
            ltrb_pos = ltrb[arange_n, matched[is_pos]]
            tgt_reg[lo:hi][is_pos] = ltrb_pos
            lp, tp, rp, bp = ltrb_pos.unbind(-1)
            tgt_ctr[lo:hi][is_pos] = torch.sqrt(
                (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) *
                (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6)))

    return tgt_cls, tgt_reg, tgt_ctr


def precompute_shard_targets(shard, device="cuda"):
    """Add precomputed (tgt_cls, tgt_reg, tgt_ctr) to each entry in a shard.

    Modifies shard in place. Each entry gains three keys:
      tgt_cls: (N_total,) int8 — stored compactly; -1 for negatives.
      tgt_reg: (N_total, 4) float16 — only meaningful where tgt_cls >= 0.
      tgt_ctr: (N_total,) float16 — only meaningful where tgt_cls >= 0.
    """
    feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2),
                  (H // 4, H // 4), (H // 8, H // 8)]
    locs_per_level = make_locations(feat_sizes, STRIDES, torch.device(device))
    all_locs = torch.cat(locs_per_level, 0)
    n_per_level = [loc.shape[0] for loc in locs_per_level]
    level_ranges = []
    cumsum = 0
    for i, n in enumerate(n_per_level):
        lo, hi = SIZE_RANGES[i]
        level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi))
        cumsum += n

    for entry in shard:
        boxes = entry["boxes"].to(device).float()
        labels = entry["labels"].to(device).long()
        tcls, treg, tctr = precompute_targets_for_image(
            boxes, labels, all_locs, level_ranges, device)
        # Store compactly: int16 for cls (saves 4×), fp16 for reg/ctr
        entry["tgt_cls"] = tcls.to(torch.int16).cpu()
        entry["tgt_reg"] = treg.to(torch.float16).cpu()
        entry["tgt_ctr"] = tctr.to(torch.float16).cpu()
    return shard


def precompute_loss_with_cache(cls_per, reg_per, ctr_per, batch_tgt_cls, batch_tgt_reg, batch_tgt_ctr,
                               num_classes=80):
    """Compute FCOS loss using PRECOMPUTED targets — replaces the assignment
    step with cache lookup. The classification, regression, and centerness
    losses themselves are unchanged from the in-line version.

    cls_per/reg_per/ctr_per: lists of per-level prediction tensors (B, C, H, W)
    batch_tgt_cls/reg/ctr: per-batch precomputed targets (B, N_total) and (B, N_total, 4)
    """
    import torch.nn.functional as F
    B = cls_per[0].shape[0]
    device = cls_per[0].device
    flat_cls = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1, num_classes) for c in cls_per], 1)
    flat_reg = torch.cat([r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in reg_per], 1)
    flat_ctr = torch.cat([c.permute(0, 2, 3, 1).reshape(B, -1) for c in ctr_per], 1)

    pos = batch_tgt_cls >= 0
    npos = max(pos.sum().item(), 1)
    oh = torch.zeros_like(flat_cls)
    pi = pos.nonzero(as_tuple=True)
    oh[pi[0], pi[1], batch_tgt_cls[pos].long()] = 1.0

    # Focal loss (matches existing implementation)
    p = torch.sigmoid(flat_cls)
    ce = F.binary_cross_entropy_with_logits(flat_cls, oh, reduction="none")
    pt = p * oh + (1 - p) * (1 - oh)
    at = 0.25 * oh + 0.75 * (1 - oh)
    loss_cls = (at * (1 - pt) ** 2 * ce).sum() / npos

    if pos.any():
        # Decode locations on the fly (still cheaper than full assignment)
        feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2),
                      (H // 4, H // 4), (H // 8, H // 8)]
        all_locs = torch.cat(make_locations(feat_sizes, STRIDES, device), 0)
        pl = all_locs[None].expand(B, -1, -1)[pos]
        pp = flat_reg[pos]
        tp = batch_tgt_reg[pos].float()
        pb = torch.stack([pl[:, 0] - pp[:, 0], pl[:, 1] - pp[:, 1],
                          pl[:, 0] + pp[:, 2], pl[:, 1] + pp[:, 3]], -1)
        tb = torch.stack([pl[:, 0] - tp[:, 0], pl[:, 1] - tp[:, 1],
                          pl[:, 0] + tp[:, 2], pl[:, 1] + tp[:, 3]], -1)
        from torchvision.ops import generalized_box_iou
        giou = generalized_box_iou(pb, tb)
        loss_reg = (1 - giou.diagonal()).sum() / npos
        loss_ctr = F.binary_cross_entropy_with_logits(
            flat_ctr[pos], batch_tgt_ctr[pos].float(), reduction="sum") / npos
    else:
        loss_reg = torch.tensor(0.0, device=device)
        loss_ctr = torch.tensor(0.0, device=device)

    return loss_cls + loss_reg + loss_ctr


# ============================================================
# Self-test using the mock backbone
# ============================================================
if __name__ == "__main__":
    from mock_eupe_backbone import make_mock_features, make_mock_boxes

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Self-test on {device}")
    print("=" * 60)

    B = 4
    boxes_list, labels_list = make_mock_boxes(B=B, n_boxes_per_image=8, device=device, seed=0)

    # Build a synthetic shard
    print("\n1. Building synthetic shard via mock features + boxes...")
    shard = []
    for i in range(B):
        feats = make_mock_features(B=1, device=device, seed=i)[0].half()
        shard.append({
            "img_id": i,
            "spatial": feats,
            "boxes": boxes_list[i].cpu(),
            "labels": labels_list[i].cpu(),
            "scale": 1.0,
        })
    print(f"   shard with {len(shard)} entries")

    # Precompute targets for each image
    print("\n2. Precomputing targets for each image...")
    t0 = time.time()
    shard = precompute_shard_targets(shard, device=device)
    t_precompute = time.time() - t0
    print(f"   precompute time: {t_precompute*1000:.1f} ms ({t_precompute*1000/B:.1f} ms/image)")
    for i, e in enumerate(shard):
        n_pos = (e["tgt_cls"] >= 0).sum().item()
        print(f"   img {i}: tgt_cls shape {e['tgt_cls'].shape}, {n_pos} positives")

    # Verify equivalence with in-line assign_targets_batched
    print("\n3. Verifying equivalence with in-line assign_targets_batched...")
    from cache_and_train_fast import assign_targets_batched
    feat_sizes = [(H * 2, H * 2), (H, H), (H // 2, H // 2), (H // 4, H // 4), (H // 8, H // 8)]
    locs_per_level = make_locations(feat_sizes, STRIDES, torch.device(device))
    all_locs = torch.cat(locs_per_level, 0)
    n_per_level = [loc.shape[0] for loc in locs_per_level]
    level_ranges = []
    cumsum = 0
    strides_per_loc = torch.zeros(all_locs.shape[0], device=device)
    for i, n in enumerate(n_per_level):
        lo, hi = SIZE_RANGES[i]
        level_ranges.append((cumsum, cumsum + n, STRIDES[i], lo, hi))
        strides_per_loc[cumsum:cumsum + n] = STRIDES[i]
        cumsum += n

    max_m = max(b.shape[0] for b in boxes_list)
    boxes_padded = torch.zeros(B, max_m, 4, device=device)
    labels_padded = torch.zeros(B, max_m, dtype=torch.long, device=device)
    valid_mask = torch.zeros(B, max_m, dtype=torch.bool, device=device)
    for i in range(B):
        m = boxes_list[i].shape[0]
        boxes_padded[i, :m] = boxes_list[i]
        labels_padded[i, :m] = labels_list[i]
        valid_mask[i, :m] = True

    inline_cls, inline_reg, inline_ctr = assign_targets_batched(
        all_locs, level_ranges, boxes_padded, labels_padded, valid_mask, strides_per_loc)

    cached_cls = torch.stack([e["tgt_cls"].to(device).long() for e in shard])
    cached_reg = torch.stack([e["tgt_reg"].to(device).float() for e in shard])
    cached_ctr = torch.stack([e["tgt_ctr"].to(device).float() for e in shard])

    cls_match = torch.equal(cached_cls, inline_cls)
    reg_diff = (cached_reg - inline_reg)[inline_cls >= 0].abs().max().item() if (inline_cls >= 0).any() else 0
    ctr_diff = (cached_ctr - inline_ctr)[inline_cls >= 0].abs().max().item() if (inline_cls >= 0).any() else 0
    print(f"   cls equal: {cls_match}")
    print(f"   reg max abs diff (positives only, fp16 precision): {reg_diff:.6f}")
    print(f"   ctr max abs diff (positives only, fp16 precision): {ctr_diff:.6f}")

    if not cls_match:
        n_diff = (cached_cls != inline_cls).sum().item()
        print(f"   WARNING: {n_diff} cls mismatches")
        sys.exit(1)
    if reg_diff > 0.5 or ctr_diff > 0.01:
        print(f"   WARNING: reg/ctr drift exceeds fp16 tolerance")
        sys.exit(1)

    print("\n4. Benchmarking loss computation: cached vs in-line...")
    # Build mock predictions
    cls_per = [torch.randn(B, 80, h, w, device=device) for (h, w) in feat_sizes]
    reg_per = [torch.rand(B, 4, h, w, device=device) * 30 for (h, w) in feat_sizes]
    ctr_per = [torch.randn(B, 1, h, w, device=device) for (h, w) in feat_sizes]

    from cache_and_train_fast import compute_loss
    # Warmup
    for _ in range(3):
        _ = compute_loss(cls_per, reg_per, ctr_per, locs_per_level, boxes_list, labels_list)
        _ = precompute_loss_with_cache(cls_per, reg_per, ctr_per, cached_cls, cached_reg, cached_ctr)
    torch.cuda.synchronize() if device == "cuda" else None

    N_ITERS = 100
    t0 = time.time()
    for _ in range(N_ITERS):
        _ = compute_loss(cls_per, reg_per, ctr_per, locs_per_level, boxes_list, labels_list)
    if device == "cuda": torch.cuda.synchronize()
    inline_time = (time.time() - t0) / N_ITERS

    t0 = time.time()
    for _ in range(N_ITERS):
        _ = precompute_loss_with_cache(cls_per, reg_per, ctr_per, cached_cls, cached_reg, cached_ctr)
    if device == "cuda": torch.cuda.synchronize()
    cached_time = (time.time() - t0) / N_ITERS

    print(f"   in-line compute_loss:  {inline_time*1000:.2f} ms/iter")
    print(f"   cached compute_loss:   {cached_time*1000:.2f} ms/iter")
    print(f"   speedup:               {inline_time / cached_time:.2f}x")

    print("\nAll Option 4 tests passed.")