| import math |
| import torch |
| import typing as tp |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers.utils import ModelOutput |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| from .helpers_ecapa import Fbank |
| from .configuration_ecapa import EcapaConfig |
|
|
|
|
| class InputNormalization(nn.Module): |
|
|
| spk_dict_mean: tp.Dict[int, torch.Tensor] |
| spk_dict_std: tp.Dict[int, torch.Tensor] |
| spk_dict_count: tp.Dict[int, int] |
|
|
| def __init__( |
| self, |
| mean_norm=True, |
| std_norm=True, |
| norm_type="global", |
| avg_factor=None, |
| requires_grad=False, |
| update_until_epoch=3, |
| ): |
| super().__init__() |
| self.mean_norm = mean_norm |
| self.std_norm = std_norm |
| self.norm_type = norm_type |
| self.avg_factor = avg_factor |
| self.requires_grad = requires_grad |
| self.glob_mean = torch.tensor([0]) |
| self.glob_std = torch.tensor([0]) |
| self.spk_dict_mean = {} |
| self.spk_dict_std = {} |
| self.spk_dict_count = {} |
| self.weight = 1.0 |
| self.count = 0 |
| self.eps = 1e-10 |
| self.update_until_epoch = update_until_epoch |
|
|
| def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0): |
| """Returns the tensor with the surrounding context. |
| |
| Arguments |
| --------- |
| x : tensor |
| A batch of tensors. |
| lengths : tensor |
| A batch of tensors containing the relative length of each |
| sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid |
| computing stats on zero-padded steps. |
| spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]). |
| It is used to perform per-speaker normalization when |
| norm_type='speaker'. |
| """ |
| x = input_values |
| N_batches = x.shape[0] |
|
|
| current_means = [] |
| current_stds = [] |
|
|
| for snt_id in range(N_batches): |
| |
| |
| |
| |
| actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() |
|
|
| |
| current_mean, current_std = self._compute_current_stats( |
| x[snt_id, 0:actual_size, ...] |
| ) |
|
|
| current_means.append(current_mean) |
| current_stds.append(current_std) |
|
|
| if self.norm_type == "sentence": |
| x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data |
|
|
| if self.norm_type == "speaker": |
| spk_id = int(spk_ids[snt_id][0]) |
|
|
| if self.training: |
| if spk_id not in self.spk_dict_mean: |
| |
| self.spk_dict_mean[spk_id] = current_mean |
| self.spk_dict_std[spk_id] = current_std |
| self.spk_dict_count[spk_id] = 1 |
|
|
| else: |
| self.spk_dict_count[spk_id] = ( |
| self.spk_dict_count[spk_id] + 1 |
| ) |
|
|
| if self.avg_factor is None: |
| self.weight = 1 / self.spk_dict_count[spk_id] |
| else: |
| self.weight = self.avg_factor |
|
|
| self.spk_dict_mean[spk_id] = ( |
| (1 - self.weight) * self.spk_dict_mean[spk_id] |
| + self.weight * current_mean |
| ) |
| self.spk_dict_std[spk_id] = ( |
| (1 - self.weight) * self.spk_dict_std[spk_id] |
| + self.weight * current_std |
| ) |
|
|
| self.spk_dict_mean[spk_id].detach() |
| self.spk_dict_std[spk_id].detach() |
|
|
| speaker_mean = self.spk_dict_mean[spk_id].data |
| speaker_std = self.spk_dict_std[spk_id].data |
| else: |
| if spk_id in self.spk_dict_mean: |
| speaker_mean = self.spk_dict_mean[spk_id].data |
| speaker_std = self.spk_dict_std[spk_id].data |
| else: |
| speaker_mean = current_mean.data |
| speaker_std = current_std.data |
|
|
| x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std |
|
|
| if self.norm_type == "batch" or self.norm_type == "global": |
| current_mean = torch.mean(torch.stack(current_means), dim=0) |
| current_std = torch.mean(torch.stack(current_stds), dim=0) |
|
|
| if self.norm_type == "batch": |
| x = (x - current_mean.data) / (current_std.data) |
|
|
| if self.norm_type == "global": |
| if self.training: |
| if self.count == 0: |
| self.glob_mean = current_mean |
| self.glob_std = current_std |
|
|
| elif epoch < self.update_until_epoch: |
| if self.avg_factor is None: |
| self.weight = 1 / (self.count + 1) |
| else: |
| self.weight = self.avg_factor |
|
|
| self.glob_mean = ( |
| 1 - self.weight |
| ) * self.glob_mean + self.weight * current_mean |
|
|
| self.glob_std = ( |
| 1 - self.weight |
| ) * self.glob_std + self.weight * current_std |
|
|
| self.glob_mean.detach() |
| self.glob_std.detach() |
|
|
| self.count = self.count + 1 |
|
|
| x = (x - self.glob_mean.data) / (self.glob_std.data) |
|
|
| return x |
|
|
| def _compute_current_stats(self, x): |
| """Returns the tensor with the surrounding context. |
| |
| Arguments |
| --------- |
| x : tensor |
| A batch of tensors. |
| """ |
| |
| if self.mean_norm: |
| current_mean = torch.mean(x, dim=0).detach().data |
| else: |
| current_mean = torch.tensor([0.0], device=x.device) |
|
|
| |
| if self.std_norm: |
| current_std = torch.std(x, dim=0).detach().data |
| else: |
| current_std = torch.tensor([1.0], device=x.device) |
|
|
| |
| current_std = torch.max( |
| current_std, self.eps * torch.ones_like(current_std) |
| ) |
|
|
| return current_mean, current_std |
|
|
| def _statistics_dict(self): |
| """Fills the dictionary containing the normalization statistics.""" |
| state = {} |
| state["count"] = self.count |
| state["glob_mean"] = self.glob_mean |
| state["glob_std"] = self.glob_std |
| state["spk_dict_mean"] = self.spk_dict_mean |
| state["spk_dict_std"] = self.spk_dict_std |
| state["spk_dict_count"] = self.spk_dict_count |
|
|
| return state |
|
|
| def _load_statistics_dict(self, state): |
| """Loads the dictionary containing the statistics. |
| |
| Arguments |
| --------- |
| state : dict |
| A dictionary containing the normalization statistics. |
| """ |
| self.count = state["count"] |
| if isinstance(state["glob_mean"], int): |
| self.glob_mean = state["glob_mean"] |
| self.glob_std = state["glob_std"] |
| else: |
| self.glob_mean = state["glob_mean"] |
| self.glob_std = state["glob_std"] |
|
|
| |
| self.spk_dict_mean = {} |
| for spk in state["spk_dict_mean"]: |
| self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to( |
| self.device_inp |
| ) |
|
|
| |
| self.spk_dict_std = {} |
| for spk in state["spk_dict_std"]: |
| self.spk_dict_std[spk] = state["spk_dict_std"][spk].to( |
| self.device_inp |
| ) |
|
|
| self.spk_dict_count = state["spk_dict_count"] |
|
|
| return state |
|
|
| def to(self, device): |
| """Puts the needed tensors in the right device.""" |
| self = super(InputNormalization, self).to(device) |
| self.glob_mean = self.glob_mean.to(device) |
| self.glob_std = self.glob_std.to(device) |
| for spk in self.spk_dict_mean: |
| self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) |
| self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) |
| return self |
|
|
|
|
| class TdnnLayer(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| dilation=1, |
| stride=1, |
| groups=1, |
| padding=0, |
| padding_mode="reflect", |
| activation=torch.nn.LeakyReLU, |
| ): |
| super(TdnnLayer, self).__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.dilation = dilation |
| self.stride = stride |
| self.groups = groups |
| self.padding = padding |
| self.padding_mode = padding_mode |
| self.activation = activation() |
|
|
| self.conv = nn.Conv1d( |
| self.in_channels, |
| self.out_channels, |
| self.kernel_size, |
| dilation=self.dilation, |
| padding=self.padding, |
| groups=self.groups |
| ) |
|
|
| |
| |
| self.norm = nn.BatchNorm1d(out_channels, affine=False) |
|
|
| def forward(self, x): |
| x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) |
| out = self.conv(x) |
| out = self.activation(out) |
| out = self.norm(out) |
| return out |
|
|
| def _manage_padding( |
| self, x, kernel_size: int, dilation: int, stride: int, |
| ): |
| |
| L_in = self.in_channels |
|
|
| |
| padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
| |
| x = F.pad(x, padding, mode=self.padding_mode) |
|
|
| return x |
|
|
|
|
| def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
| """This function computes the number of elements to add for zero-padding. |
| |
| Arguments |
| --------- |
| L_in : int |
| stride: int |
| kernel_size : int |
| dilation : int |
| """ |
| if stride > 1: |
| padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] |
|
|
| else: |
| L_out = ( |
| math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 |
| ) |
| padding = [ |
| math.floor((L_in - L_out) / 2), |
| math.floor((L_in - L_out) / 2), |
| ] |
| return padding |
|
|
|
|
| class Res2NetBlock(torch.nn.Module): |
| """An implementation of Res2NetBlock w/ dilation. |
| |
| Arguments |
| --------- |
| in_channels : int |
| The number of channels expected in the input. |
| out_channels : int |
| The number of output channels. |
| scale : int |
| The scale of the Res2Net block. |
| kernel_size: int |
| The kernel size of the Res2Net block. |
| dilation : int |
| The dilation of the Res2Net block. |
| |
| Example |
| ------- |
| >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
| >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) |
| >>> out_tensor = layer(inp_tensor).transpose(1, 2) |
| >>> out_tensor.shape |
| torch.Size([8, 120, 64]) |
| """ |
|
|
| def __init__( |
| self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1 |
| ): |
| super(Res2NetBlock, self).__init__() |
| assert in_channels % scale == 0 |
| assert out_channels % scale == 0 |
|
|
| in_channel = in_channels // scale |
| hidden_channel = out_channels // scale |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| TdnnLayer( |
| in_channel, |
| hidden_channel, |
| kernel_size=kernel_size, |
| dilation=dilation, |
| ) |
| for _ in range(scale - 1) |
| ] |
| ) |
| self.scale = scale |
|
|
| def forward(self, x): |
| """Processes the input tensor x and returns an output tensor.""" |
| y = [] |
| for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): |
| if i == 0: |
| y_i = x_i |
| elif i == 1: |
| y_i = self.blocks[i - 1](x_i) |
| else: |
| y_i = self.blocks[i - 1](x_i + y_i) |
| y.append(y_i) |
| y = torch.cat(y, dim=1) |
| return y |
|
|
|
|
| class SEBlock(nn.Module): |
| """An implementation of squeeze-and-excitation block. |
| |
| Arguments |
| --------- |
| in_channels : int |
| The number of input channels. |
| se_channels : int |
| The number of output channels after squeeze. |
| out_channels : int |
| The number of output channels. |
| |
| Example |
| ------- |
| >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
| >>> se_layer = SEBlock(64, 16, 64) |
| >>> lengths = torch.rand((8,)) |
| >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2) |
| >>> out_tensor.shape |
| torch.Size([8, 120, 64]) |
| """ |
|
|
| def __init__(self, in_channels, se_channels, out_channels): |
| super(SEBlock, self).__init__() |
|
|
| self.conv1 = nn.Conv1d( |
| in_channels=in_channels, out_channels=se_channels, kernel_size=1 |
| ) |
| self.relu = torch.nn.ReLU(inplace=True) |
| self.conv2 = nn.Conv1d( |
| in_channels=se_channels, out_channels=out_channels, kernel_size=1 |
| ) |
| self.sigmoid = torch.nn.Sigmoid() |
|
|
| def forward(self, x, lengths=None): |
| """Processes the input tensor x and returns an output tensor.""" |
| L = x.shape[-1] |
| if lengths is not None: |
| mask = length_to_mask(lengths * L, max_len=L, device=x.device) |
| mask = mask.unsqueeze(1) |
| total = mask.sum(dim=2, keepdim=True) |
| s = (x * mask).sum(dim=2, keepdim=True) / total |
| else: |
| s = x.mean(dim=2, keepdim=True) |
|
|
| s = self.relu(self.conv1(s)) |
| s = self.sigmoid(self.conv2(s)) |
|
|
| return s * x |
|
|
|
|
| def length_to_mask(length, max_len=None, dtype=None, device=None): |
| """Creates a binary mask for each sequence. |
| |
| Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 |
| |
| Arguments |
| --------- |
| length : torch.LongTensor |
| Containing the length of each sequence in the batch. Must be 1D. |
| max_len : int |
| Max length for the mask, also the size of the second dimension. |
| dtype : torch.dtype, default: None |
| The dtype of the generated mask. |
| device: torch.device, default: None |
| The device to put the mask variable. |
| |
| Returns |
| ------- |
| mask : tensor |
| The binary mask. |
| |
| Example |
| ------- |
| >>> length=torch.Tensor([1,2,3]) |
| >>> mask=length_to_mask(length) |
| >>> mask |
| tensor([[1., 0., 0.], |
| [1., 1., 0.], |
| [1., 1., 1.]]) |
| """ |
| assert len(length.shape) == 1 |
|
|
| if max_len is None: |
| max_len = length.max().long().item() |
| mask = torch.arange( |
| max_len, device=length.device, dtype=length.dtype |
| ).expand(len(length), max_len) < length.unsqueeze(1) |
|
|
| if dtype is None: |
| dtype = length.dtype |
|
|
| if device is None: |
| device = length.device |
|
|
| mask = torch.as_tensor(mask, dtype=dtype, device=device) |
| return mask |
|
|
|
|
| class AttentiveStatisticsPooling(nn.Module): |
| """This class implements an attentive statistic pooling layer for each channel. |
| It returns the concatenated mean and std of the input tensor. |
| |
| Arguments |
| --------- |
| channels: int |
| The number of input channels. |
| attention_channels: int |
| The number of attention channels. |
| |
| Example |
| ------- |
| >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
| >>> asp_layer = AttentiveStatisticsPooling(64) |
| >>> lengths = torch.rand((8,)) |
| >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2) |
| >>> out_tensor.shape |
| torch.Size([8, 1, 128]) |
| """ |
|
|
| def __init__(self, channels, attention_channels=128, global_context=True): |
| super().__init__() |
|
|
| self.eps = 1e-12 |
| self.global_context = global_context |
| if global_context: |
| self.tdnn = TdnnLayer(channels * 3, attention_channels, 1, 1) |
| else: |
| self.tdnn = TdnnLayer(channels, attention_channels, 1, 1) |
| self.tanh = nn.Tanh() |
| self.conv = nn.Conv1d( |
| in_channels=attention_channels, out_channels=channels, kernel_size=1 |
| ) |
|
|
| def forward(self, x, lengths=None): |
| """Calculates mean and std for a batch (input tensor). |
| |
| Arguments |
| --------- |
| x : torch.Tensor |
| Tensor of shape [N, C, L]. |
| """ |
| L = x.shape[-1] |
|
|
| def _compute_statistics(x, m, dim=2, eps=self.eps): |
| mean = (m * x).sum(dim) |
| std = torch.sqrt( |
| (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps) |
| ) |
| return mean, std |
|
|
| if lengths is None: |
| lengths = torch.ones(x.shape[0], device=x.device) |
|
|
| |
| mask = length_to_mask(lengths * L, max_len=L, device=x.device) |
| mask = mask.unsqueeze(1) |
|
|
| |
| |
| if self.global_context: |
| |
| |
| total = mask.sum(dim=2, keepdim=True).float() |
| mean, std = _compute_statistics(x, mask / total) |
| mean = mean.unsqueeze(2).repeat(1, 1, L) |
| std = std.unsqueeze(2).repeat(1, 1, L) |
| attn = torch.cat([x, mean, std], dim=1) |
| else: |
| attn = x |
|
|
| |
| attn = self.conv(self.tanh(self.tdnn(attn))) |
|
|
| |
| attn = attn.masked_fill(mask == 0, float("-inf")) |
|
|
| attn = F.softmax(attn, dim=2) |
| mean, std = _compute_statistics(x, attn) |
| |
| pooled_stats = torch.cat((mean, std), dim=1) |
| pooled_stats = pooled_stats.unsqueeze(2) |
|
|
| return pooled_stats |
|
|
|
|
|
|
| class SERes2NetBlock(nn.Module): |
| """An implementation of building block in ECAPA-TDNN, i.e., |
| TDNN-Res2Net-TDNN-SEBlock. |
| |
| Arguments |
| ---------- |
| out_channels: int |
| The number of output channels. |
| res2net_scale: int |
| The scale of the Res2Net block. |
| kernel_size: int |
| The kernel size of the TDNN blocks. |
| dilation: int |
| The dilation of the Res2Net block. |
| activation : torch class |
| A class for constructing the activation layers. |
| groups: int |
| Number of blocked connections from input channels to output channels. |
| |
| Example |
| ------- |
| >>> x = torch.rand(8, 120, 64).transpose(1, 2) |
| >>> conv = SERes2NetBlock(64, 64, res2net_scale=4) |
| >>> out = conv(x).transpose(1, 2) |
| >>> out.shape |
| torch.Size([8, 120, 64]) |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| res2net_scale=8, |
| se_channels=128, |
| kernel_size=1, |
| dilation=1, |
| activation=torch.nn.ReLU, |
| groups=1, |
| ): |
| super().__init__() |
| self.out_channels = out_channels |
| self.tdnn1 = TdnnLayer( |
| in_channels, |
| out_channels, |
| kernel_size=1, |
| dilation=1, |
| activation=activation, |
| groups=groups, |
| ) |
| self.res2net_block = Res2NetBlock( |
| out_channels, out_channels, res2net_scale, kernel_size, dilation |
| ) |
| self.tdnn2 = TdnnLayer( |
| out_channels, |
| out_channels, |
| kernel_size=1, |
| dilation=1, |
| activation=activation, |
| groups=groups, |
| ) |
| self.se_block = SEBlock(out_channels, se_channels, out_channels) |
|
|
| self.shortcut = None |
| if in_channels != out_channels: |
| self.shortcut = nn.Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=1, |
| ) |
|
|
| def forward(self, x, lengths=None): |
| """Processes the input tensor x and returns an output tensor.""" |
| residual = x |
| if self.shortcut: |
| residual = self.shortcut(x) |
|
|
| x = self.tdnn1(x) |
| x = self.res2net_block(x) |
| x = self.tdnn2(x) |
| x = self.se_block(x, lengths) |
|
|
| return x + residual |
|
|
|
|
| class EcapaEmbedder(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels=80, |
| hidden_size=192, |
| activation=torch.nn.ReLU, |
| channels=[512, 512, 512, 512, 1536], |
| kernel_sizes=[5, 3, 3, 3, 1], |
| dilations=[1, 2, 3, 4, 1], |
| attention_channels=128, |
| res2net_scale=8, |
| se_channels=128, |
| global_context=True, |
| groups=[1, 1, 1, 1, 1], |
| ) -> None: |
| super(EcapaEmbedder, self).__init__() |
| self.channels = channels |
| self.blocks = nn.ModuleList() |
|
|
| |
| self.blocks.append( |
| TdnnLayer( |
| in_channels, |
| channels[0], |
| kernel_sizes[0], |
| dilations[0], |
| activation=activation, |
| groups=groups[0], |
| ) |
| ) |
|
|
| |
| for i in range(1, len(channels) - 1): |
| self.blocks.append( |
| SERes2NetBlock( |
| channels[i - 1], |
| channels[i], |
| res2net_scale=res2net_scale, |
| se_channels=se_channels, |
| kernel_size=kernel_sizes[i], |
| dilation=dilations[i], |
| activation=activation, |
| groups=groups[i], |
| ) |
| ) |
|
|
| |
| self.mfa = TdnnLayer( |
| channels[-2] * (len(channels) - 2), |
| channels[-1], |
| kernel_sizes[-1], |
| dilations[-1], |
| activation=activation, |
| groups=groups[-1], |
| ) |
|
|
| |
| self.asp = AttentiveStatisticsPooling( |
| channels[-1], |
| attention_channels=attention_channels, |
| global_context=global_context, |
| ) |
| self.asp_bn = nn.BatchNorm1d(channels[-1] * 2) |
|
|
| |
| self.fc = nn.Conv1d( |
| in_channels=channels[-1] * 2, |
| out_channels=hidden_size, |
| kernel_size=1, |
| ) |
|
|
| def forward(self, input_values, lengths=None): |
| |
| x = input_values.transpose(1, 2) |
| |
| |
|
|
| xl = [] |
| for layer in self.blocks: |
| try: |
| x = layer(x, lengths) |
| except TypeError: |
| x = layer(x) |
| xl.append(x) |
|
|
| |
| x = torch.cat(xl[1:], dim=1) |
| x = self.mfa(x) |
|
|
| |
| x = self.asp(x, lengths) |
| x = self.asp_bn(x) |
|
|
| |
| x = self.fc(x) |
|
|
| pooler_output = x.transpose(1, 2) |
| pooler_output = pooler_output.squeeze(1) |
| return ModelOutput( |
| |
| pooler_output=pooler_output |
| ) |
|
|
|
|
| class CosineSimilarityHead(torch.nn.Module): |
| """ |
| This class implements the cosine similarity on the top of features. |
| """ |
| def __init__( |
| self, |
| in_channels, |
| lin_blocks=0, |
| hidden_size=192, |
| num_classes=1211, |
| ): |
| super().__init__() |
| self.blocks = nn.ModuleList() |
|
|
| for block_index in range(lin_blocks): |
| self.blocks.extend( |
| [ |
| nn.BatchNorm1d(num_features=in_channels), |
| nn.Linear(in_features=in_channels, out_features=hidden_size), |
| ] |
| ) |
| in_channels = hidden_size |
|
|
| |
| self.weight = nn.Parameter( |
| torch.FloatTensor(num_classes, in_channels) |
| ) |
| nn.init.xavier_uniform_(self.weight) |
|
|
| def forward(self, x): |
| """Returns the output probabilities over speakers. |
| |
| Arguments |
| --------- |
| x : torch.Tensor |
| Torch tensor. |
| """ |
| for layer in self.blocks: |
| x = layer(x) |
|
|
| |
| x = F.linear(F.normalize(x), F.normalize(self.weight)) |
| return x |
|
|
|
|
| class EcapaPreTrainedModel(PreTrainedModel): |
|
|
| config_class = EcapaConfig |
| base_model_prefix = "ecapa" |
| main_input_name = "input_values" |
| supports_gradient_checkpointing = True |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, nn.Linear): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| elif isinstance(module, nn.Conv1d): |
| nn.init.kaiming_normal_(module.weight.data) |
|
|
| if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| class EcapaModel(EcapaPreTrainedModel): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.compute_features = Fbank( |
| n_mels=config.n_mels, |
| sample_rate=config.sample_rate, |
| win_length=config.win_length, |
| hop_length=config.hop_length, |
| ) |
| self.mean_var_norm = InputNormalization( |
| mean_norm=config.mean_norm, |
| std_norm=config.std_norm, |
| norm_type=config.norm_type |
| ) |
| self.embedding_model = EcapaEmbedder( |
| in_channels=config.n_mels, |
| channels=config.channels, |
| kernel_sizes=config.kernel_sizes, |
| dilations=config.dilations, |
| attention_channels=config.attention_channels, |
| res2net_scale=config.res2net_scale, |
| se_channels=config.se_channels, |
| global_context=config.global_context, |
| groups=config.groups, |
| hidden_size=config.hidden_size |
| ) |
|
|
| def forward(self, input_values, lengths=None): |
| x = input_values |
| |
| |
| x = self.compute_features(x) |
| x = self.mean_var_norm(x, lengths) |
| output = self.embedding_model(x, lengths) |
| return ModelOutput( |
| pooler_output=output.pooler_output, |
| ) |
|
|
|
|