| """Auto-extracted from detection_arena.py."""
|
|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch import Tensor
|
| from typing import List
|
|
|
| from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES
|
| from losses.centernet import centernet_targets, centernet_loss
|
| from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES
|
|
|
| N_PREFIX = 5
|
|
|
| class SparseQueries(nn.Module):
|
| """100 learned queries, 2 cross-attention layers. ~2.5M params. No FPN, no NMS."""
|
| name = "E_sparse_queries"
|
| needs_intermediates = False
|
|
|
| def __init__(self, n_queries=100, dim=256, n_layers=2):
|
| super().__init__()
|
| self.queries = nn.Embedding(n_queries, dim)
|
| self.input_proj = nn.Linear(768, dim)
|
| self.layers = nn.ModuleList()
|
| for _ in range(n_layers):
|
| self.layers.append(nn.ModuleDict({
|
| "cross_attn": nn.MultiheadAttention(dim, 8, batch_first=True),
|
| "norm1": nn.LayerNorm(dim),
|
| "ffn": nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)),
|
| "norm2": nn.LayerNorm(dim),
|
| }))
|
| self.cls_head = nn.Linear(dim, NUM_CLASSES)
|
| self.box_head = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, 4))
|
| nn.init.constant_(self.cls_head.bias, -math.log(99))
|
|
|
| def forward(self, spatial, inter=None):
|
| B, C, H, W = spatial.shape
|
| kv = self.input_proj(spatial.flatten(2).permute(0, 2, 1))
|
| q = self.queries.weight.unsqueeze(0).expand(B, -1, -1)
|
| for layer in self.layers:
|
| q2, _ = layer["cross_attn"](q, kv, kv)
|
| q = layer["norm1"](q + q2)
|
| q = layer["norm2"](q + layer["ffn"](q))
|
| return self.cls_head(q), torch.sigmoid(self.box_head(q))
|
|
|
| def loss(self, preds, locs, boxes_b, labels_b):
|
| cls_logits, pred_boxes = preds
|
| B = cls_logits.shape[0]
|
| total = torch.tensor(0.0, device=cls_logits.device)
|
| for i in range(B):
|
| gt_boxes = boxes_b[i] / RESOLUTION
|
| gt_labels = labels_b[i]
|
| if gt_boxes.numel() == 0:
|
| total = total + torch.sigmoid(cls_logits[i]).sum() * 0.01
|
| continue
|
|
|
| with torch.no_grad():
|
| prob = torch.sigmoid(cls_logits[i])
|
| cost_cls = -prob[:, gt_labels]
|
| cost_box = torch.cdist(pred_boxes[i], gt_boxes, p=1)
|
| cost = cost_cls + 5 * cost_box
|
|
|
| matched_q, matched_g = [], []
|
| c = cost.clone()
|
| for _ in range(min(cost.shape[0], cost.shape[1])):
|
| idx = c.argmin()
|
| qi, gi = idx // c.shape[1], idx % c.shape[1]
|
| matched_q.append(qi.item()); matched_g.append(gi.item())
|
| c[qi, :] = float("inf"); c[:, gi] = float("inf")
|
| if matched_q:
|
| mq = torch.tensor(matched_q, device=cls_logits.device)
|
| mg = torch.tensor(matched_g, device=cls_logits.device)
|
| tgt = torch.zeros_like(cls_logits[i])
|
| tgt[mq, gt_labels[mg]] = 1.0
|
| loss_cls = focal_loss(cls_logits[i], tgt) / len(matched_q)
|
| loss_box = F.l1_loss(pred_boxes[i][mq], gt_boxes[mg]) * 5
|
| total = total + loss_cls + loss_box
|
| else:
|
| total = total + torch.sigmoid(cls_logits[i]).sum() * 0.01
|
| return total / B
|
|
|
| def decode(self, preds, locs=None, score_thresh=0.3, **kw):
|
| cls_logits, pred_boxes = preds
|
| B = cls_logits.shape[0]
|
| results = []
|
| for i in range(B):
|
| scores = torch.sigmoid(cls_logits[i])
|
| boxes = pred_boxes[i] * RESOLUTION
|
| mask = scores > score_thresh
|
| if not mask.any():
|
| results.append({"boxes": torch.zeros(0,4), "scores": torch.zeros(0), "labels": torch.zeros(0, dtype=torch.long)})
|
| continue
|
| qi, ci = mask.nonzero(as_tuple=True)
|
| sc = scores[qi, ci]
|
| bx = boxes[qi]
|
| results.append({"boxes": bx, "scores": sc, "labels": ci})
|
| return results
|
|
|
| def get_locs(self, spatial):
|
| return None
|
|
|