from torch import nn import numpy as np import torch from torch.nn.utils import clip_grad_norm_ class SpeakerEncoder(nn.Module): """ Learn speaker representation from speech utterance of arbitrary lengths. """ def __init__(self, device, loss_device): super().__init__() self.loss_device = loss_device # lstm block consisting of 3 layers # takes input 80 channel log-mel spectrograms, projected to 256 dimensions self.lstm = nn.LSTM( input_size=80, hidden_size=256, num_layers=3, batch_first=True, dropout=0, bidirectional=False ).to(device) self.linear = nn.Linear(in_features=256, out_features=256).to(device) self.relu = nn.ReLU().to(device) # epsilon term for numerical stability ( ie - division by 0) self.epsilon = 1e-5 #Cosine similarity weights self.sim_weight = nn.Parameter(torch.tensor([5.])).to(loss_device) self.sim_bias = nn.Parameter(torch.tensor([-1.])).to(loss_device) def forward(self, utterances, h_init=None, c_init=None): # implement section 2.1 from https://arxiv.org/pdf/1806.04558.pdf if h_init is None or c_init is None: out, (hidden, cell) = self.lstm(utterances) else: out, (hidden, cell) = self.lstm(utterances, (h_init, c_init)) # compute speaker embedding from hidden state of final layer final_hidden = hidden[-1] speaker_embedding = self.relu(self.linear(final_hidden)) # l2 norm of speaker embedding speaker_embedding = speaker_embedding / (torch.norm(speaker_embedding, dim=1, keepdim=True) + self.epsilon) return speaker_embedding def gradient_clipping(self): self.sim_weight.grad *= 0.01 self.sim_bias.grad *= 0.01 #Pytorch to clip gradients if norm greater than max clip_grad_norm_(self.parameters(),max_norm=3,norm_type=2) def similarity_matrix(self, embeds, debug=False): # calculate s_ji,k from section 2.1 of GE2E paper # output matrix is cosine similarity between each utterance x centroid of each speaker # embeds input size: (speakers, utterances, embedding size) # Speaker centroids # Equal to average of utterance embeddings for the speaker # Used for neg examples (utterance comparing to false speaker) # Equation 1 in paper # size: (speakers, 1, embedding size) speaker_centroid = torch.mean(embeds,dim=1,keepdim=True) # Utterance exclusive centroids # Equal to average of utterance embeddings for the speaker, excluding ith utterance # Used for pos samples (utterance comparing to true speaker; speaker centroid exludes the utterance) # Equation 8 in paper # size: (speakers, utterances, embedding size) num_utterance = embeds.shape[1] utter_ex_centroid = (torch.sum(embeds,dim=1,keepdim=True) - embeds) / (num_utterance-1) if debug: print("e",embeds.shape) print(embeds) print("sc",speaker_centroid.shape) print(speaker_centroid) print("uc",utter_ex_centroid.shape) print(utter_ex_centroid) # Create pos and neg masks num_speaker = embeds.shape[0] i = torch.eye(num_speaker, dtype=torch.int) pos_mask = torch.where(i) neg_mask = torch.where(1-i) if debug: print("pm",len(pos_mask),len(pos_mask[0])) print(pos_mask) print("nm",len(neg_mask),len(neg_mask[0])) print(neg_mask) # Compile similarity matrix # size: (speakers, utterances, speakers) # initial size is (speakers, speakers, utterances for easier vectorization) sim_matrix = torch.zeros(num_speaker, num_speaker, num_utterance).to(self.loss_device) sim_matrix[pos_mask] = nn.functional.cosine_similarity(embeds,utter_ex_centroid,dim=2) sim_matrix[neg_mask] = nn.functional.cosine_similarity(embeds[neg_mask[0]],speaker_centroid[neg_mask[1]],dim=2) if debug: print("sm",sim_matrix.shape) print("pos vals",sim_matrix[pos_mask]) print("neg vals",sim_matrix[neg_mask]) print(sim_matrix) sim_matrix = sim_matrix.permute(0,2,1) if debug: print("sm",sim_matrix.shape) print(sim_matrix) print("cos sim weight", self.sim_weight) print("cos sim bias", self.sim_bias) # Apply weight / bias sim_matrix = sim_matrix * self.sim_weight + self.sim_bias return sim_matrix def softmax_loss(self, embeds): """ computes softmax loss as defined by equ 6 in the GE2E paper :param embeds: shape (speakers, utterances, embedding size) :return: computed softmax loss """ # per the GE2E paper, softmax loss as defined by equ 6 # performs slightly better over Text-Independent Speaker # Verification tasks. # ref section 2.1 of the GE2E paper speaker_count = embeds.shape[0] # speaker, utterance, speaker similarities = self.similarity_matrix(embeds) # equ 6 loss_matrix = -similarities[torch.arange(0, speaker_count), :, torch.arange(0, speaker_count)] + \ torch.log(torch.sum(torch.exp(similarities), dim=2)) # equ 10 return torch.sum(loss_matrix) def contrast_loss(self, embeds): """ computes contrast loss as defined by equ 7 in the GE2E paper :param embeds: shape (speakers, utterances, embedding size) :return: computed softmax loss """ # per the GE2E paper, contrast loss as defined by equ 7 # performs slightly better over Text-Dependent Speaker # Verification tasks. # ref section 2.1 of the GE2E paper speaker_count, utterance_count = embeds.shape[0:2] # speaker, utterance, speaker similarities = self.similarity_matrix(embeds) # Janky indexing to resolve k != j mask = torch.ones(similarities.shape, dtype=torch.bool) mask[torch.arange(speaker_count), :, torch.arange(speaker_count)] = False closest_neighbors, _ = torch.max(similarities[mask].reshape(speaker_count, utterance_count, speaker_count - 1), dim=2) # Positive influence over matching embeddings matching_embedding = similarities[torch.arange(0, speaker_count), :, torch.arange(0, speaker_count)] # equ 7 loss_matrix = 1 - torch.sigmoid(matching_embedding) + torch.sigmoid(closest_neighbors) # equ 10 return torch.sum(loss_matrix) def accuracy(self, embeds): """ computes argmax accuracy :param embeds: shape (speakers, utterances, speakers) :return: accuracy """ num_speaker, num_utter = embeds.shape[:2] similarities = self.similarity_matrix(embeds) preds = torch.argmax(similarities, dim=2) preds_one_hot = torch.nn.functional.one_hot(preds,num_classes = num_speaker) actual = torch.arange(num_speaker).unsqueeze(1).repeat(1,num_utter) actual_one_hot = torch.nn.functional.one_hot(actual,num_classes=num_speaker) return torch.sum(preds_one_hot * actual_one_hot)/(num_speaker*num_utter)