"""Shared decode functions for FCOS-style and CenterNet-style heads.""" import torch import torch.nn.functional as F from torch import Tensor from torchvision.ops import nms from typing import Dict, List, Tuple NUM_CLASSES = 80 FPN_STRIDES = [8, 16, 32, 64, 128] RESOLUTION = 640 def make_locations(feature_sizes, strides, device): locs = [] for (h, w), s in zip(feature_sizes, strides): ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s gy, gx = torch.meshgrid(ys, xs, indexing="ij") locs.append(torch.stack([gx.flatten(), gy.flatten()], dim=-1)) return locs def decode_fcos(cls_per, reg_per, ctr_per, locs_per, score_thresh=0.3, nms_thresh=0.5, max_det=100): B = cls_per[0].shape[0] num_classes = cls_per[0].shape[1] results = [] for bi in range(B): all_boxes, all_scores, all_labels = [], [], [] for cl, rg, ct, lo in zip(cls_per, reg_per, ctr_per, locs_per): c = cl[bi].permute(1, 2, 0).reshape(-1, num_classes) r = rg[bi].permute(1, 2, 0).reshape(-1, 4) s = torch.sigmoid(c) * torch.sigmoid(ct[bi].permute(1, 2, 0).reshape(-1))[:, None] mask = s > score_thresh if not mask.any(): continue ci, cj = mask.nonzero(as_tuple=True) sc = s[ci, cj] if sc.numel() > 1000: top = sc.topk(1000) sc, ci, cj = top.values, ci[top.indices], cj[top.indices] xy = lo[ci] rr = r[ci] bx = torch.stack([xy[:, 0] - rr[:, 0], xy[:, 1] - rr[:, 1], xy[:, 0] + rr[:, 2], xy[:, 1] + rr[:, 3]], -1) all_boxes.append(bx) all_scores.append(sc) all_labels.append(cj) if all_boxes: bx = torch.cat(all_boxes) sc = torch.cat(all_scores) lb = torch.cat(all_labels) bx[:, 0::2] = bx[:, 0::2].clamp(0, RESOLUTION) bx[:, 1::2] = bx[:, 1::2].clamp(0, RESOLUTION) keep = [] for c in lb.unique(): m = lb == c k = nms(bx[m], sc[m], nms_thresh) keep.append(m.nonzero(as_tuple=True)[0][k]) keep = torch.cat(keep) bx, sc, lb = bx[keep], sc[keep], lb[keep] if sc.numel() > max_det: top = sc.topk(max_det) bx, sc, lb = bx[top.indices], top.values, lb[top.indices] results.append({"boxes": bx, "scores": sc, "labels": lb}) else: results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), "labels": torch.zeros(0, dtype=torch.long)}) return results def decode_centernet(pred_hm, pred_off, pred_sz, stride, score_thresh=0.3, max_det=100): B = pred_hm.shape[0] results = [] for bi in range(B): hm = torch.sigmoid(pred_hm[bi]) hm_max = F.max_pool2d(hm.unsqueeze(0), 3, stride=1, padding=1)[0] keep = (hm == hm_max) & (hm > score_thresh) positions = keep.nonzero() if positions.numel() == 0: results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), "labels": torch.zeros(0, dtype=torch.long)}) continue cls, ys, xs = positions[:, 0], positions[:, 1], positions[:, 2] scores = hm[cls, ys, xs] off = pred_off[bi] sz = pred_sz[bi] cx = (xs.float() + off[0, ys, xs]) * stride cy = (ys.float() + off[1, ys, xs]) * stride w = sz[0, ys, xs] * stride h = sz[1, ys, xs] * stride boxes = torch.stack([cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], -1) if scores.numel() > max_det: top = scores.topk(max_det) boxes, scores, cls = boxes[top.indices], top.values, cls[top.indices] results.append({"boxes": boxes, "scores": scores, "labels": cls}) return results