ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.speaker.yvector.tdnn import TDNNLayer
from src.models.speaker.yvector.wav2spk import (
ConvFeatureExtractionModel, Fp32GroupNorm, norm_block
)
import numpy as np
################################################################################
# Y-Vector implementation of Zhu et al. (2021)
################################################################################
class SEBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.fgate = nn.Sequential(nn.Linear(channels, channels), nn.Sigmoid())
self.tgate = nn.Sequential(nn.Linear(channels, 1), nn.Sigmoid())
def forward(self, x):
fg = self.fgate(x.mean(dim=-1))
x = x * fg.unsqueeze(-1)
tg = x.permute(0, 2, 1).contiguous().view(-1, x.shape[1])
tg = self.tgate(tg).view(x.shape[0], x.shape[2]).unsqueeze(1)
out = x * tg
return out
class MultiScaleConvFeatureExtractionModel(nn.Module):
def __init__(
self,
dropout=0.0,
non_affine_group_norm=False,
activation=nn.ReLU(),):
super().__init__()
def block(n_in, n_out, k, stride, padding=0):
return nn.Sequential(
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False, padding=padding),
nn.Dropout(p=dropout),
norm_block(is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm,
is_instance_norm=True),
activation)
self.conv_front = nn.ModuleList()
# multi-3: s=18
self.conv_front.append(nn.Sequential(block(1, 90, 36, 18, 0), block(90, 192, 5, 1, 2)))
self.conv_front.append(nn.Sequential(block(1, 90, 18, 9, 0), block(90, 160, 5, 2, 0)))
self.conv_front.append(nn.Sequential(block(1, 90, 12, 6, 0), block(90, 160, 5, 3, 0)))
self.skip1 = nn.MaxPool1d(kernel_size=5, stride=8)
self.skip2 = nn.MaxPool1d(kernel_size=3, stride=4, padding=1)
# self.skip3 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
self.conv1 = block(512, 512, 5, 2)
self.conv2 = block(512, 512, 3, 2)
self.conv3 = block(512, 512, 3, 2, padding=2)
self.am1 = SEBlock(512)
self.am2 = SEBlock(512)
self.am3 = SEBlock(512)
self.am4 = SEBlock(512*3)
def forward(self, x):
# BxT -> BxCxT
# wave encoder
enc = []
ft_shape = []
for conv in self.conv_front:
enc.append(conv(x))
ft_shape.append(conv(x).shape[-1])
ft_max = np.min(np.array(ft_shape))
enc = torch.cat((enc[0][:, :, :ft_max], enc[1][:, :, :ft_max], enc[2][:, :, :ft_max]), dim=1)
# skipping layers
skip1_out = self.skip1(enc)
out = self.conv1(enc)
out = self.am1(out)
skip2_out = self.skip2(out)
out = self.conv2(out)
out = self.am2(out)
# skip3_out = self.skip3(out)
out = self.conv3(out)
out = self.am3(out)
t_max = np.min(np.array([skip1_out.shape[-1], skip2_out.shape[-1], out.shape[-1]]))
out = torch.cat((skip1_out[:, :, :t_max], skip2_out[:, :, :t_max], out[:, :, :t_max]), dim=1)
out = self.am4(out)
return out
class TDNN_Block(nn.Module):
def __init__(self, input_dim, output_dim=512, context_size=5, dilation=1, norm='bn', affine=True):
super(TDNN_Block, self).__init__()
if norm == 'bn':
norm_layer = nn.BatchNorm1d(output_dim, affine=affine)
elif norm == 'ln':
# norm_layer = nn.GroupNorm(1, output_dim, affine=affine)
norm_layer = Fp32GroupNorm(1, output_dim, affine=affine)
elif norm == 'in':
norm_layer = nn.GroupNorm(output_dim, output_dim, affine=False)
else:
raise ValueError('Norm should be {bn, ln, in}.')
self.tdnn_layer = nn.Sequential(
TDNNLayer(input_dim, output_dim, context_size, dilation),
norm_layer,
nn.ReLU()
)
def forward(self, x):
return self.tdnn_layer(x)
class xvecTDNN(nn.Module):
def __init__(self, feature_dim=512, embed_dim=512, norm='bn', p_dropout=0.0):
super(xvecTDNN, self).__init__()
self.tdnn = nn.Sequential(
TDNN_Block(feature_dim, 512, 5, 1, norm=norm),
TDNN_Block(512, 512, 3, 2, norm=norm),
TDNN_Block(512, 512, 3, 3, norm=norm),
TDNN_Block(512, 512, 1, 1, norm=norm),
TDNN_Block(512, 1500, 1, 1, norm=norm),
)
self.fc1 = nn.Linear(3000, 512)
self.bn = nn.BatchNorm1d(512)
self.dropout_fc1 = nn.Dropout(p=p_dropout)
self.lrelu = nn.LeakyReLU(0.2)
self.fc2 = nn.Linear(512, embed_dim)
def forward(self, x):
# Note: x must be (batch_size, feat_dim, chunk_len)
x = self.tdnn(x)
stats = torch.cat((x.mean(dim=2), x.std(dim=2)), dim=1)
x = self.dropout_fc1(self.lrelu(self.bn(self.fc1(stats))))
x = self.fc2(x)
return x
class YVector(nn.Module):
def __init__(self, embed_dim=512):
super().__init__()
self.feature_encoder = MultiScaleConvFeatureExtractionModel()
self.tdnn_aggregator = xvecTDNN(feature_dim=512*3, embed_dim=128, norm='ln')
def forward(self, x):
# require batch and channel dimensions
assert x.ndim >= 2
# avoid modifying input audio
n_batch, *channel_dims, signal_len = x.shape
# add channel dimension if necessary
if len(channel_dims) == 0:
x = x.unsqueeze(1)
out = self.feature_encoder(x)
out = self.tdnn_aggregator(out)
return out
if __name__ == "__main__":
model = YVector()
print(model)
wav_input_16khz = torch.randn(4, 1, 48000)
c = model(wav_input_16khz)
print(c.shape)