| """Auto-extracted from detection_arena.py.""" | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from typing import List | |
| from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES | |
| from losses.centernet import centernet_targets, centernet_loss | |
| from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES | |
| N_PREFIX = 5 | |
| class PatchAssembly(nn.Module): | |
| """Patches vote on object membership. Boxes emerge from vote clustering. 5.65M params.""" | |
| name = "J_patch_assembly" | |
| needs_intermediates = False | |
| def __init__(self, dim=384, n_layers=3): | |
| super().__init__() | |
| self.proj = nn.Linear(768, dim) | |
| self.layers = nn.ModuleList() | |
| for _ in range(n_layers): | |
| self.layers.append(nn.ModuleDict({ | |
| "attn": nn.MultiheadAttention(dim, 8, batch_first=True), | |
| "norm1": nn.LayerNorm(dim), | |
| "ffn": nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)), | |
| "norm2": nn.LayerNorm(dim), | |
| })) | |
| self.cls_head = nn.Linear(dim, NUM_CLASSES) | |
| self.offset_head = nn.Linear(dim, 2) | |
| self.obj_head = nn.Linear(dim, 1) | |
| nn.init.constant_(self.cls_head.bias, -math.log(99)) | |
| def forward(self, spatial, inter=None): | |
| B, C, H, W = spatial.shape | |
| tokens = spatial.flatten(2).permute(0, 2, 1) | |
| x = self.proj(tokens) | |
| for layer in self.layers: | |
| x2, _ = layer["attn"](x, x, x) | |
| x = layer["norm1"](x + x2) | |
| x = layer["norm2"](x + layer["ffn"](x)) | |
| cls = self.cls_head(x) | |
| offset = self.offset_head(x).tanh() * 20 | |
| obj = self.obj_head(x) | |
| return cls, offset, obj, H, W | |
| def loss(self, preds, locs, boxes_b, labels_b): | |
| cls_logits, offsets, objectness, H, W = preds | |
| B = cls_logits.shape[0] | |
| device = cls_logits.device | |
| stride = 16 | |
| ys = (torch.arange(H, device=device, dtype=torch.float32) + 0.5) * stride | |
| xs = (torch.arange(W, device=device, dtype=torch.float32) + 0.5) * stride | |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") | |
| grid = torch.stack([gx.flatten(), gy.flatten()], -1) | |
| total = torch.tensor(0.0, device=device) | |
| for i in range(B): | |
| gt_boxes = boxes_b[i] | |
| gt_labels = labels_b[i] | |
| if gt_boxes.numel() == 0: | |
| total = total + torch.sigmoid(objectness[i]).sum() * 0.01 | |
| continue | |
| gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2 | |
| gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2 | |
| gt_centers = torch.stack([gt_cx, gt_cy], -1) | |
| dist = torch.cdist(grid.unsqueeze(0), gt_centers.unsqueeze(0))[0] | |
| nearest = dist.argmin(dim=1) | |
| max_dist = torch.stack([gt_boxes[:, 2] - gt_boxes[:, 0], gt_boxes[:, 3] - gt_boxes[:, 1]], -1).max(-1).values | |
| in_obj = dist[torch.arange(len(grid)), nearest] < max_dist[nearest] * 0.5 | |
| obj_tgt = in_obj.float() | |
| loss_obj = F.binary_cross_entropy_with_logits(objectness[i, :, 0], obj_tgt) | |
| if in_obj.any(): | |
| target_offset = (gt_centers[nearest[in_obj]] - grid[in_obj]) / stride | |
| loss_off = F.l1_loss(offsets[i, in_obj], target_offset) | |
| cls_tgt = torch.zeros(in_obj.sum(), NUM_CLASSES, device=device) | |
| cls_tgt[torch.arange(in_obj.sum()), gt_labels[nearest[in_obj]]] = 1.0 | |
| loss_cls = focal_loss(cls_logits[i, in_obj], cls_tgt) / max(in_obj.sum().item(), 1) | |
| else: | |
| loss_off = loss_cls = torch.tensor(0.0, device=device) | |
| total = total + loss_obj + loss_off + loss_cls | |
| return total / B | |
| def decode(self, preds, locs=None, score_thresh=0.3, **kw): | |
| cls_logits, offsets, objectness, H, W = preds | |
| B = cls_logits.shape[0] | |
| device = cls_logits.device | |
| stride = 16 | |
| ys = (torch.arange(H, device=device, dtype=torch.float32) + 0.5) * stride | |
| xs = (torch.arange(W, device=device, dtype=torch.float32) + 0.5) * stride | |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") | |
| grid = torch.stack([gx.flatten(), gy.flatten()], -1) | |
| results = [] | |
| for i in range(B): | |
| obj_score = torch.sigmoid(objectness[i, :, 0]) | |
| cls_score = torch.sigmoid(cls_logits[i]) | |
| combined = cls_score * obj_score[:, None] | |
| mask = combined > score_thresh | |
| if not mask.any(): | |
| results.append({"boxes": torch.zeros(0, 4, device=device), "scores": torch.zeros(0, device=device), | |
| "labels": torch.zeros(0, dtype=torch.long, device=device)}) | |
| continue | |
| pi, ci = mask.nonzero(as_tuple=True) | |
| scores = combined[pi, ci] | |
| centers = grid[pi] + offsets[i, pi] * stride | |
| half = stride * 2 | |
| boxes = torch.stack([centers[:, 0]-half, centers[:, 1]-half, | |
| centers[:, 0]+half, centers[:, 1]+half], -1) | |
| if scores.numel() > 100: | |
| top = scores.topk(100) | |
| scores, pi2 = top.values, top.indices | |
| boxes, ci = boxes[pi2], ci[pi2] | |
| results.append({"boxes": boxes, "scores": scores, "labels": ci}) | |
| return results | |
| def get_locs(self, spatial): | |
| return None | |