host / speaker /model.py
lord-reso's picture
Update speaker/model.py
43db691 verified
from torch import nn
import numpy as np
import torch
from torch.nn.utils import clip_grad_norm_
class SpeakerEncoder(nn.Module):
""" Learn speaker representation from speech utterance of arbitrary lengths.
"""
def __init__(self, device, loss_device):
super().__init__()
self.loss_device = loss_device
# lstm block consisting of 3 layers
# takes input 80 channel log-mel spectrograms, projected to 256 dimensions
self.lstm = nn.LSTM(
input_size=80,
hidden_size=256,
num_layers=3,
batch_first=True,
dropout=0,
bidirectional=False
).to(device)
self.linear = nn.Linear(in_features=256, out_features=256).to(device)
self.relu = nn.ReLU().to(device)
# epsilon term for numerical stability ( ie - division by 0)
self.epsilon = 1e-5
#Cosine similarity weights
self.sim_weight = nn.Parameter(torch.tensor([5.])).to(loss_device)
self.sim_bias = nn.Parameter(torch.tensor([-1.])).to(loss_device)
def forward(self, utterances, h_init=None, c_init=None):
# implement section 2.1 from https://arxiv.org/pdf/1806.04558.pdf
if h_init is None or c_init is None:
out, (hidden, cell) = self.lstm(utterances)
else:
out, (hidden, cell) = self.lstm(utterances, (h_init, c_init))
# compute speaker embedding from hidden state of final layer
final_hidden = hidden[-1]
speaker_embedding = self.relu(self.linear(final_hidden))
# l2 norm of speaker embedding
speaker_embedding = speaker_embedding / (torch.norm(speaker_embedding, dim=1, keepdim=True) + self.epsilon)
return speaker_embedding
def gradient_clipping(self):
self.sim_weight.grad *= 0.01
self.sim_bias.grad *= 0.01
#Pytorch to clip gradients if norm greater than max
clip_grad_norm_(self.parameters(),max_norm=3,norm_type=2)
def similarity_matrix(self, embeds, debug=False):
# calculate s_ji,k from section 2.1 of GE2E paper
# output matrix is cosine similarity between each utterance x centroid of each speaker
# embeds input size: (speakers, utterances, embedding size)
# Speaker centroids
# Equal to average of utterance embeddings for the speaker
# Used for neg examples (utterance comparing to false speaker)
# Equation 1 in paper
# size: (speakers, 1, embedding size)
speaker_centroid = torch.mean(embeds,dim=1,keepdim=True)
# Utterance exclusive centroids
# Equal to average of utterance embeddings for the speaker, excluding ith utterance
# Used for pos samples (utterance comparing to true speaker; speaker centroid exludes the utterance)
# Equation 8 in paper
# size: (speakers, utterances, embedding size)
num_utterance = embeds.shape[1]
utter_ex_centroid = (torch.sum(embeds,dim=1,keepdim=True) - embeds) / (num_utterance-1)
if debug:
print("e",embeds.shape)
print(embeds)
print("sc",speaker_centroid.shape)
print(speaker_centroid)
print("uc",utter_ex_centroid.shape)
print(utter_ex_centroid)
# Create pos and neg masks
num_speaker = embeds.shape[0]
i = torch.eye(num_speaker, dtype=torch.int)
pos_mask = torch.where(i)
neg_mask = torch.where(1-i)
if debug:
print("pm",len(pos_mask),len(pos_mask[0]))
print(pos_mask)
print("nm",len(neg_mask),len(neg_mask[0]))
print(neg_mask)
# Compile similarity matrix
# size: (speakers, utterances, speakers)
# initial size is (speakers, speakers, utterances for easier vectorization)
sim_matrix = torch.zeros(num_speaker, num_speaker, num_utterance).to(self.loss_device)
sim_matrix[pos_mask] = nn.functional.cosine_similarity(embeds,utter_ex_centroid,dim=2)
sim_matrix[neg_mask] = nn.functional.cosine_similarity(embeds[neg_mask[0]],speaker_centroid[neg_mask[1]],dim=2)
if debug:
print("sm",sim_matrix.shape)
print("pos vals",sim_matrix[pos_mask])
print("neg vals",sim_matrix[neg_mask])
print(sim_matrix)
sim_matrix = sim_matrix.permute(0,2,1)
if debug:
print("sm",sim_matrix.shape)
print(sim_matrix)
print("cos sim weight", self.sim_weight)
print("cos sim bias", self.sim_bias)
# Apply weight / bias
sim_matrix = sim_matrix * self.sim_weight + self.sim_bias
return sim_matrix
def softmax_loss(self, embeds):
"""
computes softmax loss as defined by equ 6 in the GE2E paper
:param embeds: shape (speakers, utterances, embedding size)
:return: computed softmax loss
"""
# per the GE2E paper, softmax loss as defined by equ 6
# performs slightly better over Text-Independent Speaker
# Verification tasks.
# ref section 2.1 of the GE2E paper
speaker_count = embeds.shape[0]
# speaker, utterance, speaker
similarities = self.similarity_matrix(embeds)
# equ 6
loss_matrix = -similarities[torch.arange(0, speaker_count), :, torch.arange(0, speaker_count)] + \
torch.log(torch.sum(torch.exp(similarities), dim=2))
# equ 10
return torch.sum(loss_matrix)
def contrast_loss(self, embeds):
"""
computes contrast loss as defined by equ 7 in the GE2E paper
:param embeds: shape (speakers, utterances, embedding size)
:return: computed softmax loss
"""
# per the GE2E paper, contrast loss as defined by equ 7
# performs slightly better over Text-Dependent Speaker
# Verification tasks.
# ref section 2.1 of the GE2E paper
speaker_count, utterance_count = embeds.shape[0:2]
# speaker, utterance, speaker
similarities = self.similarity_matrix(embeds)
# Janky indexing to resolve k != j
mask = torch.ones(similarities.shape, dtype=torch.bool)
mask[torch.arange(speaker_count), :, torch.arange(speaker_count)] = False
closest_neighbors, _ = torch.max(similarities[mask].reshape(speaker_count, utterance_count, speaker_count - 1), dim=2)
# Positive influence over matching embeddings
matching_embedding = similarities[torch.arange(0, speaker_count), :, torch.arange(0, speaker_count)]
# equ 7
loss_matrix = 1 - torch.sigmoid(matching_embedding) + torch.sigmoid(closest_neighbors)
# equ 10
return torch.sum(loss_matrix)
def accuracy(self, embeds):
"""
computes argmax accuracy
:param embeds: shape (speakers, utterances, speakers)
:return: accuracy
"""
num_speaker, num_utter = embeds.shape[:2]
similarities = self.similarity_matrix(embeds)
preds = torch.argmax(similarities, dim=2)
preds_one_hot = torch.nn.functional.one_hot(preds,num_classes = num_speaker)
actual = torch.arange(num_speaker).unsqueeze(1).repeat(1,num_utter)
actual_one_hot = torch.nn.functional.one_hot(actual,num_classes=num_speaker)
return torch.sum(preds_one_hot * actual_one_hot)/(num_speaker*num_utter)