| | |
| | import torch |
| | import torch.nn as nn |
| | from torch.autograd import Function |
| | from torch.autograd.function import once_differentiable |
| |
|
| | from ..utils import ext_loader |
| |
|
| | ext_module = ext_loader.load_ext('_ext', [ |
| | 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward', |
| | 'softmax_focal_loss_forward', 'softmax_focal_loss_backward' |
| | ]) |
| |
|
| |
|
| | class SigmoidFocalLossFunction(Function): |
| |
|
| | @staticmethod |
| | def symbolic(g, input, target, gamma, alpha, weight, reduction): |
| | return g.op( |
| | 'mmcv::MMCVSigmoidFocalLoss', |
| | input, |
| | target, |
| | gamma_f=gamma, |
| | alpha_f=alpha, |
| | weight_f=weight, |
| | reduction_s=reduction) |
| |
|
| | @staticmethod |
| | def forward(ctx, |
| | input, |
| | target, |
| | gamma=2.0, |
| | alpha=0.25, |
| | weight=None, |
| | reduction='mean'): |
| |
|
| | assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) |
| | assert input.dim() == 2 |
| | assert target.dim() == 1 |
| | assert input.size(0) == target.size(0) |
| | if weight is None: |
| | weight = input.new_empty(0) |
| | else: |
| | assert weight.dim() == 1 |
| | assert input.size(1) == weight.size(0) |
| | ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} |
| | assert reduction in ctx.reduction_dict.keys() |
| |
|
| | ctx.gamma = float(gamma) |
| | ctx.alpha = float(alpha) |
| | ctx.reduction = ctx.reduction_dict[reduction] |
| |
|
| | output = input.new_zeros(input.size()) |
| |
|
| | ext_module.sigmoid_focal_loss_forward( |
| | input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha) |
| | if ctx.reduction == ctx.reduction_dict['mean']: |
| | output = output.sum() / input.size(0) |
| | elif ctx.reduction == ctx.reduction_dict['sum']: |
| | output = output.sum() |
| | ctx.save_for_backward(input, target, weight) |
| | return output |
| |
|
| | @staticmethod |
| | @once_differentiable |
| | def backward(ctx, grad_output): |
| | input, target, weight = ctx.saved_tensors |
| |
|
| | grad_input = input.new_zeros(input.size()) |
| |
|
| | ext_module.sigmoid_focal_loss_backward( |
| | input, |
| | target, |
| | weight, |
| | grad_input, |
| | gamma=ctx.gamma, |
| | alpha=ctx.alpha) |
| |
|
| | grad_input *= grad_output |
| | if ctx.reduction == ctx.reduction_dict['mean']: |
| | grad_input /= input.size(0) |
| | return grad_input, None, None, None, None, None |
| |
|
| |
|
| | sigmoid_focal_loss = SigmoidFocalLossFunction.apply |
| |
|
| |
|
| | class SigmoidFocalLoss(nn.Module): |
| |
|
| | def __init__(self, gamma, alpha, weight=None, reduction='mean'): |
| | super(SigmoidFocalLoss, self).__init__() |
| | self.gamma = gamma |
| | self.alpha = alpha |
| | self.register_buffer('weight', weight) |
| | self.reduction = reduction |
| |
|
| | def forward(self, input, target): |
| | return sigmoid_focal_loss(input, target, self.gamma, self.alpha, |
| | self.weight, self.reduction) |
| |
|
| | def __repr__(self): |
| | s = self.__class__.__name__ |
| | s += f'(gamma={self.gamma}, ' |
| | s += f'alpha={self.alpha}, ' |
| | s += f'reduction={self.reduction})' |
| | return s |
| |
|
| |
|
| | class SoftmaxFocalLossFunction(Function): |
| |
|
| | @staticmethod |
| | def symbolic(g, input, target, gamma, alpha, weight, reduction): |
| | return g.op( |
| | 'mmcv::MMCVSoftmaxFocalLoss', |
| | input, |
| | target, |
| | gamma_f=gamma, |
| | alpha_f=alpha, |
| | weight_f=weight, |
| | reduction_s=reduction) |
| |
|
| | @staticmethod |
| | def forward(ctx, |
| | input, |
| | target, |
| | gamma=2.0, |
| | alpha=0.25, |
| | weight=None, |
| | reduction='mean'): |
| |
|
| | assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) |
| | assert input.dim() == 2 |
| | assert target.dim() == 1 |
| | assert input.size(0) == target.size(0) |
| | if weight is None: |
| | weight = input.new_empty(0) |
| | else: |
| | assert weight.dim() == 1 |
| | assert input.size(1) == weight.size(0) |
| | ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} |
| | assert reduction in ctx.reduction_dict.keys() |
| |
|
| | ctx.gamma = float(gamma) |
| | ctx.alpha = float(alpha) |
| | ctx.reduction = ctx.reduction_dict[reduction] |
| |
|
| | channel_stats, _ = torch.max(input, dim=1) |
| | input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) |
| | input_softmax.exp_() |
| |
|
| | channel_stats = input_softmax.sum(dim=1) |
| | input_softmax /= channel_stats.unsqueeze(1).expand_as(input) |
| |
|
| | output = input.new_zeros(input.size(0)) |
| | ext_module.softmax_focal_loss_forward( |
| | input_softmax, |
| | target, |
| | weight, |
| | output, |
| | gamma=ctx.gamma, |
| | alpha=ctx.alpha) |
| |
|
| | if ctx.reduction == ctx.reduction_dict['mean']: |
| | output = output.sum() / input.size(0) |
| | elif ctx.reduction == ctx.reduction_dict['sum']: |
| | output = output.sum() |
| | ctx.save_for_backward(input_softmax, target, weight) |
| | return output |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | input_softmax, target, weight = ctx.saved_tensors |
| | buff = input_softmax.new_zeros(input_softmax.size(0)) |
| | grad_input = input_softmax.new_zeros(input_softmax.size()) |
| |
|
| | ext_module.softmax_focal_loss_backward( |
| | input_softmax, |
| | target, |
| | weight, |
| | buff, |
| | grad_input, |
| | gamma=ctx.gamma, |
| | alpha=ctx.alpha) |
| |
|
| | grad_input *= grad_output |
| | if ctx.reduction == ctx.reduction_dict['mean']: |
| | grad_input /= input_softmax.size(0) |
| | return grad_input, None, None, None, None, None |
| |
|
| |
|
| | softmax_focal_loss = SoftmaxFocalLossFunction.apply |
| |
|
| |
|
| | class SoftmaxFocalLoss(nn.Module): |
| |
|
| | def __init__(self, gamma, alpha, weight=None, reduction='mean'): |
| | super(SoftmaxFocalLoss, self).__init__() |
| | self.gamma = gamma |
| | self.alpha = alpha |
| | self.register_buffer('weight', weight) |
| | self.reduction = reduction |
| |
|
| | def forward(self, input, target): |
| | return softmax_focal_loss(input, target, self.gamma, self.alpha, |
| | self.weight, self.reduction) |
| |
|
| | def __repr__(self): |
| | s = self.__class__.__name__ |
| | s += f'(gamma={self.gamma}, ' |
| | s += f'alpha={self.alpha}, ' |
| | s += f'reduction={self.reduction})' |
| | return s |
| |
|