Spaces:
Sleeping
Sleeping
File size: 813 Bytes
ef814bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | from torch import nn
class MLMLoss(nn.Module):
"""
Masked Language Modeling loss.
"""
def __init__(self, mse_based=False):
super(MLMLoss, self).__init__()
self.mse_based = mse_based
if self.mse_based:
self.loss_fn = nn.MSELoss(reduction='none')
else:
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
def forward(self, predictions, targets, mask):
if self.mse_based:
predictions = predictions.squeeze(-1)
else:
predictions = predictions.permute(0, 2, 1) # (batch_size, vocab_size, seq_len)
targets = targets.long()
masked_loss = self.loss_fn(predictions, targets)
masked_loss = masked_loss * mask.float()
return masked_loss.sum() / mask.sum() |