|
|
from __future__ import division |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class Detect(nn.Module): |
|
|
"""At test time, Detect is the final layer of SSD. Decode location preds, |
|
|
apply non-maximum suppression to location predictions based on conf |
|
|
scores and threshold to a top_k number of output predictions for both |
|
|
confidence score and locations. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh, variance=(0.1, 0.2)): |
|
|
super(Detect, self).__init__() |
|
|
self.num_classes = num_classes |
|
|
self.background_label = bkg_label |
|
|
self.top_k = top_k |
|
|
|
|
|
self.nms_thresh = nms_thresh |
|
|
if nms_thresh <= 0: |
|
|
raise ValueError('nms_threshold must be non negative.') |
|
|
self.conf_thresh = conf_thresh |
|
|
self.variance = variance |
|
|
|
|
|
def forward(self, loc_data, conf_data, prior_data): |
|
|
""" |
|
|
Args: |
|
|
loc_data: (tensor) Loc preds from loc layers |
|
|
Shape: [batch,num_priors*4] |
|
|
conf_data: (tensor) Shape: Conf preds from conf layers |
|
|
Shape: [batch*num_priors,num_classes] |
|
|
prior_data: (tensor) Prior boxes and variances from priorbox layers |
|
|
Shape: [1,num_priors,4] |
|
|
""" |
|
|
num = loc_data.size(0) |
|
|
num_priors = prior_data.size(0) |
|
|
|
|
|
output = torch.zeros(num, self.num_classes, self.top_k, 5) |
|
|
conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1) |
|
|
|
|
|
|
|
|
for i in range(num): |
|
|
default = prior_data |
|
|
decoded_boxes = decode(loc_data[i], default, self.variance) |
|
|
|
|
|
conf_scores = conf_preds[i].clone() |
|
|
|
|
|
for cl in range(1, self.num_classes): |
|
|
c_mask = conf_scores[cl].gt(self.conf_thresh) |
|
|
scores = conf_scores[cl][c_mask] |
|
|
if scores.dim() == 0 or scores.size(0) == 0: |
|
|
continue |
|
|
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) |
|
|
boxes = decoded_boxes[l_mask].view(-1, 4) |
|
|
|
|
|
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) |
|
|
output[i, cl, :count] = \ |
|
|
torch.cat((scores[ids[:count]].unsqueeze(1), |
|
|
boxes[ids[:count]]), 1) |
|
|
flt = output.contiguous().view(num, -1, 5) |
|
|
_, idx = flt[:, :, 0].sort(1, descending=True) |
|
|
_, rank = idx.sort(1) |
|
|
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def decode(loc, priors, variances): |
|
|
"""Decode locations from predictions using priors to undo |
|
|
the encoding we did for offset regression at train time. |
|
|
Args: |
|
|
loc (tensor): location predictions for loc layers, |
|
|
Shape: [num_priors,4] |
|
|
priors (tensor): Prior boxes in center-offset form. |
|
|
Shape: [num_priors,4]. |
|
|
variances: (list[float]) Variances of priorboxes |
|
|
Return: |
|
|
decoded bounding box predictions |
|
|
""" |
|
|
|
|
|
boxes = torch.cat(( |
|
|
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], |
|
|
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) |
|
|
boxes[:, :2] -= boxes[:, 2:] / 2 |
|
|
boxes[:, 2:] += boxes[:, :2] |
|
|
|
|
|
return boxes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nms(boxes, scores, overlap=0.5, top_k=200): |
|
|
"""Apply non-maximum suppression at test time to avoid detecting too many |
|
|
overlapping bounding boxes for a given object. |
|
|
Args: |
|
|
boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. |
|
|
scores: (tensor) The class predscores for the img, Shape:[num_priors]. |
|
|
overlap: (float) The overlap thresh for suppressing unnecessary boxes. |
|
|
top_k: (int) The Maximum number of box preds to consider. |
|
|
Return: |
|
|
The indices of the kept boxes with respect to num_priors. |
|
|
""" |
|
|
|
|
|
keep = scores.new(scores.size(0)).zero_().long() |
|
|
if boxes.numel() == 0: |
|
|
return keep |
|
|
x1 = boxes[:, 0] |
|
|
y1 = boxes[:, 1] |
|
|
x2 = boxes[:, 2] |
|
|
y2 = boxes[:, 3] |
|
|
area = torch.mul(x2 - x1, y2 - y1) |
|
|
v, idx = scores.sort(0) |
|
|
|
|
|
idx = idx[-top_k:] |
|
|
xx1 = boxes.new() |
|
|
yy1 = boxes.new() |
|
|
xx2 = boxes.new() |
|
|
yy2 = boxes.new() |
|
|
w = boxes.new() |
|
|
h = boxes.new() |
|
|
|
|
|
|
|
|
count = 0 |
|
|
while idx.numel() > 0: |
|
|
i = idx[-1] |
|
|
|
|
|
keep[count] = i |
|
|
count += 1 |
|
|
if idx.size(0) == 1: |
|
|
break |
|
|
idx = idx[:-1] |
|
|
|
|
|
torch.index_select(x1, 0, idx, out=xx1) |
|
|
torch.index_select(y1, 0, idx, out=yy1) |
|
|
torch.index_select(x2, 0, idx, out=xx2) |
|
|
torch.index_select(y2, 0, idx, out=yy2) |
|
|
|
|
|
xx1 = torch.clamp(xx1, min=x1[i]) |
|
|
yy1 = torch.clamp(yy1, min=y1[i]) |
|
|
xx2 = torch.clamp(xx2, max=x2[i]) |
|
|
yy2 = torch.clamp(yy2, max=y2[i]) |
|
|
w.resize_as_(xx2) |
|
|
h.resize_as_(yy2) |
|
|
w = xx2 - xx1 |
|
|
h = yy2 - yy1 |
|
|
|
|
|
w = torch.clamp(w, min=0.0) |
|
|
h = torch.clamp(h, min=0.0) |
|
|
inter = w * h |
|
|
|
|
|
rem_areas = torch.index_select(area, 0, idx) |
|
|
union = (rem_areas - inter) + area[i] |
|
|
IoU = inter / union |
|
|
|
|
|
idx = idx[IoU.le(overlap)] |
|
|
return keep, count |
|
|
|