|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.autograd import Variable
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module):
|
|
|
r"""
|
|
|
This criterion is a implemenation of Focal Loss, which is proposed in
|
|
|
Focal Loss for Dense Object Detection.
|
|
|
|
|
|
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
|
|
|
|
|
|
The losses are averaged across observations for each minibatch.
|
|
|
|
|
|
Args:
|
|
|
alpha(1D Tensor, Variable) : the scalar factor for this criterion
|
|
|
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
|
|
|
putting more focus on hard, misclassified examples
|
|
|
size_average(bool): By default, the losses are averaged over observations for each minibatch.
|
|
|
However, if the field size_average is set to False, the losses are
|
|
|
instead summed for each minibatch.
|
|
|
"""
|
|
|
def __init__(self, class_num, alpha=None, gamma=2, size_average=True, device='cuda:0'):
|
|
|
super(FocalLoss, self).__init__()
|
|
|
if alpha is None:
|
|
|
self.alpha = Variable(torch.ones(class_num, 1))
|
|
|
else:
|
|
|
if isinstance(alpha, Variable):
|
|
|
self.alpha = alpha
|
|
|
else:
|
|
|
self.alpha = Variable(alpha)
|
|
|
self.gamma = gamma
|
|
|
self.class_num = class_num
|
|
|
self.size_average = size_average
|
|
|
self.device = device
|
|
|
|
|
|
def forward(self, inputs, targets):
|
|
|
N = inputs.size(0)
|
|
|
C = inputs.size(1)
|
|
|
P = F.softmax(inputs, dim=1)
|
|
|
|
|
|
class_mask = inputs.data.new(N, C).fill_(0)
|
|
|
class_mask = Variable(class_mask)
|
|
|
ids = targets.view(-1, 1)
|
|
|
class_mask.scatter_(1, ids.data, 1.)
|
|
|
|
|
|
|
|
|
if inputs.is_cuda and not self.alpha.is_cuda:
|
|
|
|
|
|
self.alpha = self.alpha.to(self.device)
|
|
|
alpha = self.alpha[ids.data.view(-1)]
|
|
|
|
|
|
probs = (P * class_mask).sum(1).view(-1, 1)
|
|
|
|
|
|
log_p = probs.log()
|
|
|
|
|
|
|
|
|
|
|
|
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
|
|
|
|
|
|
|
|
|
|
|
|
if self.size_average:
|
|
|
loss = batch_loss.mean()
|
|
|
else:
|
|
|
loss = batch_loss.sum()
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|