| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from mmseg.registry import MODELS |
|
|
|
|
| @MODELS.register_module() |
| class KLDivLoss(nn.Module): |
|
|
| def __init__(self, |
| temperature: float = 1.0, |
| reduction: str = 'mean', |
| loss_name: str = 'loss_kld'): |
| """Kullback-Leibler divergence Loss. |
| |
| <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence> |
| |
| Args: |
| temperature (float, optional): Temperature param |
| reduction (str, optional): The method to reduce the loss into a |
| scalar. Default is "mean". Options are "none", "sum", |
| and "mean" |
| """ |
|
|
| assert isinstance(temperature, (float, int)), \ |
| 'Expected temperature to be' \ |
| f'float or int, but got {temperature.__class__.__name__} instead' |
| assert temperature != 0., 'Temperature must not be zero' |
|
|
| assert reduction in ['mean', 'none', 'sum'], \ |
| 'Reduction must be one of the options ("mean", ' \ |
| f'"sum", "none"), but got {reduction}' |
|
|
| super().__init__() |
| self.temperature = temperature |
| self.reduction = reduction |
| self._loss_name = loss_name |
|
|
| def forward(self, input: torch.Tensor, target: torch.Tensor): |
| """Forward function. Calculate KL divergence Loss. |
| |
| Args: |
| input (Tensor): Logit tensor, |
| the data type is float32 or float64. |
| The shape is (N, C) where N is batchsize and C is number of |
| channels. |
| If there more than 2 dimensions, shape is (N, C, D1, D2, ... |
| Dk), k>= 1 |
| target (Tensor): Logit tensor, |
| the data type is float32 or float64. |
| input and target must be with the same shape. |
| |
| Returns: |
| (Tensor): Reduced loss. |
| """ |
| assert isinstance(input, torch.Tensor), 'Expected input to' \ |
| f'be Tensor, but got {input.__class__.__name__} instead' |
| assert isinstance(target, torch.Tensor), 'Expected target to' \ |
| f'be Tensor, but got {target.__class__.__name__} instead' |
|
|
| assert input.shape == target.shape, 'Input and target ' \ |
| 'must have same shape,' \ |
| f'but got shapes {input.shape} and {target.shape}' |
|
|
| input = F.softmax(input / self.temperature, dim=1) |
| target = F.softmax(target / self.temperature, dim=1) |
|
|
| loss = F.kl_div(input, target, reduction='none', log_target=False) |
| loss = loss * self.temperature**2 |
|
|
| batch_size = input.shape[0] |
|
|
| if self.reduction == 'sum': |
| |
| loss = loss.view(batch_size, -1) |
| return torch.sum(loss, dim=1) |
|
|
| elif self.reduction == 'mean': |
| |
| loss = loss.view(batch_size, -1) |
| return torch.mean(loss, dim=1) |
|
|
| return loss |
|
|
| @property |
| def loss_name(self): |
| """Loss Name. |
| |
| This function must be implemented and will return the name of this |
| loss function. This name will be used to combine different loss items |
| by simple sum operation. In addition, if you want this loss item to be |
| included into the backward graph, `loss_` must be the prefix of the |
| name. |
| Returns: |
| str: The name of this loss item. |
| """ |
| return self._loss_name |
|
|