| import torch | |
| from torch import Tensor | |
| from torch.nn import CrossEntropyLoss | |
| class CrossEntropyLossWithZLoss(CrossEntropyLoss): | |
| def __init__( | |
| self, | |
| eps: float = 1e-4, | |
| weight: Tensor = None, | |
| size_average=None, | |
| ignore_index: int = -100, | |
| reduce=None, | |
| reduction: str = "mean", | |
| label_smoothing: float = 0, | |
| ) -> None: | |
| super().__init__(weight, size_average, ignore_index, reduce, reduction, label_smoothing) | |
| self.eps = eps | |
| def forward(self, input: Tensor, target: Tensor) -> Tensor: | |
| return super().forward(input, target) + self.eps * torch.square(torch.logsumexp(input, dim=-1)).mean() | |