PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /examples /attention_head_selection /src /loss /attention_head_selection.py
ash56's picture
Add files using upload-large-folder tool
878264b verified
raw
history blame
872 Bytes
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.nn.modules.loss import _Loss
class HeadSelectionLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
self.kl_weight = getattr(args, "kl_weight", 0.0)
def forward(self, head_samples, sample_sizes, prior=0.5, eps=1e-7):
"""
head_scores: (num_tasks, num_layers, num_heads)
sample_sizes: (num_tasks, )
"""
kl_loss = (head_samples * (torch.log(head_samples + eps) - math.log(prior))).sum(-1).sum(-1)
kl_loss /= (torch.numel(head_samples) / head_samples.size(0))
kl_loss = self.kl_weight * torch.matmul(kl_loss, sample_sizes)
return kl_loss