phanerozoic's picture
Restructure: one folder per head, shared losses/utils, registry runner
ca63835 verified
"""Auto-extracted from detection_arena.py."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import List
from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES
from losses.centernet import centernet_targets, centernet_loss
from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES
N_PREFIX = 5
class PatchAssembly(nn.Module):
"""Patches vote on object membership. Boxes emerge from vote clustering. 5.65M params."""
name = "J_patch_assembly"
needs_intermediates = False
def __init__(self, dim=384, n_layers=3):
super().__init__()
self.proj = nn.Linear(768, dim)
self.layers = nn.ModuleList()
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn": nn.MultiheadAttention(dim, 8, batch_first=True),
"norm1": nn.LayerNorm(dim),
"ffn": nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)),
"norm2": nn.LayerNorm(dim),
}))
self.cls_head = nn.Linear(dim, NUM_CLASSES)
self.offset_head = nn.Linear(dim, 2)
self.obj_head = nn.Linear(dim, 1)
nn.init.constant_(self.cls_head.bias, -math.log(99))
def forward(self, spatial, inter=None):
B, C, H, W = spatial.shape
tokens = spatial.flatten(2).permute(0, 2, 1)
x = self.proj(tokens)
for layer in self.layers:
x2, _ = layer["attn"](x, x, x)
x = layer["norm1"](x + x2)
x = layer["norm2"](x + layer["ffn"](x))
cls = self.cls_head(x)
offset = self.offset_head(x).tanh() * 20
obj = self.obj_head(x)
return cls, offset, obj, H, W
def loss(self, preds, locs, boxes_b, labels_b):
cls_logits, offsets, objectness, H, W = preds
B = cls_logits.shape[0]
device = cls_logits.device
stride = 16
ys = (torch.arange(H, device=device, dtype=torch.float32) + 0.5) * stride
xs = (torch.arange(W, device=device, dtype=torch.float32) + 0.5) * stride
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
grid = torch.stack([gx.flatten(), gy.flatten()], -1)
total = torch.tensor(0.0, device=device)
for i in range(B):
gt_boxes = boxes_b[i]
gt_labels = labels_b[i]
if gt_boxes.numel() == 0:
total = total + torch.sigmoid(objectness[i]).sum() * 0.01
continue
gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2
gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2
gt_centers = torch.stack([gt_cx, gt_cy], -1)
dist = torch.cdist(grid.unsqueeze(0), gt_centers.unsqueeze(0))[0]
nearest = dist.argmin(dim=1)
max_dist = torch.stack([gt_boxes[:, 2] - gt_boxes[:, 0], gt_boxes[:, 3] - gt_boxes[:, 1]], -1).max(-1).values
in_obj = dist[torch.arange(len(grid)), nearest] < max_dist[nearest] * 0.5
obj_tgt = in_obj.float()
loss_obj = F.binary_cross_entropy_with_logits(objectness[i, :, 0], obj_tgt)
if in_obj.any():
target_offset = (gt_centers[nearest[in_obj]] - grid[in_obj]) / stride
loss_off = F.l1_loss(offsets[i, in_obj], target_offset)
cls_tgt = torch.zeros(in_obj.sum(), NUM_CLASSES, device=device)
cls_tgt[torch.arange(in_obj.sum()), gt_labels[nearest[in_obj]]] = 1.0
loss_cls = focal_loss(cls_logits[i, in_obj], cls_tgt) / max(in_obj.sum().item(), 1)
else:
loss_off = loss_cls = torch.tensor(0.0, device=device)
total = total + loss_obj + loss_off + loss_cls
return total / B
def decode(self, preds, locs=None, score_thresh=0.3, **kw):
cls_logits, offsets, objectness, H, W = preds
B = cls_logits.shape[0]
device = cls_logits.device
stride = 16
ys = (torch.arange(H, device=device, dtype=torch.float32) + 0.5) * stride
xs = (torch.arange(W, device=device, dtype=torch.float32) + 0.5) * stride
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
grid = torch.stack([gx.flatten(), gy.flatten()], -1)
results = []
for i in range(B):
obj_score = torch.sigmoid(objectness[i, :, 0])
cls_score = torch.sigmoid(cls_logits[i])
combined = cls_score * obj_score[:, None]
mask = combined > score_thresh
if not mask.any():
results.append({"boxes": torch.zeros(0, 4, device=device), "scores": torch.zeros(0, device=device),
"labels": torch.zeros(0, dtype=torch.long, device=device)})
continue
pi, ci = mask.nonzero(as_tuple=True)
scores = combined[pi, ci]
centers = grid[pi] + offsets[i, pi] * stride
half = stride * 2
boxes = torch.stack([centers[:, 0]-half, centers[:, 1]-half,
centers[:, 0]+half, centers[:, 1]+half], -1)
if scores.numel() > 100:
top = scores.topk(100)
scores, pi2 = top.values, top.indices
boxes, ci = boxes[pi2], ci[pi2]
results.append({"boxes": boxes, "scores": scores, "labels": ci})
return results
def get_locs(self, spatial):
return None