detection-heads / utils /decode.py
phanerozoic's picture
Restructure: one folder per head, shared losses/utils, registry runner
ca63835 verified
"""Shared decode functions for FCOS-style and CenterNet-style heads."""
import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import nms
from typing import Dict, List, Tuple
NUM_CLASSES = 80
FPN_STRIDES = [8, 16, 32, 64, 128]
RESOLUTION = 640
def make_locations(feature_sizes, strides, device):
locs = []
for (h, w), s in zip(feature_sizes, strides):
ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
locs.append(torch.stack([gx.flatten(), gy.flatten()], dim=-1))
return locs
def decode_fcos(cls_per, reg_per, ctr_per, locs_per, score_thresh=0.3, nms_thresh=0.5, max_det=100):
B = cls_per[0].shape[0]
num_classes = cls_per[0].shape[1]
results = []
for bi in range(B):
all_boxes, all_scores, all_labels = [], [], []
for cl, rg, ct, lo in zip(cls_per, reg_per, ctr_per, locs_per):
c = cl[bi].permute(1, 2, 0).reshape(-1, num_classes)
r = rg[bi].permute(1, 2, 0).reshape(-1, 4)
s = torch.sigmoid(c) * torch.sigmoid(ct[bi].permute(1, 2, 0).reshape(-1))[:, None]
mask = s > score_thresh
if not mask.any():
continue
ci, cj = mask.nonzero(as_tuple=True)
sc = s[ci, cj]
if sc.numel() > 1000:
top = sc.topk(1000)
sc, ci, cj = top.values, ci[top.indices], cj[top.indices]
xy = lo[ci]
rr = r[ci]
bx = torch.stack([xy[:, 0] - rr[:, 0], xy[:, 1] - rr[:, 1],
xy[:, 0] + rr[:, 2], xy[:, 1] + rr[:, 3]], -1)
all_boxes.append(bx)
all_scores.append(sc)
all_labels.append(cj)
if all_boxes:
bx = torch.cat(all_boxes)
sc = torch.cat(all_scores)
lb = torch.cat(all_labels)
bx[:, 0::2] = bx[:, 0::2].clamp(0, RESOLUTION)
bx[:, 1::2] = bx[:, 1::2].clamp(0, RESOLUTION)
keep = []
for c in lb.unique():
m = lb == c
k = nms(bx[m], sc[m], nms_thresh)
keep.append(m.nonzero(as_tuple=True)[0][k])
keep = torch.cat(keep)
bx, sc, lb = bx[keep], sc[keep], lb[keep]
if sc.numel() > max_det:
top = sc.topk(max_det)
bx, sc, lb = bx[top.indices], top.values, lb[top.indices]
results.append({"boxes": bx, "scores": sc, "labels": lb})
else:
results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0),
"labels": torch.zeros(0, dtype=torch.long)})
return results
def decode_centernet(pred_hm, pred_off, pred_sz, stride, score_thresh=0.3, max_det=100):
B = pred_hm.shape[0]
results = []
for bi in range(B):
hm = torch.sigmoid(pred_hm[bi])
hm_max = F.max_pool2d(hm.unsqueeze(0), 3, stride=1, padding=1)[0]
keep = (hm == hm_max) & (hm > score_thresh)
positions = keep.nonzero()
if positions.numel() == 0:
results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0),
"labels": torch.zeros(0, dtype=torch.long)})
continue
cls, ys, xs = positions[:, 0], positions[:, 1], positions[:, 2]
scores = hm[cls, ys, xs]
off = pred_off[bi]
sz = pred_sz[bi]
cx = (xs.float() + off[0, ys, xs]) * stride
cy = (ys.float() + off[1, ys, xs]) * stride
w = sz[0, ys, xs] * stride
h = sz[1, ys, xs] * stride
boxes = torch.stack([cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], -1)
if scores.numel() > max_det:
top = scores.topk(max_det)
boxes, scores, cls = boxes[top.indices], top.values, cls[top.indices]
results.append({"boxes": boxes, "scores": scores, "labels": cls})
return results