File size: 813 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torch import nn

class MLMLoss(nn.Module):
    """
    Masked Language Modeling loss.
    """
    def __init__(self, mse_based=False):
        super(MLMLoss, self).__init__()
        self.mse_based = mse_based
        if self.mse_based:
            self.loss_fn = nn.MSELoss(reduction='none')
        else:
            self.loss_fn = nn.CrossEntropyLoss(reduction='none')

    def forward(self, predictions, targets, mask):
        if self.mse_based:
            predictions = predictions.squeeze(-1)
        else:
            predictions = predictions.permute(0, 2, 1) # (batch_size, vocab_size, seq_len)
            targets = targets.long()
            
        masked_loss = self.loss_fn(predictions, targets)
        masked_loss = masked_loss * mask.float() 
        return masked_loss.sum() / mask.sum()