| import torch | |
| from torch import nn | |
| from torch.nn import KLDivLoss | |
| from torch.nn import LogSoftmax | |
| class LabelSmoothingLoss(nn.Module): | |
| def __init__(self, label_smoothing=0.0, unreliable_label=None, ignore_index=-100): | |
| """ | |
| If label_smoothing == 0.0, it is equivalent to xentropy | |
| """ | |
| assert 0.0 <= label_smoothing <= 1.0 | |
| super(LabelSmoothingLoss, self).__init__() | |
| self.ignore_index = ignore_index | |
| self.label_smoothing = label_smoothing | |
| self.loss_fn = KLDivLoss(reduction='batchmean') | |
| self.unreliable_label = unreliable_label | |
| self.max_gap = 100. | |
| self.log_softmax = LogSoftmax(1) | |
| def forward(self, output, target): | |
| """ | |
| output: logits | |
| target: labels | |
| """ | |
| vocab_size = output.shape[1] | |
| mask = (target != self.ignore_index) | |
| output, target = output[mask], target[mask] | |
| output = self.log_softmax(output) | |
| def get_smooth_prob(ls): | |
| smoothing_value = ls / (vocab_size - 1) | |
| prob = output.new_full((target.size(0), vocab_size), smoothing_value) | |
| prob.scatter_(1, target.unsqueeze(1), 1 - ls) | |
| return prob | |
| if self.unreliable_label is not None: | |
| smoothed_prob = get_smooth_prob(self.label_smoothing) | |
| hard_prob = get_smooth_prob(0.0) | |
| unreliable_mask = (target == self.unreliable_label).to(torch.float) | |
| model_prob = ((smoothed_prob.T * unreliable_mask) + (hard_prob.T * (1 - unreliable_mask))).T | |
| else: | |
| model_prob = get_smooth_prob(self.label_smoothing) | |
| loss = self.loss_fn(output, model_prob) | |
| return loss | |