|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from fairseq.criterions import register_criterion |
|
|
from fairseq.criterions.sentence_prediction import ( |
|
|
SentencePredictionCriterion, |
|
|
SentencePredictionConfig, |
|
|
) |
|
|
|
|
|
|
|
|
@register_criterion("sentence_prediction_adapters", dataclass=SentencePredictionConfig) |
|
|
class SentencePredictionCriterionAdapters(SentencePredictionCriterion): |
|
|
def forward(self, model, sample, reduce=True): |
|
|
"""Compute the loss for the given sample. |
|
|
|
|
|
Returns a tuple with three elements: |
|
|
1) the loss |
|
|
2) the sample size, which is used as the denominator for the gradient |
|
|
3) logging outputs to display while training |
|
|
""" |
|
|
assert ( |
|
|
hasattr(model, "classification_heads") |
|
|
and self.classification_head_name in model.classification_heads |
|
|
), "model must provide sentence classification head for --criterion=sentence_prediction" |
|
|
|
|
|
if not hasattr(sample, "lang_id"): |
|
|
|
|
|
lang_id = ["en_XX"] * sample["nsentences"] |
|
|
else: |
|
|
lang_id = sample["lang_id"] |
|
|
|
|
|
logits, _ = model( |
|
|
**sample["net_input"], |
|
|
features_only=True, |
|
|
classification_head_name=self.classification_head_name, |
|
|
lang_id=lang_id, |
|
|
) |
|
|
targets = model.get_targets(sample, [logits]).view(-1) |
|
|
sample_size = targets.numel() |
|
|
|
|
|
if not self.regression_target: |
|
|
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) |
|
|
loss = F.nll_loss(lprobs, targets, reduction="sum") |
|
|
else: |
|
|
logits = logits.view(-1).float() |
|
|
targets = targets.float() |
|
|
loss = F.mse_loss(logits, targets, reduction="sum") |
|
|
|
|
|
logging_output = { |
|
|
"loss": loss.data, |
|
|
"ntokens": sample["ntokens"], |
|
|
"nsentences": sample_size, |
|
|
"sample_size": sample_size, |
|
|
} |
|
|
if not self.regression_target: |
|
|
preds = logits.argmax(dim=1) |
|
|
logging_output["ncorrect"] = (preds == targets).sum() |
|
|
|
|
|
return loss, sample_size, logging_output |
|
|
|