Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| import logging | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from fairseq import utils | |
| from fairseq.criterions import FairseqCriterion, register_criterion | |
| class CrossEntropyWithAccCriterion(FairseqCriterion): | |
| def __init__(self, task, sentence_avg): | |
| super().__init__(task) | |
| self.sentence_avg = sentence_avg | |
| def compute_loss(self, model, net_output, target, reduction, log_probs): | |
| # N, T -> N * T | |
| target = target.view(-1) | |
| lprobs = model.get_normalized_probs(net_output, log_probs=log_probs) | |
| if not hasattr(lprobs, "batch_first"): | |
| logging.warning( | |
| "ERROR: we need to know whether " | |
| "batch first for the net output; " | |
| "you need to set batch_first attribute for the return value of " | |
| "model.get_normalized_probs. Now, we assume this is true, but " | |
| "in the future, we will raise exception instead. " | |
| ) | |
| batch_first = getattr(lprobs, "batch_first", True) | |
| if not batch_first: | |
| lprobs = lprobs.transpose(0, 1) | |
| # N, T, D -> N * T, D | |
| lprobs = lprobs.view(-1, lprobs.size(-1)) | |
| loss = F.nll_loss( | |
| lprobs, target, ignore_index=self.padding_idx, reduction=reduction | |
| ) | |
| return lprobs, loss | |
| def get_logging_output(self, sample, target, lprobs, loss): | |
| target = target.view(-1) | |
| mask = target != self.padding_idx | |
| correct = torch.sum( | |
| lprobs.argmax(1).masked_select(mask) == target.masked_select(mask) | |
| ) | |
| total = torch.sum(mask) | |
| sample_size = ( | |
| sample["target"].size(0) if self.sentence_avg else sample["ntokens"] | |
| ) | |
| logging_output = { | |
| "loss": utils.item(loss.data), # * sample['ntokens'], | |
| "ntokens": sample["ntokens"], | |
| "nsentences": sample["target"].size(0), | |
| "sample_size": sample_size, | |
| "correct": utils.item(correct.data), | |
| "total": utils.item(total.data), | |
| "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(), | |
| } | |
| return sample_size, logging_output | |
| def forward(self, model, sample, reduction="sum", log_probs=True): | |
| """Computes the cross entropy with accuracy metric for the given sample. | |
| This is similar to CrossEntropyCriterion in fairseq, but also | |
| computes accuracy metrics as part of logging | |
| Args: | |
| logprobs (Torch.tensor) of shape N, T, D i.e. | |
| batchsize, timesteps, dimensions | |
| targets (Torch.tensor) of shape N, T i.e batchsize, timesteps | |
| Returns: | |
| tuple: With three elements: | |
| 1) the loss | |
| 2) the sample size, which is used as the denominator for the gradient | |
| 3) logging outputs to display while training | |
| TODO: | |
| * Currently this Criterion will only work with LSTMEncoderModels or | |
| FairseqModels which have decoder, or Models which return TorchTensor | |
| as net_output. | |
| We need to make a change to support all FairseqEncoder models. | |
| """ | |
| net_output = model(**sample["net_input"]) | |
| target = model.get_targets(sample, net_output) | |
| lprobs, loss = self.compute_loss( | |
| model, net_output, target, reduction, log_probs | |
| ) | |
| sample_size, logging_output = self.get_logging_output( | |
| sample, target, lprobs, loss | |
| ) | |
| return loss, sample_size, logging_output | |
| def aggregate_logging_outputs(logging_outputs): | |
| """Aggregate logging outputs from data parallel training.""" | |
| correct_sum = sum(log.get("correct", 0) for log in logging_outputs) | |
| total_sum = sum(log.get("total", 0) for log in logging_outputs) | |
| loss_sum = sum(log.get("loss", 0) for log in logging_outputs) | |
| ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | |
| nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) | |
| sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) | |
| nframes = sum(log.get("nframes", 0) for log in logging_outputs) | |
| agg_output = { | |
| "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0, | |
| # if args.sentence_avg, then sample_size is nsentences, then loss | |
| # is per-sentence loss; else sample_size is ntokens, the loss | |
| # becomes per-output token loss | |
| "ntokens": ntokens, | |
| "nsentences": nsentences, | |
| "nframes": nframes, | |
| "sample_size": sample_size, | |
| "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0, | |
| "correct": correct_sum, | |
| "total": total_sum, | |
| # total is the number of validate tokens | |
| } | |
| if sample_size != ntokens: | |
| agg_output["nll_loss"] = loss_sum / ntokens / math.log(2) | |
| # loss: per output token loss | |
| # nll_loss: per sentence loss | |
| return agg_output | |