| | from speaker_encoder.params_model import * |
| | from speaker_encoder.params_data import * |
| | from scipy.interpolate import interp1d |
| | from sklearn.metrics import roc_curve |
| | from torch.nn.utils import clip_grad_norm_ |
| | from scipy.optimize import brentq |
| | from torch import nn |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | class SpeakerEncoder(nn.Module): |
| | def __init__(self, device, loss_device): |
| | super().__init__() |
| | self.loss_device = loss_device |
| | |
| | |
| | self.lstm = nn.LSTM(input_size=mel_n_channels, |
| | hidden_size=model_hidden_size, |
| | num_layers=model_num_layers, |
| | batch_first=True).to(device) |
| | self.linear = nn.Linear(in_features=model_hidden_size, |
| | out_features=model_embedding_size).to(device) |
| | self.relu = torch.nn.ReLU().to(device) |
| | |
| | |
| | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) |
| | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) |
| |
|
| | |
| | self.loss_fn = nn.CrossEntropyLoss().to(loss_device) |
| | |
| | def do_gradient_ops(self): |
| | |
| | self.similarity_weight.grad *= 0.01 |
| | self.similarity_bias.grad *= 0.01 |
| | |
| | |
| | clip_grad_norm_(self.parameters(), 3, norm_type=2) |
| | |
| | def forward(self, utterances, hidden_init=None): |
| | """ |
| | Computes the embeddings of a batch of utterance spectrograms. |
| | |
| | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape |
| | (batch_size, n_frames, n_channels) |
| | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, |
| | batch_size, hidden_size). Will default to a tensor of zeros if None. |
| | :return: the embeddings as a tensor of shape (batch_size, embedding_size) |
| | """ |
| | |
| | |
| | out, (hidden, cell) = self.lstm(utterances, hidden_init) |
| | |
| | |
| | embeds_raw = self.relu(self.linear(hidden[-1])) |
| | |
| | |
| | embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) |
| | |
| | return embeds |
| | |
| | def similarity_matrix(self, embeds): |
| | """ |
| | Computes the similarity matrix according the section 2.1 of GE2E. |
| | |
| | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, |
| | utterances_per_speaker, embedding_size) |
| | :return: the similarity matrix as a tensor of shape (speakers_per_batch, |
| | utterances_per_speaker, speakers_per_batch) |
| | """ |
| | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] |
| | |
| | |
| | centroids_incl = torch.mean(embeds, dim=1, keepdim=True) |
| | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) |
| |
|
| | |
| | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) |
| | centroids_excl /= (utterances_per_speaker - 1) |
| | centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) |
| |
|
| | |
| | |
| | |
| | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, |
| | speakers_per_batch).to(self.loss_device) |
| | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) |
| | for j in range(speakers_per_batch): |
| | mask = np.where(mask_matrix[j])[0] |
| | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) |
| | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias |
| | return sim_matrix |
| | |
| | def loss(self, embeds): |
| | """ |
| | Computes the softmax loss according the section 2.1 of GE2E. |
| | |
| | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, |
| | utterances_per_speaker, embedding_size) |
| | :return: the loss and the EER for this batch of embeddings. |
| | """ |
| | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] |
| | |
| | |
| | sim_matrix = self.similarity_matrix(embeds) |
| | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, |
| | speakers_per_batch)) |
| | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) |
| | target = torch.from_numpy(ground_truth).long().to(self.loss_device) |
| | loss = self.loss_fn(sim_matrix, target) |
| | |
| | |
| | with torch.no_grad(): |
| | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] |
| | labels = np.array([inv_argmax(i) for i in ground_truth]) |
| | preds = sim_matrix.detach().cpu().numpy() |
| |
|
| | |
| | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) |
| | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) |
| | |
| | return loss, eer |