| |
| from typing import Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from mmseg.registry import MODELS |
| from .utils import weight_reduce_loss |
|
|
|
|
| def _expand_onehot_labels_dice(pred: torch.Tensor, |
| target: torch.Tensor) -> torch.Tensor: |
| """Expand onehot labels to match the size of prediction. |
| |
| Args: |
| pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W). |
| target (torch.Tensor): The learning label of the prediction, |
| has a shape (N, H, W). |
| |
| Returns: |
| torch.Tensor: The target after one-hot encoding, |
| has a shape (N, num_class, H, W). |
| """ |
| num_classes = pred.shape[1] |
| one_hot_target = torch.clamp(target, min=0, max=num_classes) |
| one_hot_target = torch.nn.functional.one_hot(one_hot_target, |
| num_classes + 1) |
| one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2) |
| return one_hot_target |
|
|
|
|
| def dice_loss(pred: torch.Tensor, |
| target: torch.Tensor, |
| weight: Union[torch.Tensor, None], |
| eps: float = 1e-3, |
| reduction: Union[str, None] = 'mean', |
| naive_dice: Union[bool, None] = False, |
| avg_factor: Union[int, None] = None, |
| ignore_index: Union[int, None] = 255) -> float: |
| """Calculate dice loss, there are two forms of dice loss is supported: |
| |
| - the one proposed in `V-Net: Fully Convolutional Neural |
| Networks for Volumetric Medical Image Segmentation |
| <https://arxiv.org/abs/1606.04797>`_. |
| - the dice loss in which the power of the number in the |
| denominator is the first power instead of the second |
| power. |
| |
| Args: |
| pred (torch.Tensor): The prediction, has a shape (n, *) |
| target (torch.Tensor): The learning label of the prediction, |
| shape (n, *), same shape of pred. |
| weight (torch.Tensor, optional): The weight of loss for each |
| prediction, has a shape (n,). Defaults to None. |
| eps (float): Avoid dividing by zero. Default: 1e-3. |
| reduction (str, optional): The method used to reduce the loss into |
| a scalar. Defaults to 'mean'. |
| Options are "none", "mean" and "sum". |
| naive_dice (bool, optional): If false, use the dice |
| loss defined in the V-Net paper, otherwise, use the |
| naive dice loss in which the power of the number in the |
| denominator is the first power instead of the second |
| power.Defaults to False. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| ignore_index (int, optional): The label index to be ignored. |
| Defaults to 255. |
| """ |
| if ignore_index is not None: |
| num_classes = pred.shape[1] |
| pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] |
| target = target[:, torch.arange(num_classes) != ignore_index, :, :] |
| assert pred.shape[1] != 0 |
| input = pred.flatten(1) |
| target = target.flatten(1).float() |
| a = torch.sum(input * target, 1) |
| if naive_dice: |
| b = torch.sum(input, 1) |
| c = torch.sum(target, 1) |
| d = (2 * a + eps) / (b + c + eps) |
| else: |
| b = torch.sum(input * input, 1) + eps |
| c = torch.sum(target * target, 1) + eps |
| d = (2 * a) / (b + c) |
|
|
| loss = 1 - d |
| if weight is not None: |
| assert weight.ndim == loss.ndim |
| assert len(weight) == len(pred) |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
| return loss |
|
|
|
|
| @MODELS.register_module() |
| class DiceLoss(nn.Module): |
|
|
| def __init__(self, |
| use_sigmoid=True, |
| activate=True, |
| reduction='mean', |
| naive_dice=False, |
| loss_weight=1.0, |
| ignore_index=255, |
| eps=1e-3, |
| loss_name='loss_dice'): |
| """Compute dice loss. |
| |
| Args: |
| use_sigmoid (bool, optional): Whether to the prediction is |
| used for sigmoid or softmax. Defaults to True. |
| activate (bool): Whether to activate the predictions inside, |
| this will disable the inside sigmoid operation. |
| Defaults to True. |
| reduction (str, optional): The method used |
| to reduce the loss. Options are "none", |
| "mean" and "sum". Defaults to 'mean'. |
| naive_dice (bool, optional): If false, use the dice |
| loss defined in the V-Net paper, otherwise, use the |
| naive dice loss in which the power of the number in the |
| denominator is the first power instead of the second |
| power. Defaults to False. |
| loss_weight (float, optional): Weight of loss. Defaults to 1.0. |
| ignore_index (int, optional): The label index to be ignored. |
| Default: 255. |
| eps (float): Avoid dividing by zero. Defaults to 1e-3. |
| loss_name (str, optional): Name of the loss item. If you want this |
| loss item to be included into the backward graph, `loss_` must |
| be the prefix of the name. Defaults to 'loss_dice'. |
| """ |
|
|
| super().__init__() |
| self.use_sigmoid = use_sigmoid |
| self.reduction = reduction |
| self.naive_dice = naive_dice |
| self.loss_weight = loss_weight |
| self.eps = eps |
| self.activate = activate |
| self.ignore_index = ignore_index |
| self._loss_name = loss_name |
|
|
| def forward(self, |
| pred, |
| target, |
| weight=None, |
| avg_factor=None, |
| reduction_override=None, |
| ignore_index=255, |
| **kwargs): |
| """Forward function. |
| |
| Args: |
| pred (torch.Tensor): The prediction, has a shape (n, *). |
| target (torch.Tensor): The label of the prediction, |
| shape (n, *), same shape of pred. |
| weight (torch.Tensor, optional): The weight of loss for each |
| prediction, has a shape (n,). Defaults to None. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| reduction_override (str, optional): The reduction method used to |
| override the original reduction method of the loss. |
| Options are "none", "mean" and "sum". |
| |
| Returns: |
| torch.Tensor: The calculated loss |
| """ |
| one_hot_target = target |
| if (pred.shape != target.shape): |
| one_hot_target = _expand_onehot_labels_dice(pred, target) |
| assert reduction_override in (None, 'none', 'mean', 'sum') |
| reduction = ( |
| reduction_override if reduction_override else self.reduction) |
| if self.activate: |
| if self.use_sigmoid: |
| pred = pred.sigmoid() |
| elif pred.shape[1] != 1: |
| |
| pred = pred.softmax(dim=1) |
| loss = self.loss_weight * dice_loss( |
| pred, |
| one_hot_target, |
| weight, |
| eps=self.eps, |
| reduction=reduction, |
| naive_dice=self.naive_dice, |
| avg_factor=avg_factor, |
| ignore_index=self.ignore_index) |
|
|
| return loss |
|
|
| @property |
| def loss_name(self): |
| """Loss Name. |
| |
| This function must be implemented and will return the name of this |
| loss function. This name will be used to combine different loss items |
| by simple sum operation. In addition, if you want this loss item to be |
| included into the backward graph, `loss_` must be the prefix of the |
| name. |
| Returns: |
| str: The name of this loss item. |
| """ |
| return self._loss_name |
|
|