Spaces:
Build error
Build error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MLECriterion(nn.Module): | |
| """ | |
| Class to define loss give input, model output and groundtruth | |
| """ | |
| def __init__(self, opt, module): | |
| super().__init__() | |
| self.opt = opt | |
| self.ignore_index = ( | |
| self.opt["IGNORE_INDEX"] | |
| if "IGNORE_INDEX" in self.opt | |
| else module.tokenizer.pad_token_id | |
| ) | |
| def forward(self, vocab_logprob, batch): | |
| extended_vocab_size = vocab_logprob.shape[2] | |
| y = batch["decoder_input_ids"] | |
| if "USE_BOS_TOKEN" in self.opt: | |
| y = y[:, 1:] | |
| if "USE_EOS_TOKEN" in self.opt: | |
| vocab_logprob = vocab_logprob[:, :-1, :] | |
| loss = F.nll_loss( | |
| vocab_logprob.contiguous().view(-1, extended_vocab_size), | |
| y.contiguous().view(-1), | |
| ignore_index=self.ignore_index, | |
| ) | |
| return loss | |