CoLMbo / encoder /attentive_pooling.py
massabaali's picture
Upload CoLMbo model weights and code
f55a095 verified
import torch
import torch.nn as nn
class SelfAttentionPooling(nn.Module):
"""
Implementation of SelfAttentionPooling
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
https://arxiv.org/pdf/2008.01077v1.pdf
"""
def __init__(self, input_dim):
super(SelfAttentionPooling, self).__init__()
self.W = nn.Linear(input_dim, 1)
def forward(self, batch_rep, att_mask):
"""
input:
batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
attention_weight:
att_w : size (N, T, 1)
return:
utter_rep: size (N, H)
"""
seq_len = batch_rep.shape[1]
softmax = nn.functional.softmax
att_logits = self.W(batch_rep).squeeze(-1)
att_mask = att_mask[:, :, 0]
att_logits = att_mask + att_logits
att_w = softmax(att_logits, dim=-1).unsqueeze(-1)
utter_rep = torch.sum(batch_rep * att_w, dim=1)
attn_out_std = torch.sqrt(torch.sum(att_w * (batch_rep - utter_rep.unsqueeze(1))**2, dim=1))
return utter_rep, attn_out_std