phanerozoic's picture
Restructure: one folder per head, shared losses/utils, registry runner
ca63835 verified
"""Auto-extracted from detection_arena.py."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import List
from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES
from losses.centernet import centernet_targets, centernet_loss
from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES
N_PREFIX = 5
class SparseQueries(nn.Module):
"""100 learned queries, 2 cross-attention layers. ~2.5M params. No FPN, no NMS."""
name = "E_sparse_queries"
needs_intermediates = False
def __init__(self, n_queries=100, dim=256, n_layers=2):
super().__init__()
self.queries = nn.Embedding(n_queries, dim)
self.input_proj = nn.Linear(768, dim)
self.layers = nn.ModuleList()
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"cross_attn": nn.MultiheadAttention(dim, 8, batch_first=True),
"norm1": nn.LayerNorm(dim),
"ffn": nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)),
"norm2": nn.LayerNorm(dim),
}))
self.cls_head = nn.Linear(dim, NUM_CLASSES)
self.box_head = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, 4))
nn.init.constant_(self.cls_head.bias, -math.log(99))
def forward(self, spatial, inter=None):
B, C, H, W = spatial.shape
kv = self.input_proj(spatial.flatten(2).permute(0, 2, 1))
q = self.queries.weight.unsqueeze(0).expand(B, -1, -1)
for layer in self.layers:
q2, _ = layer["cross_attn"](q, kv, kv)
q = layer["norm1"](q + q2)
q = layer["norm2"](q + layer["ffn"](q))
return self.cls_head(q), torch.sigmoid(self.box_head(q))
def loss(self, preds, locs, boxes_b, labels_b):
cls_logits, pred_boxes = preds
B = cls_logits.shape[0]
total = torch.tensor(0.0, device=cls_logits.device)
for i in range(B):
gt_boxes = boxes_b[i] / RESOLUTION
gt_labels = labels_b[i]
if gt_boxes.numel() == 0:
total = total + torch.sigmoid(cls_logits[i]).sum() * 0.01
continue
# simplified Hungarian: use cost matrix of class prob + L1 box distance
with torch.no_grad():
prob = torch.sigmoid(cls_logits[i]) # [Q, 80]
cost_cls = -prob[:, gt_labels] # [Q, G]
cost_box = torch.cdist(pred_boxes[i], gt_boxes, p=1) # [Q, G]
cost = cost_cls + 5 * cost_box
# greedy assignment (not true Hungarian but fast)
matched_q, matched_g = [], []
c = cost.clone()
for _ in range(min(cost.shape[0], cost.shape[1])):
idx = c.argmin()
qi, gi = idx // c.shape[1], idx % c.shape[1]
matched_q.append(qi.item()); matched_g.append(gi.item())
c[qi, :] = float("inf"); c[:, gi] = float("inf")
if matched_q:
mq = torch.tensor(matched_q, device=cls_logits.device)
mg = torch.tensor(matched_g, device=cls_logits.device)
tgt = torch.zeros_like(cls_logits[i])
tgt[mq, gt_labels[mg]] = 1.0
loss_cls = focal_loss(cls_logits[i], tgt) / len(matched_q)
loss_box = F.l1_loss(pred_boxes[i][mq], gt_boxes[mg]) * 5
total = total + loss_cls + loss_box
else:
total = total + torch.sigmoid(cls_logits[i]).sum() * 0.01
return total / B
def decode(self, preds, locs=None, score_thresh=0.3, **kw):
cls_logits, pred_boxes = preds
B = cls_logits.shape[0]
results = []
for i in range(B):
scores = torch.sigmoid(cls_logits[i])
boxes = pred_boxes[i] * RESOLUTION
mask = scores > score_thresh
if not mask.any():
results.append({"boxes": torch.zeros(0,4), "scores": torch.zeros(0), "labels": torch.zeros(0, dtype=torch.long)})
continue
qi, ci = mask.nonzero(as_tuple=True)
sc = scores[qi, ci]
bx = boxes[qi]
results.append({"boxes": bx, "scores": sc, "labels": ci})
return results
def get_locs(self, spatial):
return None