Spaces:
Runtime error
Runtime error
| """ | |
| Misc Losses | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .builder import LOSSES | |
| class CrossEntropyLoss(nn.Module): | |
| def __init__( | |
| self, | |
| weight=None, | |
| size_average=None, | |
| reduce=None, | |
| reduction="mean", | |
| label_smoothing=0.0, | |
| loss_weight=1.0, | |
| ignore_index=-1, | |
| ): | |
| super(CrossEntropyLoss, self).__init__() | |
| weight = torch.tensor(weight).cuda() if weight is not None else None | |
| self.loss_weight = loss_weight | |
| self.loss = nn.CrossEntropyLoss( | |
| weight=weight, | |
| size_average=size_average, | |
| ignore_index=ignore_index, | |
| reduce=reduce, | |
| reduction=reduction, | |
| label_smoothing=label_smoothing, | |
| ) | |
| def forward(self, pred, target): | |
| return self.loss(pred, target) * self.loss_weight | |
| class SmoothCELoss(nn.Module): | |
| def __init__(self, smoothing_ratio=0.1): | |
| super(SmoothCELoss, self).__init__() | |
| self.smoothing_ratio = smoothing_ratio | |
| def forward(self, pred, target): | |
| eps = self.smoothing_ratio | |
| n_class = pred.size(1) | |
| one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1) | |
| one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) | |
| log_prb = F.log_softmax(pred, dim=1) | |
| loss = -(one_hot * log_prb).total(dim=1) | |
| loss = loss[torch.isfinite(loss)].mean() | |
| return loss | |
| class BinaryFocalLoss(nn.Module): | |
| def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0): | |
| """Binary Focal Loss | |
| <https://arxiv.org/abs/1708.02002>` | |
| """ | |
| super(BinaryFocalLoss, self).__init__() | |
| assert 0 < alpha < 1 | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| self.logits = logits | |
| self.reduce = reduce | |
| self.loss_weight = loss_weight | |
| def forward(self, pred, target, **kwargs): | |
| """Forward function. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N) | |
| target (torch.Tensor): The ground truth. If containing class | |
| indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities, | |
| same shape as the input. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| if self.logits: | |
| bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none") | |
| else: | |
| bce = F.binary_cross_entropy(pred, target, reduction="none") | |
| pt = torch.exp(-bce) | |
| alpha = self.alpha * target + (1 - self.alpha) * (1 - target) | |
| focal_loss = alpha * (1 - pt) ** self.gamma * bce | |
| if self.reduce: | |
| focal_loss = torch.mean(focal_loss) | |
| return focal_loss * self.loss_weight | |
| class FocalLoss(nn.Module): | |
| def __init__( | |
| self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1 | |
| ): | |
| """Focal Loss | |
| <https://arxiv.org/abs/1708.02002>` | |
| """ | |
| super(FocalLoss, self).__init__() | |
| assert reduction in ( | |
| "mean", | |
| "sum", | |
| ), "AssertionError: reduction should be 'mean' or 'sum'" | |
| assert isinstance( | |
| alpha, (float, list) | |
| ), "AssertionError: alpha should be of type float" | |
| assert isinstance(gamma, float), "AssertionError: gamma should be of type float" | |
| assert isinstance( | |
| loss_weight, float | |
| ), "AssertionError: loss_weight should be of type float" | |
| assert isinstance(ignore_index, int), "ignore_index must be of type int" | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.ignore_index = ignore_index | |
| def forward(self, pred, target, **kwargs): | |
| """Forward function. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes. | |
| target (torch.Tensor): The ground truth. If containing class | |
| indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities, | |
| same shape as the input. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] | |
| pred = pred.transpose(0, 1) | |
| # [C, B, d_1, d_2, ..., d_k] -> [C, N] | |
| pred = pred.reshape(pred.size(0), -1) | |
| # [C, N] -> [N, C] | |
| pred = pred.transpose(0, 1).contiguous() | |
| # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,) | |
| target = target.view(-1).contiguous() | |
| assert pred.size(0) == target.size( | |
| 0 | |
| ), "The shape of pred doesn't match the shape of target" | |
| valid_mask = target != self.ignore_index | |
| target = target[valid_mask] | |
| pred = pred[valid_mask] | |
| if len(target) == 0: | |
| return 0.0 | |
| num_classes = pred.size(1) | |
| target = F.one_hot(target, num_classes=num_classes) | |
| alpha = self.alpha | |
| if isinstance(alpha, list): | |
| alpha = pred.new_tensor(alpha) | |
| pred_sigmoid = pred.sigmoid() | |
| target = target.type_as(pred) | |
| one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) | |
| focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow( | |
| self.gamma | |
| ) | |
| loss = ( | |
| F.binary_cross_entropy_with_logits(pred, target, reduction="none") | |
| * focal_weight | |
| ) | |
| if self.reduction == "mean": | |
| loss = loss.mean() | |
| elif self.reduction == "sum": | |
| loss = loss.total() | |
| return self.loss_weight * loss | |
| class DiceLoss(nn.Module): | |
| def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1): | |
| """DiceLoss. | |
| This loss is proposed in `V-Net: Fully Convolutional Neural Networks for | |
| Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_. | |
| """ | |
| super(DiceLoss, self).__init__() | |
| self.smooth = smooth | |
| self.exponent = exponent | |
| self.loss_weight = loss_weight | |
| self.ignore_index = ignore_index | |
| def forward(self, pred, target, **kwargs): | |
| # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] | |
| pred = pred.transpose(0, 1) | |
| # [C, B, d_1, d_2, ..., d_k] -> [C, N] | |
| pred = pred.reshape(pred.size(0), -1) | |
| # [C, N] -> [N, C] | |
| pred = pred.transpose(0, 1).contiguous() | |
| # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,) | |
| target = target.view(-1).contiguous() | |
| assert pred.size(0) == target.size( | |
| 0 | |
| ), "The shape of pred doesn't match the shape of target" | |
| valid_mask = target != self.ignore_index | |
| target = target[valid_mask] | |
| pred = pred[valid_mask] | |
| pred = F.softmax(pred, dim=1) | |
| num_classes = pred.shape[1] | |
| target = F.one_hot( | |
| torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes | |
| ) | |
| total_loss = 0 | |
| for i in range(num_classes): | |
| if i != self.ignore_index: | |
| num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth | |
| den = ( | |
| torch.sum( | |
| pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent) | |
| ) | |
| + self.smooth | |
| ) | |
| dice_loss = 1 - num / den | |
| total_loss += dice_loss | |
| loss = total_loss / num_classes | |
| return self.loss_weight * loss | |