Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Shigeki Karita | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| """Label smoothing module.""" | |
| import torch | |
| from torch import nn | |
| from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
| class LabelSmoothingLoss(nn.Module): | |
| """Label-smoothing loss. | |
| :param int size: the number of class | |
| :param int padding_idx: ignored class id | |
| :param float smoothing: smoothing rate (0.0 means the conventional CE) | |
| :param bool normalize_length: normalize loss by sequence length if True | |
| :param torch.nn.Module criterion: loss function to be smoothed | |
| """ | |
| def __init__( | |
| self, | |
| size, | |
| padding_idx, | |
| smoothing, | |
| normalize_length=False, | |
| criterion=nn.KLDivLoss(reduction="none"), | |
| ): | |
| """Construct an LabelSmoothingLoss object.""" | |
| super(LabelSmoothingLoss, self).__init__() | |
| self.criterion = criterion | |
| self.padding_idx = padding_idx | |
| self.confidence = 1.0 - smoothing | |
| self.smoothing = smoothing | |
| self.size = size | |
| self.true_dist = None | |
| self.normalize_length = normalize_length | |
| def forward(self, x, target): | |
| """Compute loss between x and target. | |
| :param torch.Tensor x: prediction (batch, seqlen, class) | |
| :param torch.Tensor target: | |
| target signal masked with self.padding_id (batch, seqlen) | |
| :return: scalar float value | |
| :rtype torch.Tensor | |
| """ | |
| assert x.size(2) == self.size | |
| batch_size = x.size(0) | |
| x = x.view(-1, self.size) | |
| target = target.view(-1) | |
| with torch.no_grad(): | |
| true_dist = x.clone() | |
| true_dist.fill_(self.smoothing / (self.size - 1)) | |
| ignore = target == self.padding_idx # (B,) | |
| total = len(target) - ignore.sum().item() | |
| target = target.masked_fill(ignore, 0) # avoid -1 index | |
| true_dist.scatter_(1, target.unsqueeze(1), self.confidence) | |
| kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) | |
| denom = total if self.normalize_length else batch_size | |
| return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom | |
| class SequenceBinaryCrossEntropy(nn.Module): | |
| def __init__( | |
| self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none") | |
| ): | |
| super().__init__() | |
| self.normalize_length = normalize_length | |
| self.criterion = criterion | |
| def forward(self, pred, label, lengths): | |
| pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device) | |
| loss = self.criterion(pred, label) | |
| denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] | |
| return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom | |
| class NllLoss(nn.Module): | |
| """Nll loss. | |
| :param int size: the number of class | |
| :param int padding_idx: ignored class id | |
| :param bool normalize_length: normalize loss by sequence length if True | |
| :param torch.nn.Module criterion: loss function | |
| """ | |
| def __init__( | |
| self, | |
| size, | |
| padding_idx, | |
| normalize_length=False, | |
| criterion=nn.NLLLoss(reduction="none"), | |
| ): | |
| """Construct an NllLoss object.""" | |
| super(NllLoss, self).__init__() | |
| self.criterion = criterion | |
| self.padding_idx = padding_idx | |
| self.size = size | |
| self.true_dist = None | |
| self.normalize_length = normalize_length | |
| def forward(self, x, target): | |
| """Compute loss between x and target. | |
| :param torch.Tensor x: prediction (batch, seqlen, class) | |
| :param torch.Tensor target: | |
| target signal masked with self.padding_id (batch, seqlen) | |
| :return: scalar float value | |
| :rtype torch.Tensor | |
| """ | |
| assert x.size(2) == self.size | |
| batch_size = x.size(0) | |
| x = x.view(-1, self.size) | |
| target = target.view(-1) | |
| with torch.no_grad(): | |
| ignore = target == self.padding_idx # (B,) | |
| total = len(target) - ignore.sum().item() | |
| target = target.masked_fill(ignore, 0) # avoid -1 index | |
| kl = self.criterion(x, target) | |
| denom = total if self.normalize_length else batch_size | |
| return kl.masked_fill(ignore, 0).sum() / denom | |