File size: 4,037 Bytes
ca63835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Shared decode functions for FCOS-style and CenterNet-style heads."""

import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import nms
from typing import Dict, List, Tuple

NUM_CLASSES = 80
FPN_STRIDES = [8, 16, 32, 64, 128]
RESOLUTION = 640


def make_locations(feature_sizes, strides, device):
    locs = []
    for (h, w), s in zip(feature_sizes, strides):
        ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
        xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
        gy, gx = torch.meshgrid(ys, xs, indexing="ij")
        locs.append(torch.stack([gx.flatten(), gy.flatten()], dim=-1))
    return locs


def decode_fcos(cls_per, reg_per, ctr_per, locs_per, score_thresh=0.3, nms_thresh=0.5, max_det=100):
    B = cls_per[0].shape[0]
    num_classes = cls_per[0].shape[1]
    results = []
    for bi in range(B):
        all_boxes, all_scores, all_labels = [], [], []
        for cl, rg, ct, lo in zip(cls_per, reg_per, ctr_per, locs_per):
            c = cl[bi].permute(1, 2, 0).reshape(-1, num_classes)
            r = rg[bi].permute(1, 2, 0).reshape(-1, 4)
            s = torch.sigmoid(c) * torch.sigmoid(ct[bi].permute(1, 2, 0).reshape(-1))[:, None]
            mask = s > score_thresh
            if not mask.any():
                continue
            ci, cj = mask.nonzero(as_tuple=True)
            sc = s[ci, cj]
            if sc.numel() > 1000:
                top = sc.topk(1000)
                sc, ci, cj = top.values, ci[top.indices], cj[top.indices]
            xy = lo[ci]
            rr = r[ci]
            bx = torch.stack([xy[:, 0] - rr[:, 0], xy[:, 1] - rr[:, 1],
                              xy[:, 0] + rr[:, 2], xy[:, 1] + rr[:, 3]], -1)
            all_boxes.append(bx)
            all_scores.append(sc)
            all_labels.append(cj)
        if all_boxes:
            bx = torch.cat(all_boxes)
            sc = torch.cat(all_scores)
            lb = torch.cat(all_labels)
            bx[:, 0::2] = bx[:, 0::2].clamp(0, RESOLUTION)
            bx[:, 1::2] = bx[:, 1::2].clamp(0, RESOLUTION)
            keep = []
            for c in lb.unique():
                m = lb == c
                k = nms(bx[m], sc[m], nms_thresh)
                keep.append(m.nonzero(as_tuple=True)[0][k])
            keep = torch.cat(keep)
            bx, sc, lb = bx[keep], sc[keep], lb[keep]
            if sc.numel() > max_det:
                top = sc.topk(max_det)
                bx, sc, lb = bx[top.indices], top.values, lb[top.indices]
            results.append({"boxes": bx, "scores": sc, "labels": lb})
        else:
            results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0),
                            "labels": torch.zeros(0, dtype=torch.long)})
    return results


def decode_centernet(pred_hm, pred_off, pred_sz, stride, score_thresh=0.3, max_det=100):
    B = pred_hm.shape[0]
    results = []
    for bi in range(B):
        hm = torch.sigmoid(pred_hm[bi])
        hm_max = F.max_pool2d(hm.unsqueeze(0), 3, stride=1, padding=1)[0]
        keep = (hm == hm_max) & (hm > score_thresh)
        positions = keep.nonzero()
        if positions.numel() == 0:
            results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0),
                            "labels": torch.zeros(0, dtype=torch.long)})
            continue
        cls, ys, xs = positions[:, 0], positions[:, 1], positions[:, 2]
        scores = hm[cls, ys, xs]
        off = pred_off[bi]
        sz = pred_sz[bi]
        cx = (xs.float() + off[0, ys, xs]) * stride
        cy = (ys.float() + off[1, ys, xs]) * stride
        w = sz[0, ys, xs] * stride
        h = sz[1, ys, xs] * stride
        boxes = torch.stack([cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], -1)
        if scores.numel() > max_det:
            top = scores.topk(max_det)
            boxes, scores, cls = boxes[top.indices], top.values, cls[top.indices]
        results.append({"boxes": boxes, "scores": scores, "labels": cls})
    return results