| """Shared decode functions for FCOS-style and CenterNet-style heads.""" |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
| from torchvision.ops import nms |
| from typing import Dict, List, Tuple |
|
|
| NUM_CLASSES = 80 |
| FPN_STRIDES = [8, 16, 32, 64, 128] |
| RESOLUTION = 640 |
|
|
|
|
| def make_locations(feature_sizes, strides, device): |
| locs = [] |
| for (h, w), s in zip(feature_sizes, strides): |
| ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") |
| locs.append(torch.stack([gx.flatten(), gy.flatten()], dim=-1)) |
| return locs |
|
|
|
|
| def decode_fcos(cls_per, reg_per, ctr_per, locs_per, score_thresh=0.3, nms_thresh=0.5, max_det=100): |
| B = cls_per[0].shape[0] |
| num_classes = cls_per[0].shape[1] |
| results = [] |
| for bi in range(B): |
| all_boxes, all_scores, all_labels = [], [], [] |
| for cl, rg, ct, lo in zip(cls_per, reg_per, ctr_per, locs_per): |
| c = cl[bi].permute(1, 2, 0).reshape(-1, num_classes) |
| r = rg[bi].permute(1, 2, 0).reshape(-1, 4) |
| s = torch.sigmoid(c) * torch.sigmoid(ct[bi].permute(1, 2, 0).reshape(-1))[:, None] |
| mask = s > score_thresh |
| if not mask.any(): |
| continue |
| ci, cj = mask.nonzero(as_tuple=True) |
| sc = s[ci, cj] |
| if sc.numel() > 1000: |
| top = sc.topk(1000) |
| sc, ci, cj = top.values, ci[top.indices], cj[top.indices] |
| xy = lo[ci] |
| rr = r[ci] |
| bx = torch.stack([xy[:, 0] - rr[:, 0], xy[:, 1] - rr[:, 1], |
| xy[:, 0] + rr[:, 2], xy[:, 1] + rr[:, 3]], -1) |
| all_boxes.append(bx) |
| all_scores.append(sc) |
| all_labels.append(cj) |
| if all_boxes: |
| bx = torch.cat(all_boxes) |
| sc = torch.cat(all_scores) |
| lb = torch.cat(all_labels) |
| bx[:, 0::2] = bx[:, 0::2].clamp(0, RESOLUTION) |
| bx[:, 1::2] = bx[:, 1::2].clamp(0, RESOLUTION) |
| keep = [] |
| for c in lb.unique(): |
| m = lb == c |
| k = nms(bx[m], sc[m], nms_thresh) |
| keep.append(m.nonzero(as_tuple=True)[0][k]) |
| keep = torch.cat(keep) |
| bx, sc, lb = bx[keep], sc[keep], lb[keep] |
| if sc.numel() > max_det: |
| top = sc.topk(max_det) |
| bx, sc, lb = bx[top.indices], top.values, lb[top.indices] |
| results.append({"boxes": bx, "scores": sc, "labels": lb}) |
| else: |
| results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), |
| "labels": torch.zeros(0, dtype=torch.long)}) |
| return results |
|
|
|
|
| def decode_centernet(pred_hm, pred_off, pred_sz, stride, score_thresh=0.3, max_det=100): |
| B = pred_hm.shape[0] |
| results = [] |
| for bi in range(B): |
| hm = torch.sigmoid(pred_hm[bi]) |
| hm_max = F.max_pool2d(hm.unsqueeze(0), 3, stride=1, padding=1)[0] |
| keep = (hm == hm_max) & (hm > score_thresh) |
| positions = keep.nonzero() |
| if positions.numel() == 0: |
| results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), |
| "labels": torch.zeros(0, dtype=torch.long)}) |
| continue |
| cls, ys, xs = positions[:, 0], positions[:, 1], positions[:, 2] |
| scores = hm[cls, ys, xs] |
| off = pred_off[bi] |
| sz = pred_sz[bi] |
| cx = (xs.float() + off[0, ys, xs]) * stride |
| cy = (ys.float() + off[1, ys, xs]) * stride |
| w = sz[0, ys, xs] * stride |
| h = sz[1, ys, xs] * stride |
| boxes = torch.stack([cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], -1) |
| if scores.numel() > max_det: |
| top = scores.topk(max_det) |
| boxes, scores, cls = boxes[top.indices], top.values, cls[top.indices] |
| results.append({"boxes": boxes, "scores": scores, "labels": cls}) |
| return results |
|
|