STAR / fairseq /criterions /sentence_prediction_adapters.py
Yixuan Li
add fairseq folder
85ba398
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from fairseq.criterions import register_criterion
from fairseq.criterions.sentence_prediction import (
SentencePredictionCriterion,
SentencePredictionConfig,
)
@register_criterion("sentence_prediction_adapters", dataclass=SentencePredictionConfig)
class SentencePredictionCriterionAdapters(SentencePredictionCriterion):
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
if not hasattr(sample, "lang_id"):
# If no language ID is given, we fall back to English
lang_id = ["en_XX"] * sample["nsentences"]
else:
lang_id = sample["lang_id"]
logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
lang_id=lang_id,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output