| | import math |
| | import torch |
| | import typing as tp |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers.utils import ModelOutput |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| |
|
| | from .helpers_xvector import Fbank |
| | from .configuration_xvector import XvectorConfig |
| |
|
| |
|
| | class InputNormalization(nn.Module): |
| |
|
| | spk_dict_mean: tp.Dict[int, torch.Tensor] |
| | spk_dict_std: tp.Dict[int, torch.Tensor] |
| | spk_dict_count: tp.Dict[int, int] |
| |
|
| | def __init__( |
| | self, |
| | mean_norm=True, |
| | std_norm=True, |
| | norm_type="global", |
| | avg_factor=None, |
| | requires_grad=False, |
| | update_until_epoch=3, |
| | ): |
| | super().__init__() |
| | self.mean_norm = mean_norm |
| | self.std_norm = std_norm |
| | self.norm_type = norm_type |
| | self.avg_factor = avg_factor |
| | self.requires_grad = requires_grad |
| | self.glob_mean = torch.tensor([0]) |
| | self.glob_std = torch.tensor([0]) |
| | self.spk_dict_mean = {} |
| | self.spk_dict_std = {} |
| | self.spk_dict_count = {} |
| | self.weight = 1.0 |
| | self.count = 0 |
| | self.eps = 1e-10 |
| | self.update_until_epoch = update_until_epoch |
| |
|
| | def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0): |
| | """Returns the tensor with the surrounding context. |
| | Arguments |
| | --------- |
| | x : tensor |
| | A batch of tensors. |
| | lengths : tensor |
| | A batch of tensors containing the relative length of each |
| | sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid |
| | computing stats on zero-padded steps. |
| | spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]). |
| | It is used to perform per-speaker normalization when |
| | norm_type='speaker'. |
| | """ |
| | x = input_values |
| | N_batches = x.shape[0] |
| |
|
| | current_means = [] |
| | current_stds = [] |
| |
|
| | for snt_id in range(N_batches): |
| | |
| | |
| | |
| | |
| | actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() |
| |
|
| | |
| | current_mean, current_std = self._compute_current_stats( |
| | x[snt_id, 0:actual_size, ...] |
| | ) |
| |
|
| | current_means.append(current_mean) |
| | current_stds.append(current_std) |
| |
|
| | if self.norm_type == "sentence": |
| | x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data |
| |
|
| | if self.norm_type == "speaker": |
| | spk_id = int(spk_ids[snt_id][0]) |
| |
|
| | if self.training: |
| | if spk_id not in self.spk_dict_mean: |
| | |
| | self.spk_dict_mean[spk_id] = current_mean |
| | self.spk_dict_std[spk_id] = current_std |
| | self.spk_dict_count[spk_id] = 1 |
| |
|
| | else: |
| | self.spk_dict_count[spk_id] = ( |
| | self.spk_dict_count[spk_id] + 1 |
| | ) |
| |
|
| | if self.avg_factor is None: |
| | self.weight = 1 / self.spk_dict_count[spk_id] |
| | else: |
| | self.weight = self.avg_factor |
| |
|
| | self.spk_dict_mean[spk_id] = ( |
| | (1 - self.weight) * self.spk_dict_mean[spk_id] |
| | + self.weight * current_mean |
| | ) |
| | self.spk_dict_std[spk_id] = ( |
| | (1 - self.weight) * self.spk_dict_std[spk_id] |
| | + self.weight * current_std |
| | ) |
| |
|
| | self.spk_dict_mean[spk_id].detach() |
| | self.spk_dict_std[spk_id].detach() |
| |
|
| | speaker_mean = self.spk_dict_mean[spk_id].data |
| | speaker_std = self.spk_dict_std[spk_id].data |
| | else: |
| | if spk_id in self.spk_dict_mean: |
| | speaker_mean = self.spk_dict_mean[spk_id].data |
| | speaker_std = self.spk_dict_std[spk_id].data |
| | else: |
| | speaker_mean = current_mean.data |
| | speaker_std = current_std.data |
| |
|
| | x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std |
| |
|
| | if self.norm_type == "batch" or self.norm_type == "global": |
| | current_mean = torch.mean(torch.stack(current_means), dim=0) |
| | current_std = torch.mean(torch.stack(current_stds), dim=0) |
| |
|
| | if self.norm_type == "batch": |
| | x = (x - current_mean.data) / (current_std.data) |
| |
|
| | if self.norm_type == "global": |
| | if self.training: |
| | if self.count == 0: |
| | self.glob_mean = current_mean |
| | self.glob_std = current_std |
| |
|
| | elif epoch < self.update_until_epoch: |
| | if self.avg_factor is None: |
| | self.weight = 1 / (self.count + 1) |
| | else: |
| | self.weight = self.avg_factor |
| |
|
| | self.glob_mean = ( |
| | 1 - self.weight |
| | ) * self.glob_mean + self.weight * current_mean |
| |
|
| | self.glob_std = ( |
| | 1 - self.weight |
| | ) * self.glob_std + self.weight * current_std |
| |
|
| | self.glob_mean.detach() |
| | self.glob_std.detach() |
| |
|
| | self.count = self.count + 1 |
| |
|
| | x = (x - self.glob_mean.data) / (self.glob_std.data) |
| |
|
| | return x |
| |
|
| | def _compute_current_stats(self, x): |
| | """Returns the tensor with the surrounding context. |
| | Arguments |
| | --------- |
| | x : tensor |
| | A batch of tensors. |
| | """ |
| | |
| | if self.mean_norm: |
| | current_mean = torch.mean(x, dim=0).detach().data |
| | else: |
| | current_mean = torch.tensor([0.0], device=x.device) |
| |
|
| | |
| | if self.std_norm: |
| | current_std = torch.std(x, dim=0).detach().data |
| | else: |
| | current_std = torch.tensor([1.0], device=x.device) |
| |
|
| | |
| | current_std = torch.max( |
| | current_std, self.eps * torch.ones_like(current_std) |
| | ) |
| |
|
| | return current_mean, current_std |
| |
|
| | def _statistics_dict(self): |
| | """Fills the dictionary containing the normalization statistics.""" |
| | state = {} |
| | state["count"] = self.count |
| | state["glob_mean"] = self.glob_mean |
| | state["glob_std"] = self.glob_std |
| | state["spk_dict_mean"] = self.spk_dict_mean |
| | state["spk_dict_std"] = self.spk_dict_std |
| | state["spk_dict_count"] = self.spk_dict_count |
| |
|
| | return state |
| |
|
| | def _load_statistics_dict(self, state): |
| | """Loads the dictionary containing the statistics. |
| | Arguments |
| | --------- |
| | state : dict |
| | A dictionary containing the normalization statistics. |
| | """ |
| | self.count = state["count"] |
| | if isinstance(state["glob_mean"], int): |
| | self.glob_mean = state["glob_mean"] |
| | self.glob_std = state["glob_std"] |
| | else: |
| | self.glob_mean = state["glob_mean"] |
| | self.glob_std = state["glob_std"] |
| |
|
| | |
| | self.spk_dict_mean = {} |
| | for spk in state["spk_dict_mean"]: |
| | self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to( |
| | self.device_inp |
| | ) |
| |
|
| | |
| | self.spk_dict_std = {} |
| | for spk in state["spk_dict_std"]: |
| | self.spk_dict_std[spk] = state["spk_dict_std"][spk].to( |
| | self.device_inp |
| | ) |
| |
|
| | self.spk_dict_count = state["spk_dict_count"] |
| |
|
| | return state |
| |
|
| | def to(self, device): |
| | """Puts the needed tensors in the right device.""" |
| | self = super(InputNormalization, self).to(device) |
| | self.glob_mean = self.glob_mean.to(device) |
| | self.glob_std = self.glob_std.to(device) |
| | for spk in self.spk_dict_mean: |
| | self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) |
| | self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) |
| | return self |
| |
|
| |
|
| | class TdnnLayer(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | kernel_size, |
| | dilation=1, |
| | stride=1, |
| | padding=0, |
| | padding_mode="reflect", |
| | activation=torch.nn.LeakyReLU, |
| | ): |
| | super(TdnnLayer, self).__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.dilation = dilation |
| | self.stride = stride |
| | self.padding = padding |
| | self.padding_mode = padding_mode |
| | self.activation = activation |
| |
|
| | self.conv = nn.Conv1d( |
| | self.in_channels, |
| | self.out_channels, |
| | self.kernel_size, |
| | dilation=self.dilation, |
| | padding=self.padding |
| | ) |
| |
|
| | |
| | |
| | self.norm = nn.BatchNorm1d(out_channels, affine=False) |
| |
|
| | def forward(self, x): |
| | x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) |
| | out = self.conv(x) |
| | out = self.activation()(out) |
| | out = self.norm(out) |
| | return out |
| |
|
| | def _manage_padding( |
| | self, x, kernel_size: int, dilation: int, stride: int, |
| | ): |
| | |
| | L_in = self.in_channels |
| |
|
| | |
| | padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
| |
|
| | |
| | x = F.pad(x, padding, mode=self.padding_mode) |
| |
|
| | return x |
| |
|
| |
|
| | def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
| | """This function computes the number of elements to add for zero-padding. |
| | Arguments |
| | --------- |
| | L_in : int |
| | stride: int |
| | kernel_size : int |
| | dilation : int |
| | """ |
| | if stride > 1: |
| | padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] |
| |
|
| | else: |
| | L_out = ( |
| | math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 |
| | ) |
| | padding = [ |
| | math.floor((L_in - L_out) / 2), |
| | math.floor((L_in - L_out) / 2), |
| | ] |
| | return padding |
| |
|
| |
|
| | class StatisticsPooling(nn.Module): |
| |
|
| | def __init__(self, return_mean=True, return_std=True): |
| | super().__init__() |
| |
|
| | |
| | self.eps = 1e-5 |
| | self.return_mean = return_mean |
| | self.return_std = return_std |
| | if not (self.return_mean or self.return_std): |
| | raise ValueError( |
| | "both of statistics are equal to False \n" |
| | "consider enabling mean and/or std statistic pooling" |
| | ) |
| |
|
| | def forward(self, input_values, lengths=None): |
| | """Calculates mean and std for a batch (input tensor). |
| | Arguments |
| | --------- |
| | x : torch.Tensor |
| | It represents a tensor for a mini-batch. |
| | """ |
| | x = input_values |
| | if lengths is None: |
| | if self.return_mean: |
| | mean = x.mean(dim=1) |
| | if self.return_std: |
| | std = x.std(dim=1) |
| | else: |
| | mean = [] |
| | std = [] |
| | for snt_id in range(x.shape[0]): |
| | |
| | |
| | |
| | |
| | actual_size = int(torch.round(lengths[snt_id] * x.shape[1])) |
| |
|
| | |
| | if self.return_mean: |
| | mean.append( |
| | torch.mean(x[snt_id, 0:actual_size, ...], dim=0) |
| | ) |
| | if self.return_std: |
| | std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0)) |
| | if self.return_mean: |
| | mean = torch.stack(mean) |
| | if self.return_std: |
| | std = torch.stack(std) |
| |
|
| | if self.return_mean: |
| | gnoise = self._get_gauss_noise(mean.size(), device=mean.device) |
| | gnoise = gnoise |
| | mean += gnoise |
| | if self.return_std: |
| | std = std + self.eps |
| |
|
| | |
| | if self.return_mean and self.return_std: |
| | pooled_stats = torch.cat((mean, std), dim=1) |
| | pooled_stats = pooled_stats.unsqueeze(1) |
| | elif self.return_mean: |
| | pooled_stats = mean.unsqueeze(1) |
| | elif self.return_std: |
| | pooled_stats = std.unsqueeze(1) |
| |
|
| | return pooled_stats |
| |
|
| | def _get_gauss_noise(self, shape_of_tensor, device="cpu"): |
| | """Returns a tensor of epsilon Gaussian noise. |
| | Arguments |
| | --------- |
| | shape_of_tensor : tensor |
| | It represents the size of tensor for generating Gaussian noise. |
| | """ |
| | gnoise = torch.randn(shape_of_tensor, device=device) |
| | gnoise -= torch.min(gnoise) |
| | gnoise /= torch.max(gnoise) |
| | gnoise = self.eps * ((1 - 9) * gnoise + 9) |
| |
|
| | return gnoise |
| |
|
| |
|
| | class XvectorEmbedder(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_channels=40, |
| | activation=torch.nn.LeakyReLU, |
| | tdnn_blocks=5, |
| | tdnn_channels=[512, 512, 512, 512, 1500], |
| | tdnn_kernel_sizes=[5, 3, 3, 1, 1], |
| | tdnn_dilations=[1, 2, 3, 1, 1], |
| | hidden_size=512, |
| | ) -> None: |
| | super(XvectorEmbedder, self).__init__() |
| | self.activation = activation |
| | self.blocks = nn.ModuleList() |
| | for block_index in range(tdnn_blocks): |
| | out_channels = tdnn_channels[block_index] |
| | tdnn = TdnnLayer( |
| | in_channels, |
| | out_channels, |
| | kernel_size=tdnn_kernel_sizes[block_index], |
| | dilation=tdnn_dilations[block_index], |
| | activation=activation, |
| | ) |
| | self.blocks.append(tdnn) |
| | in_channels = tdnn_channels[block_index] |
| | self.pooler = StatisticsPooling() |
| | self.fc = nn.Linear(2 * out_channels, hidden_size) |
| |
|
| | def forward(self, input_values, lengths=None): |
| | x = input_values |
| | x = x.permute(0, 2, 1) |
| | for block in self.blocks: |
| | x = block(x) |
| | last_hidden_state = x.permute(0, 2, 1) |
| | pooler_output = self.pooler(last_hidden_state, lengths) |
| | pooler_output = self.fc(pooler_output.squeeze(1)) |
| | return ModelOutput( |
| | last_hidden_state=last_hidden_state, |
| | pooler_output=pooler_output |
| | ) |
| |
|
| |
|
| | class CosineSimilarityHead(torch.nn.Module): |
| | """ |
| | This class implements the cosine similarity on the top of features. |
| | """ |
| | def __init__( |
| | self, |
| | in_channels, |
| | lin_blocks=0, |
| | hidden_size=192, |
| | num_classes=1211, |
| | ): |
| | super().__init__() |
| | self.blocks = nn.ModuleList() |
| |
|
| | for block_index in range(lin_blocks): |
| | self.blocks.extend( |
| | [ |
| | nn.BatchNorm1d(num_features=in_channels), |
| | nn.Linear(in_features=in_channels, out_features=hidden_size), |
| | ] |
| | ) |
| | in_channels = hidden_size |
| |
|
| | |
| | self.weight = nn.Parameter( |
| | torch.FloatTensor(num_classes, in_channels) |
| | ) |
| | nn.init.xavier_uniform_(self.weight) |
| |
|
| | def forward(self, x): |
| | """Returns the output probabilities over speakers. |
| | Arguments |
| | --------- |
| | x : torch.Tensor |
| | Torch tensor. |
| | """ |
| | for layer in self.blocks: |
| | x = layer(x) |
| |
|
| | |
| | x = F.linear(F.normalize(x), F.normalize(self.weight)) |
| | return x |
| |
|
| |
|
| | class XvectorPreTrainedModel(PreTrainedModel): |
| |
|
| | config_class = XvectorConfig |
| | base_model_prefix = "xvector" |
| | main_input_name = "input_values" |
| | supports_gradient_checkpointing = True |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | elif isinstance(module, nn.Conv1d): |
| | nn.init.kaiming_normal_(module.weight.data) |
| |
|
| | if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| |
|
| | class XvectorModel(XvectorPreTrainedModel): |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.compute_features = Fbank( |
| | n_mels=config.n_mels, |
| | sample_rate=config.sample_rate, |
| | win_length=config.win_length, |
| | hop_length=config.hop_length, |
| | ) |
| | self.mean_var_norm = InputNormalization( |
| | mean_norm=config.mean_norm, |
| | std_norm=config.std_norm, |
| | norm_type=config.norm_type |
| | ) |
| | self.embedding_model = XvectorEmbedder( |
| | in_channels=config.n_mels, |
| | activation=nn.LeakyReLU, |
| | tdnn_blocks=config.tdnn_blocks, |
| | tdnn_channels=config.tdnn_channels, |
| | tdnn_kernel_sizes=config.tdnn_kernel_sizes, |
| | tdnn_dilations=config.tdnn_dilations, |
| | hidden_size=config.hidden_size, |
| | ) |
| |
|
| | def forward(self, input_values, lengths=None): |
| | x = input_values |
| | |
| | |
| | x = self.compute_features(x) |
| | x = self.mean_var_norm(x, lengths) |
| | output = self.embedding_model(x, lengths) |
| | return output |
| |
|