# -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F # from torch.autograd import Variable from ..box_utils import match, log_sum_exp from .focal_loss import FocalLoss class MultiBoxLoss(nn.Module): """SSD Weighted Loss Function Compute Targets: 1) Produce Confidence Target Indices by matching ground truth boxes with (default) 'priorboxes' that have jaccard index > threshold parameter (default threshold: 0.5). 2) Produce localization target by 'encoding' variance into offsets of ground truth boxes and their matched 'priorboxes'. 3) Hard negative mining to filter the excessive number of negative examples that comes with using a large number of default bounding boxes. (default negative:positive ratio 3:1) Objective Loss: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss weighted by α which is set to 1 by cross val. Args: c: class confidences, l: predicted boxes, g: ground truth boxes N: number of matched default boxes See: https://arxiv.org/pdf/1512.02325.pdf for more details. """ def __init__(self, args, cfg, overlap_thresh, bkg_label, neg_pos): super(MultiBoxLoss, self).__init__() self.args = args self.num_classes = cfg['num_classes'] self.threshold = overlap_thresh self.background_label = bkg_label self.negpos_ratio = neg_pos self.variance = cfg['variance'] self.focal_loss = FocalLoss() # self.neg_overlap = neg_overlap # self.encode_target = encode_target # self.use_prior_for_matching = prior_for_matching # self.do_neg_mining = args.neg_mining def forward(self, predictions, targets): """Multibox Loss Args: predictions (tuple): A tuple containing loc preds, conf preds, and prior boxes from SSD net. conf shape: torch.size(batch_size,num_priors,num_classes) loc shape: torch.size(batch_size,num_priors,4) priors shape: torch.size(num_priors,4) targets (tensor): Ground truth boxes and labels for a batch, shape: [batch_size,num_objs,5] (last idx is the label). """ loc_data, conf_data, priors = predictions num = loc_data.size(0) priors = priors[:loc_data.size(1), :] num_priors = (priors.size(0)) num_classes = self.num_classes # match priors (default boxes) and ground truth boxes loc_t = torch.Tensor(num, num_priors, 4) conf_t = torch.LongTensor(num, num_priors) for idx in range(num): truths = targets[idx][:, :-1].data labels = targets[idx][:, -1].data defaults = priors.data match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx) if self.args.cuda: loc_t = loc_t.cuda() conf_t = conf_t.cuda() # wrap targets loc_t = Variable(loc_t, requires_grad=False) conf_t = Variable(conf_t, requires_grad=False) pos = conf_t > 0 num_pos = pos.sum(dim=1, keepdim=True) # Localization Loss (Smooth L1) # Shape: [batch,num_priors,4] pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) loc_p = loc_data[pos_idx].view(-1, 4) loc_t = loc_t[pos_idx].view(-1, 4) loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') # Compute max conf across batch for hard negative mining batch_conf = conf_data.view(-1, self.num_classes) #print('conf_t view ', conf_t.view(-1, 1)) #print('conf_t ' + conf_t.view(-1, 1)) loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) # Hard Negative Mining if self.args.neg_mining: loss_c = loss_c.view(pos.size()[0], pos.size()[1]) loss_c = loss_c.view(num, -1) loss_c[pos] = 0 # filter out pos boxes for now _, loss_idx = loss_c.sort(1, descending=True) _, idx_rank = loss_idx.sort(1) num_pos = pos.long().sum(1, keepdim=True) num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) neg = idx_rank < num_neg.expand_as(idx_rank) else: #num_neg = torch.tensor(0).expand_as(idx_rank) #num_neg[idx_rank] = 1 neg = conf_t == 0 # Confidence Loss Including Positive and Negative Example pos_idx = pos.unsqueeze(2).expand_as(conf_data) neg_idx = neg.unsqueeze(2).expand_as(conf_data) conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) targets_weighted = conf_t[(pos+neg).gt(0)] if self.args.loss_fun == 'ce': loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') else: loss_c = self.focal_loss.compute(conf_p, targets_weighted) # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = num_pos.data.sum() #loss_l = loss_l.double() #loss_c = loss_c.double() loss_l /= N loss_c /= N return loss_l, loss_c