"""Auto-extracted from detection_arena.py.""" import math import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import List from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES from losses.centernet import centernet_targets, centernet_loss from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES N_PREFIX = 5 class PatchAssembly(nn.Module): """Patches vote on object membership. Boxes emerge from vote clustering. 5.65M params.""" name = "J_patch_assembly" needs_intermediates = False def __init__(self, dim=384, n_layers=3): super().__init__() self.proj = nn.Linear(768, dim) self.layers = nn.ModuleList() for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn": nn.MultiheadAttention(dim, 8, batch_first=True), "norm1": nn.LayerNorm(dim), "ffn": nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)), "norm2": nn.LayerNorm(dim), })) self.cls_head = nn.Linear(dim, NUM_CLASSES) self.offset_head = nn.Linear(dim, 2) self.obj_head = nn.Linear(dim, 1) nn.init.constant_(self.cls_head.bias, -math.log(99)) def forward(self, spatial, inter=None): B, C, H, W = spatial.shape tokens = spatial.flatten(2).permute(0, 2, 1) x = self.proj(tokens) for layer in self.layers: x2, _ = layer["attn"](x, x, x) x = layer["norm1"](x + x2) x = layer["norm2"](x + layer["ffn"](x)) cls = self.cls_head(x) offset = self.offset_head(x).tanh() * 20 obj = self.obj_head(x) return cls, offset, obj, H, W def loss(self, preds, locs, boxes_b, labels_b): cls_logits, offsets, objectness, H, W = preds B = cls_logits.shape[0] device = cls_logits.device stride = 16 ys = (torch.arange(H, device=device, dtype=torch.float32) + 0.5) * stride xs = (torch.arange(W, device=device, dtype=torch.float32) + 0.5) * stride gy, gx = torch.meshgrid(ys, xs, indexing="ij") grid = torch.stack([gx.flatten(), gy.flatten()], -1) total = torch.tensor(0.0, device=device) for i in range(B): gt_boxes = boxes_b[i] gt_labels = labels_b[i] if gt_boxes.numel() == 0: total = total + torch.sigmoid(objectness[i]).sum() * 0.01 continue gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2 gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2 gt_centers = torch.stack([gt_cx, gt_cy], -1) dist = torch.cdist(grid.unsqueeze(0), gt_centers.unsqueeze(0))[0] nearest = dist.argmin(dim=1) max_dist = torch.stack([gt_boxes[:, 2] - gt_boxes[:, 0], gt_boxes[:, 3] - gt_boxes[:, 1]], -1).max(-1).values in_obj = dist[torch.arange(len(grid)), nearest] < max_dist[nearest] * 0.5 obj_tgt = in_obj.float() loss_obj = F.binary_cross_entropy_with_logits(objectness[i, :, 0], obj_tgt) if in_obj.any(): target_offset = (gt_centers[nearest[in_obj]] - grid[in_obj]) / stride loss_off = F.l1_loss(offsets[i, in_obj], target_offset) cls_tgt = torch.zeros(in_obj.sum(), NUM_CLASSES, device=device) cls_tgt[torch.arange(in_obj.sum()), gt_labels[nearest[in_obj]]] = 1.0 loss_cls = focal_loss(cls_logits[i, in_obj], cls_tgt) / max(in_obj.sum().item(), 1) else: loss_off = loss_cls = torch.tensor(0.0, device=device) total = total + loss_obj + loss_off + loss_cls return total / B def decode(self, preds, locs=None, score_thresh=0.3, **kw): cls_logits, offsets, objectness, H, W = preds B = cls_logits.shape[0] device = cls_logits.device stride = 16 ys = (torch.arange(H, device=device, dtype=torch.float32) + 0.5) * stride xs = (torch.arange(W, device=device, dtype=torch.float32) + 0.5) * stride gy, gx = torch.meshgrid(ys, xs, indexing="ij") grid = torch.stack([gx.flatten(), gy.flatten()], -1) results = [] for i in range(B): obj_score = torch.sigmoid(objectness[i, :, 0]) cls_score = torch.sigmoid(cls_logits[i]) combined = cls_score * obj_score[:, None] mask = combined > score_thresh if not mask.any(): results.append({"boxes": torch.zeros(0, 4, device=device), "scores": torch.zeros(0, device=device), "labels": torch.zeros(0, dtype=torch.long, device=device)}) continue pi, ci = mask.nonzero(as_tuple=True) scores = combined[pi, ci] centers = grid[pi] + offsets[i, pi] * stride half = stride * 2 boxes = torch.stack([centers[:, 0]-half, centers[:, 1]-half, centers[:, 0]+half, centers[:, 1]+half], -1) if scores.numel() > 100: top = scores.topk(100) scores, pi2 = top.values, top.indices boxes, ci = boxes[pi2], ci[pi2] results.append({"boxes": boxes, "scores": scores, "labels": ci}) return results def get_locs(self, spatial): return None