import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class FocalLoss(nn.Module): r""" This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection. Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) The losses are averaged across observations for each minibatch. Args: alpha(1D Tensor, Variable) : the scalar factor for this criterion gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), putting more focus on hard, misclassified examples size_average(bool): By default, the losses are averaged over observations for each minibatch. However, if the field size_average is set to False, the losses are instead summed for each minibatch. """ def __init__(self, class_num, alpha=None, gamma=2, size_average=True, device='cuda:0'): super(FocalLoss, self).__init__() if alpha is None: self.alpha = Variable(torch.ones(class_num, 1)) else: if isinstance(alpha, Variable): self.alpha = alpha else: self.alpha = Variable(alpha) self.gamma = gamma self.class_num = class_num self.size_average = size_average self.device = device def forward(self, inputs, targets): N = inputs.size(0) C = inputs.size(1) P = F.softmax(inputs, dim=1) class_mask = inputs.data.new(N, C).fill_(0) class_mask = Variable(class_mask) ids = targets.view(-1, 1) class_mask.scatter_(1, ids.data, 1.) # print(class_mask) if inputs.is_cuda and not self.alpha.is_cuda: # self.alpha = self.alpha.cuda() self.alpha = self.alpha.to(self.device) alpha = self.alpha[ids.data.view(-1)] probs = (P * class_mask).sum(1).view(-1, 1) log_p = probs.log() # print('probs size= {}'.format(probs.size())) # print(probs) batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p # print('-----bacth_loss------') # print(batch_loss) if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss # import torch # import torch.nn as nn # # # class FocalLoss(nn.Module): # # def __init__(self, gamma=0, eps=1e-7): # super(FocalLoss, self).__init__() # self.gamma = gamma # self.eps = eps # self.ce = torch.nn.CrossEntropyLoss() # # def forward(self, input, target): # logp = self.ce(input, target) # p = torch.exp(-logp) # loss = (1 - p) ** self.gamma * logp # return loss.mean()