import torch import torch.nn as nn class SelfAttentionPooling(nn.Module): """ Implementation of SelfAttentionPooling Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition https://arxiv.org/pdf/2008.01077v1.pdf """ def __init__(self, input_dim): super(SelfAttentionPooling, self).__init__() self.W = nn.Linear(input_dim, 1) def forward(self, batch_rep, att_mask): """ input: batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension attention_weight: att_w : size (N, T, 1) return: utter_rep: size (N, H) """ seq_len = batch_rep.shape[1] softmax = nn.functional.softmax att_logits = self.W(batch_rep).squeeze(-1) att_mask = att_mask[:, :, 0] att_logits = att_mask + att_logits att_w = softmax(att_logits, dim=-1).unsqueeze(-1) utter_rep = torch.sum(batch_rep * att_w, dim=1) attn_out_std = torch.sqrt(torch.sum(att_w * (batch_rep - utter_rep.unsqueeze(1))**2, dim=1)) return utter_rep, attn_out_std