Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.models.speaker.yvector.tdnn import TDNNLayer | |
| from src.models.speaker.yvector.wav2spk import ( | |
| ConvFeatureExtractionModel, Fp32GroupNorm, norm_block | |
| ) | |
| import numpy as np | |
| ################################################################################ | |
| # Y-Vector implementation of Zhu et al. (2021) | |
| ################################################################################ | |
| class SEBlock(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.fgate = nn.Sequential(nn.Linear(channels, channels), nn.Sigmoid()) | |
| self.tgate = nn.Sequential(nn.Linear(channels, 1), nn.Sigmoid()) | |
| def forward(self, x): | |
| fg = self.fgate(x.mean(dim=-1)) | |
| x = x * fg.unsqueeze(-1) | |
| tg = x.permute(0, 2, 1).contiguous().view(-1, x.shape[1]) | |
| tg = self.tgate(tg).view(x.shape[0], x.shape[2]).unsqueeze(1) | |
| out = x * tg | |
| return out | |
| class MultiScaleConvFeatureExtractionModel(nn.Module): | |
| def __init__( | |
| self, | |
| dropout=0.0, | |
| non_affine_group_norm=False, | |
| activation=nn.ReLU(),): | |
| super().__init__() | |
| def block(n_in, n_out, k, stride, padding=0): | |
| return nn.Sequential( | |
| nn.Conv1d(n_in, n_out, k, stride=stride, bias=False, padding=padding), | |
| nn.Dropout(p=dropout), | |
| norm_block(is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm, | |
| is_instance_norm=True), | |
| activation) | |
| self.conv_front = nn.ModuleList() | |
| # multi-3: s=18 | |
| self.conv_front.append(nn.Sequential(block(1, 90, 36, 18, 0), block(90, 192, 5, 1, 2))) | |
| self.conv_front.append(nn.Sequential(block(1, 90, 18, 9, 0), block(90, 160, 5, 2, 0))) | |
| self.conv_front.append(nn.Sequential(block(1, 90, 12, 6, 0), block(90, 160, 5, 3, 0))) | |
| self.skip1 = nn.MaxPool1d(kernel_size=5, stride=8) | |
| self.skip2 = nn.MaxPool1d(kernel_size=3, stride=4, padding=1) | |
| # self.skip3 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) | |
| self.conv1 = block(512, 512, 5, 2) | |
| self.conv2 = block(512, 512, 3, 2) | |
| self.conv3 = block(512, 512, 3, 2, padding=2) | |
| self.am1 = SEBlock(512) | |
| self.am2 = SEBlock(512) | |
| self.am3 = SEBlock(512) | |
| self.am4 = SEBlock(512*3) | |
| def forward(self, x): | |
| # BxT -> BxCxT | |
| # wave encoder | |
| enc = [] | |
| ft_shape = [] | |
| for conv in self.conv_front: | |
| enc.append(conv(x)) | |
| ft_shape.append(conv(x).shape[-1]) | |
| ft_max = np.min(np.array(ft_shape)) | |
| enc = torch.cat((enc[0][:, :, :ft_max], enc[1][:, :, :ft_max], enc[2][:, :, :ft_max]), dim=1) | |
| # skipping layers | |
| skip1_out = self.skip1(enc) | |
| out = self.conv1(enc) | |
| out = self.am1(out) | |
| skip2_out = self.skip2(out) | |
| out = self.conv2(out) | |
| out = self.am2(out) | |
| # skip3_out = self.skip3(out) | |
| out = self.conv3(out) | |
| out = self.am3(out) | |
| t_max = np.min(np.array([skip1_out.shape[-1], skip2_out.shape[-1], out.shape[-1]])) | |
| out = torch.cat((skip1_out[:, :, :t_max], skip2_out[:, :, :t_max], out[:, :, :t_max]), dim=1) | |
| out = self.am4(out) | |
| return out | |
| class TDNN_Block(nn.Module): | |
| def __init__(self, input_dim, output_dim=512, context_size=5, dilation=1, norm='bn', affine=True): | |
| super(TDNN_Block, self).__init__() | |
| if norm == 'bn': | |
| norm_layer = nn.BatchNorm1d(output_dim, affine=affine) | |
| elif norm == 'ln': | |
| # norm_layer = nn.GroupNorm(1, output_dim, affine=affine) | |
| norm_layer = Fp32GroupNorm(1, output_dim, affine=affine) | |
| elif norm == 'in': | |
| norm_layer = nn.GroupNorm(output_dim, output_dim, affine=False) | |
| else: | |
| raise ValueError('Norm should be {bn, ln, in}.') | |
| self.tdnn_layer = nn.Sequential( | |
| TDNNLayer(input_dim, output_dim, context_size, dilation), | |
| norm_layer, | |
| nn.ReLU() | |
| ) | |
| def forward(self, x): | |
| return self.tdnn_layer(x) | |
| class xvecTDNN(nn.Module): | |
| def __init__(self, feature_dim=512, embed_dim=512, norm='bn', p_dropout=0.0): | |
| super(xvecTDNN, self).__init__() | |
| self.tdnn = nn.Sequential( | |
| TDNN_Block(feature_dim, 512, 5, 1, norm=norm), | |
| TDNN_Block(512, 512, 3, 2, norm=norm), | |
| TDNN_Block(512, 512, 3, 3, norm=norm), | |
| TDNN_Block(512, 512, 1, 1, norm=norm), | |
| TDNN_Block(512, 1500, 1, 1, norm=norm), | |
| ) | |
| self.fc1 = nn.Linear(3000, 512) | |
| self.bn = nn.BatchNorm1d(512) | |
| self.dropout_fc1 = nn.Dropout(p=p_dropout) | |
| self.lrelu = nn.LeakyReLU(0.2) | |
| self.fc2 = nn.Linear(512, embed_dim) | |
| def forward(self, x): | |
| # Note: x must be (batch_size, feat_dim, chunk_len) | |
| x = self.tdnn(x) | |
| stats = torch.cat((x.mean(dim=2), x.std(dim=2)), dim=1) | |
| x = self.dropout_fc1(self.lrelu(self.bn(self.fc1(stats)))) | |
| x = self.fc2(x) | |
| return x | |
| class YVector(nn.Module): | |
| def __init__(self, embed_dim=512): | |
| super().__init__() | |
| self.feature_encoder = MultiScaleConvFeatureExtractionModel() | |
| self.tdnn_aggregator = xvecTDNN(feature_dim=512*3, embed_dim=128, norm='ln') | |
| def forward(self, x): | |
| # require batch and channel dimensions | |
| assert x.ndim >= 2 | |
| # avoid modifying input audio | |
| n_batch, *channel_dims, signal_len = x.shape | |
| # add channel dimension if necessary | |
| if len(channel_dims) == 0: | |
| x = x.unsqueeze(1) | |
| out = self.feature_encoder(x) | |
| out = self.tdnn_aggregator(out) | |
| return out | |
| if __name__ == "__main__": | |
| model = YVector() | |
| print(model) | |
| wav_input_16khz = torch.randn(4, 1, 48000) | |
| c = model(wav_input_16khz) | |
| print(c.shape) | |