Spaces:
Running
Running
| #!/usr/bin/env python | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .norm import ChannelwiseLayerNorm, GlobalLayerNorm, CumLN | |
| class Conv1D(nn.Conv1d): | |
| """ | |
| 1D Conv based on nn.Conv1d for 2D or 3D tensor | |
| Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in] | |
| Output: Default 3D tensor with [N, C_out, L_out] | |
| If C_out=1 and squeeze is true, return 2D tensor | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(Conv1D, self).__init__(*args, **kwargs) | |
| def forward(self, x, squeeze=False): | |
| if x.dim() not in [2, 3]: | |
| raise RuntimeError("{} require a 2/3D tensor input".format( | |
| self.__name__)) | |
| x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) | |
| if squeeze: | |
| x = th.squeeze(x) | |
| return x | |
| class ConvTrans1D(nn.ConvTranspose1d): | |
| """ | |
| 1D Transposed Conv based on nn.ConvTranspose1d for 2D or 3D tensor | |
| Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in] | |
| Output: 2D tensor with [N, L_out] | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(ConvTrans1D, self).__init__(*args, **kwargs) | |
| def forward(self, x): | |
| if x.dim() not in [2, 3]: | |
| raise RuntimeError("{} require a 2/3D tensor input".format( | |
| self.__name__)) | |
| x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) | |
| # squeeze the channel dimension 1 after reconstructing the signal | |
| return th.squeeze(x, 1) | |
| class TCNBlock(nn.Module): | |
| """ | |
| Temporal convolutional network block, | |
| 1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv | |
| Input: 3D tensor with [N, C_in, L_in] | |
| Output: 3D tensor with [N, C_out, L_out] | |
| """ | |
| def __init__(self, | |
| in_channels=256, | |
| conv_channels=512, | |
| kernel_size=3, | |
| dilation=1, | |
| causal=False, | |
| norm_type='gLN'): | |
| super(TCNBlock, self).__init__() | |
| self.conv1x1 = Conv1D(in_channels, conv_channels, 1) | |
| self.prelu1 = nn.PReLU() | |
| # self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( | |
| # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) | |
| if norm_type == 'gLN': | |
| self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) | |
| self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) | |
| elif norm_type == 'cLN': | |
| self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) | |
| self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) | |
| elif norm_type == 'cgLN': | |
| self.norm1 = CumLN(conv_channels, elementwise_affine=True) | |
| self.norm2 = CumLN(conv_channels, elementwise_affine=True) | |
| dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( | |
| dilation * (kernel_size - 1)) | |
| self.dconv = nn.Conv1d( | |
| conv_channels, | |
| conv_channels, | |
| kernel_size, | |
| groups=conv_channels, | |
| padding=dconv_pad, | |
| dilation=dilation, | |
| bias=True) | |
| self.prelu2 = nn.PReLU() | |
| # self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( | |
| # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) | |
| self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) | |
| self.causal = causal | |
| self.dconv_pad = dconv_pad | |
| def forward(self, x): | |
| y = self.conv1x1(x) | |
| y = self.norm1(self.prelu1(y)) | |
| y = self.dconv(y) | |
| if self.causal: | |
| y = y[:, :, :-self.dconv_pad] | |
| y = self.norm2(self.prelu2(y)) | |
| y = self.sconv(y) | |
| y += x | |
| return y | |
| class TCNBlock_Spk(nn.Module): | |
| """ | |
| Temporal convolutional network block, | |
| 1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv | |
| The first tcn block takes additional speaker embedding as inputs | |
| Input: 3D tensor with [N, C_in, L_in] | |
| Input Speaker Embedding: 2D tensor with [N, D] | |
| Output: 3D tensor with [N, C_out, L_out] | |
| """ | |
| def __init__(self, | |
| in_channels=256, | |
| spk_embed_dim=100, | |
| conv_channels=512, | |
| kernel_size=3, | |
| dilation=1, | |
| causal=False, | |
| norm_type='gLN', | |
| fusion_type='cat'): | |
| super(TCNBlock_Spk, self).__init__() | |
| self.fusion_type = fusion_type | |
| if fusion_type == 'cat': | |
| self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1) | |
| if fusion_type in ('add', 'mul'): | |
| self.fusion_linear = nn.Linear(spk_embed_dim, in_channels) | |
| self.conv1x1 = Conv1D(in_channels, conv_channels, 1) | |
| if fusion_type == 'film': | |
| self.fusion_linear_1 = nn.Linear(spk_embed_dim, in_channels) | |
| self.fusion_linear_2 = nn.Linear(spk_embed_dim, in_channels) | |
| self.conv1x1 = Conv1D(in_channels, conv_channels, 1) | |
| if fusion_type == 'att': | |
| self.fusion_linear = nn.Linear(spk_embed_dim, in_channels) | |
| self.average = Conv1D(in_channels, in_channels, kernel_size, kernel_size, groups=in_channels) | |
| self.average.weight = nn.Parameter(th.ones(in_channels, 1, kernel_size) / kernel_size) | |
| self.average.bias = nn.Parameter(th.zeros(in_channels)) | |
| for p in self.average.parameters(): | |
| p.requires_grad = False | |
| self.conv1x1 = Conv1D(in_channels, conv_channels, 1) | |
| self.prelu1 = nn.PReLU() | |
| # self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( | |
| # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) | |
| if norm_type == 'gLN': | |
| self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) | |
| self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) | |
| elif norm_type == 'cLN': | |
| self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) | |
| self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) | |
| elif norm_type == 'cgLN': | |
| self.norm1 = CumLN(conv_channels, elementwise_affine=True) | |
| self.norm2 = CumLN(conv_channels, elementwise_affine=True) | |
| dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( | |
| dilation * (kernel_size - 1)) | |
| self.dconv = nn.Conv1d( | |
| conv_channels, | |
| conv_channels, | |
| kernel_size, | |
| groups=conv_channels, | |
| padding=dconv_pad, | |
| dilation=dilation, | |
| bias=True) | |
| self.prelu2 = nn.PReLU() | |
| # self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( | |
| # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) | |
| self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) | |
| self.causal = causal | |
| self.dconv_pad = dconv_pad | |
| self.dilation = dilation | |
| def _concatenation(self, aux, output, L): | |
| aux_concat = th.unsqueeze(aux, -1) | |
| aux_concat = aux_concat.repeat(1, 1, L) | |
| # -> [B, N(embeddings_size), L] | |
| output = th.cat([output, aux_concat], 1) | |
| # -> [B, N(input_size + embeddings_size), L] | |
| return output | |
| def _addition(self, aux, output, L, fusion_linear): | |
| aux_add = fusion_linear(aux) | |
| # -> [B, N(input_size)] | |
| aux_add = th.unsqueeze(aux_add, -1) | |
| aux_add = aux_add.repeat(1, 1, L) | |
| # -> [B, N(input_size), L] | |
| output = output + aux_add | |
| # -> [B, N(input_size, L] | |
| return output | |
| def _multiplication(self, aux, output, L, fusion_linear): | |
| aux_mul = fusion_linear(aux) | |
| # -> [B, N(input_size)] | |
| aux_mul = th.unsqueeze(aux_mul, -1) | |
| aux_mul = aux_mul.repeat(1, 1, L) | |
| # -> [B, N(input_size), L] | |
| output = output * aux_mul | |
| # -> [B, N(input_size, L] | |
| return output | |
| def _attention(self, aux, output, fusion_linear): | |
| L = output.shape[-1] | |
| aux_att = fusion_linear(aux) | |
| aux_att = th.unsqueeze(aux_att, -1) | |
| aux_att = aux_att.repeat(1, 1, L) | |
| att = th.sum(output * aux_att, 1, keepdim=True) | |
| att = F.softmax(att, -1) | |
| att = att * aux_att | |
| return att + aux_att | |
| def _film(self, aux, output, L): | |
| output = self._multiplication(aux, output, L, self.fusion_linear_1) | |
| # -> [B, N(input_size, L] | |
| output = self._addition(aux, output, L, self.fusion_linear_2) | |
| # -> [B, N(input_size, L] | |
| return output | |
| def forward(self, x, aux): | |
| # Repeatedly concated speaker embedding aux to each frame of the representation x | |
| T = x.shape[-1] | |
| if self.fusion_type == 'cat': | |
| y = self._concatenation(aux, x, T) | |
| # -> [B, N(input_size + embeddings_size), L] | |
| if self.fusion_type == 'add': | |
| y = self._addition(aux, x, T, self.fusion_linear) | |
| # -> [B, N(input_size), L] | |
| if self.fusion_type == 'mul': | |
| y = self._multiplication(aux, x, T, self.fusion_linear) | |
| # -> [B, N(input_size), L] | |
| if self.fusion_type == 'film': | |
| y = self._film(aux, x, T) | |
| # -> [B, N(input_size), L] | |
| if self.fusion_type == 'att': | |
| output_avg = self.average(x) | |
| att_out = self._attention(aux, output_avg, self.fusion_linear) | |
| upsampling = nn.Upsample(size=T, mode='nearest') | |
| att_out = upsampling(att_out) | |
| y = x * att_out | |
| y = self.conv1x1(y) | |
| y = self.norm1(self.prelu1(y)) | |
| y = self.dconv(y) | |
| if self.causal: | |
| y = y[:, :, :-self.dconv_pad] | |
| y = self.norm2(self.prelu2(y)) | |
| y = self.sconv(y) | |
| y += x | |
| return y | |
| class ResBlock(nn.Module): | |
| """ | |
| Resnet block for speaker encoder to obtain speaker embedding | |
| ref to | |
| https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py | |
| and | |
| https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py | |
| """ | |
| def __init__(self, in_dims, out_dims): | |
| super(ResBlock, self).__init__() | |
| self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) | |
| self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False) | |
| self.batch_norm1 = nn.BatchNorm1d(out_dims) | |
| self.batch_norm2 = nn.BatchNorm1d(out_dims) | |
| self.prelu1 = nn.PReLU() | |
| self.prelu2 = nn.PReLU() | |
| self.maxpool = nn.MaxPool1d(3) | |
| if in_dims != out_dims: | |
| self.downsample = True | |
| self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) | |
| else: | |
| self.downsample = False | |
| def forward(self, x): | |
| y = self.conv1(x) | |
| y = self.batch_norm1(y) | |
| y = self.prelu1(y) | |
| y = self.conv2(y) | |
| y = self.batch_norm2(y) | |
| if self.downsample: | |
| y += self.conv_downsample(x) | |
| else: | |
| y += x | |
| y = self.prelu2(y) | |
| return self.maxpool(y) | |