Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class CrossEntropyLossWithLabelSmoothing(nn.Module): | |
| r""" | |
| PyTorch :class:`~torch.nn.CrossEntropyLoss` with label smoothing. Quoting | |
| documentation from original PyTorch module: | |
| It is useful when training a classification problem with ``C`` classes. | |
| The ``inputs`` is expected to contain raw, unnormalized scores for each class. | |
| ``inputs`` has to be a Tensor of size either ``(N, C)``. This criterion | |
| expects a class index in the range ``[0, C - 1]`` as the ``targets`` for each | |
| value of a 1D tensor of size ``minibatch``; if ``ignore_index`` is specified, | |
| this criterion also accepts this class index (this index may not necessarily be | |
| in the class range). | |
| Parameters | |
| ---------- | |
| smoothing: float, optional (default = 0.1) | |
| Label smoothing value. It sets target weights as ``(1 - smoothing)`` | |
| and all other weights as ``smoothing / (C - 1)``. Setting this to | |
| zero will default to vanilla cross entropy loss. | |
| """ | |
| def __init__(self, smoothing: float = 0.0, ignore_index: int = -100): | |
| super().__init__() | |
| self.smoothing = smoothing | |
| self.ignore_index = ignore_index | |
| def forward(self, inputs, targets): | |
| if self.smoothing == 0.0: | |
| # Use PyTorch cross entropy when smoothing is 0. This is slightly | |
| # faster than what we are doing manually below. | |
| return F.cross_entropy( | |
| inputs, targets, ignore_index=self.ignore_index, reduction="mean" | |
| ) | |
| # Remove entries matching ``ignore_index``. | |
| if self.ignore_index >= 0: | |
| _targets = targets[targets != self.ignore_index] | |
| _inputs = inputs[targets != self.ignore_index] | |
| # shape: (batch_size, num_classes) | |
| logprobs = F.log_softmax(_inputs, dim=-1) | |
| # shape: (batch_size, num_classes) | |
| weights = ( | |
| torch.ones_like(_inputs) * self.smoothing / (_inputs.size(-1) - 1.0) | |
| ) | |
| weights.scatter_(-1, _targets.unsqueeze(-1), (1.0 - self.smoothing)) | |
| # shape: (batch_size, ) | |
| loss = (- weights * logprobs).sum(dim=-1) | |
| return loss.mean() | |