| |
| import torch |
| import torch.nn as nn |
| from torch.autograd import Function |
| from torch.autograd.function import once_differentiable |
|
|
| from ..utils import ext_loader |
|
|
| ext_module = ext_loader.load_ext('_ext', [ |
| 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward', |
| 'softmax_focal_loss_forward', 'softmax_focal_loss_backward' |
| ]) |
|
|
|
|
| class SigmoidFocalLossFunction(Function): |
|
|
| @staticmethod |
| def symbolic(g, input, target, gamma, alpha, weight, reduction): |
| return g.op( |
| 'mmcv::MMCVSigmoidFocalLoss', |
| input, |
| target, |
| gamma_f=gamma, |
| alpha_f=alpha, |
| weight_f=weight, |
| reduction_s=reduction) |
|
|
| @staticmethod |
| def forward(ctx, |
| input, |
| target, |
| gamma=2.0, |
| alpha=0.25, |
| weight=None, |
| reduction='mean'): |
|
|
| assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) |
| assert input.dim() == 2 |
| assert target.dim() == 1 |
| assert input.size(0) == target.size(0) |
| if weight is None: |
| weight = input.new_empty(0) |
| else: |
| assert weight.dim() == 1 |
| assert input.size(1) == weight.size(0) |
| ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} |
| assert reduction in ctx.reduction_dict.keys() |
|
|
| ctx.gamma = float(gamma) |
| ctx.alpha = float(alpha) |
| ctx.reduction = ctx.reduction_dict[reduction] |
|
|
| output = input.new_zeros(input.size()) |
|
|
| ext_module.sigmoid_focal_loss_forward( |
| input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha) |
| if ctx.reduction == ctx.reduction_dict['mean']: |
| output = output.sum() / input.size(0) |
| elif ctx.reduction == ctx.reduction_dict['sum']: |
| output = output.sum() |
| ctx.save_for_backward(input, target, weight) |
| return output |
|
|
| @staticmethod |
| @once_differentiable |
| def backward(ctx, grad_output): |
| input, target, weight = ctx.saved_tensors |
|
|
| grad_input = input.new_zeros(input.size()) |
|
|
| ext_module.sigmoid_focal_loss_backward( |
| input, |
| target, |
| weight, |
| grad_input, |
| gamma=ctx.gamma, |
| alpha=ctx.alpha) |
|
|
| grad_input *= grad_output |
| if ctx.reduction == ctx.reduction_dict['mean']: |
| grad_input /= input.size(0) |
| return grad_input, None, None, None, None, None |
|
|
|
|
| sigmoid_focal_loss = SigmoidFocalLossFunction.apply |
|
|
|
|
| class SigmoidFocalLoss(nn.Module): |
|
|
| def __init__(self, gamma, alpha, weight=None, reduction='mean'): |
| super(SigmoidFocalLoss, self).__init__() |
| self.gamma = gamma |
| self.alpha = alpha |
| self.register_buffer('weight', weight) |
| self.reduction = reduction |
|
|
| def forward(self, input, target): |
| return sigmoid_focal_loss(input, target, self.gamma, self.alpha, |
| self.weight, self.reduction) |
|
|
| def __repr__(self): |
| s = self.__class__.__name__ |
| s += f'(gamma={self.gamma}, ' |
| s += f'alpha={self.alpha}, ' |
| s += f'reduction={self.reduction})' |
| return s |
|
|
|
|
| class SoftmaxFocalLossFunction(Function): |
|
|
| @staticmethod |
| def symbolic(g, input, target, gamma, alpha, weight, reduction): |
| return g.op( |
| 'mmcv::MMCVSoftmaxFocalLoss', |
| input, |
| target, |
| gamma_f=gamma, |
| alpha_f=alpha, |
| weight_f=weight, |
| reduction_s=reduction) |
|
|
| @staticmethod |
| def forward(ctx, |
| input, |
| target, |
| gamma=2.0, |
| alpha=0.25, |
| weight=None, |
| reduction='mean'): |
|
|
| assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) |
| assert input.dim() == 2 |
| assert target.dim() == 1 |
| assert input.size(0) == target.size(0) |
| if weight is None: |
| weight = input.new_empty(0) |
| else: |
| assert weight.dim() == 1 |
| assert input.size(1) == weight.size(0) |
| ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} |
| assert reduction in ctx.reduction_dict.keys() |
|
|
| ctx.gamma = float(gamma) |
| ctx.alpha = float(alpha) |
| ctx.reduction = ctx.reduction_dict[reduction] |
|
|
| channel_stats, _ = torch.max(input, dim=1) |
| input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) |
| input_softmax.exp_() |
|
|
| channel_stats = input_softmax.sum(dim=1) |
| input_softmax /= channel_stats.unsqueeze(1).expand_as(input) |
|
|
| output = input.new_zeros(input.size(0)) |
| ext_module.softmax_focal_loss_forward( |
| input_softmax, |
| target, |
| weight, |
| output, |
| gamma=ctx.gamma, |
| alpha=ctx.alpha) |
|
|
| if ctx.reduction == ctx.reduction_dict['mean']: |
| output = output.sum() / input.size(0) |
| elif ctx.reduction == ctx.reduction_dict['sum']: |
| output = output.sum() |
| ctx.save_for_backward(input_softmax, target, weight) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| input_softmax, target, weight = ctx.saved_tensors |
| buff = input_softmax.new_zeros(input_softmax.size(0)) |
| grad_input = input_softmax.new_zeros(input_softmax.size()) |
|
|
| ext_module.softmax_focal_loss_backward( |
| input_softmax, |
| target, |
| weight, |
| buff, |
| grad_input, |
| gamma=ctx.gamma, |
| alpha=ctx.alpha) |
|
|
| grad_input *= grad_output |
| if ctx.reduction == ctx.reduction_dict['mean']: |
| grad_input /= input_softmax.size(0) |
| return grad_input, None, None, None, None, None |
|
|
|
|
| softmax_focal_loss = SoftmaxFocalLossFunction.apply |
|
|
|
|
| class SoftmaxFocalLoss(nn.Module): |
|
|
| def __init__(self, gamma, alpha, weight=None, reduction='mean'): |
| super(SoftmaxFocalLoss, self).__init__() |
| self.gamma = gamma |
| self.alpha = alpha |
| self.register_buffer('weight', weight) |
| self.reduction = reduction |
|
|
| def forward(self, input, target): |
| return softmax_focal_loss(input, target, self.gamma, self.alpha, |
| self.weight, self.reduction) |
|
|
| def __repr__(self): |
| s = self.__class__.__name__ |
| s += f'(gamma={self.gamma}, ' |
| s += f'alpha={self.alpha}, ' |
| s += f'reduction={self.reduction})' |
| return s |
|
|