#!/usr/bin/env python import torch as th import torch.nn as nn import torch.nn.functional as F from .norm import ChannelwiseLayerNorm, GlobalLayerNorm, CumLN class Conv1D(nn.Conv1d): """ 1D Conv based on nn.Conv1d for 2D or 3D tensor Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in] Output: Default 3D tensor with [N, C_out, L_out] If C_out=1 and squeeze is true, return 2D tensor """ def __init__(self, *args, **kwargs): super(Conv1D, self).__init__(*args, **kwargs) def forward(self, x, squeeze=False): if x.dim() not in [2, 3]: raise RuntimeError("{} require a 2/3D tensor input".format( self.__name__)) x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) if squeeze: x = th.squeeze(x) return x class ConvTrans1D(nn.ConvTranspose1d): """ 1D Transposed Conv based on nn.ConvTranspose1d for 2D or 3D tensor Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in] Output: 2D tensor with [N, L_out] """ def __init__(self, *args, **kwargs): super(ConvTrans1D, self).__init__(*args, **kwargs) def forward(self, x): if x.dim() not in [2, 3]: raise RuntimeError("{} require a 2/3D tensor input".format( self.__name__)) x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) # squeeze the channel dimension 1 after reconstructing the signal return th.squeeze(x, 1) class TCNBlock(nn.Module): """ Temporal convolutional network block, 1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv Input: 3D tensor with [N, C_in, L_in] Output: 3D tensor with [N, C_out, L_out] """ def __init__(self, in_channels=256, conv_channels=512, kernel_size=3, dilation=1, causal=False, norm_type='gLN'): super(TCNBlock, self).__init__() self.conv1x1 = Conv1D(in_channels, conv_channels, 1) self.prelu1 = nn.PReLU() # self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) if norm_type == 'gLN': self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) elif norm_type == 'cLN': self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) elif norm_type == 'cgLN': self.norm1 = CumLN(conv_channels, elementwise_affine=True) self.norm2 = CumLN(conv_channels, elementwise_affine=True) dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( dilation * (kernel_size - 1)) self.dconv = nn.Conv1d( conv_channels, conv_channels, kernel_size, groups=conv_channels, padding=dconv_pad, dilation=dilation, bias=True) self.prelu2 = nn.PReLU() # self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) self.causal = causal self.dconv_pad = dconv_pad def forward(self, x): y = self.conv1x1(x) y = self.norm1(self.prelu1(y)) y = self.dconv(y) if self.causal: y = y[:, :, :-self.dconv_pad] y = self.norm2(self.prelu2(y)) y = self.sconv(y) y += x return y class TCNBlock_Spk(nn.Module): """ Temporal convolutional network block, 1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv The first tcn block takes additional speaker embedding as inputs Input: 3D tensor with [N, C_in, L_in] Input Speaker Embedding: 2D tensor with [N, D] Output: 3D tensor with [N, C_out, L_out] """ def __init__(self, in_channels=256, spk_embed_dim=100, conv_channels=512, kernel_size=3, dilation=1, causal=False, norm_type='gLN', fusion_type='cat'): super(TCNBlock_Spk, self).__init__() self.fusion_type = fusion_type if fusion_type == 'cat': self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1) if fusion_type in ('add', 'mul'): self.fusion_linear = nn.Linear(spk_embed_dim, in_channels) self.conv1x1 = Conv1D(in_channels, conv_channels, 1) if fusion_type == 'film': self.fusion_linear_1 = nn.Linear(spk_embed_dim, in_channels) self.fusion_linear_2 = nn.Linear(spk_embed_dim, in_channels) self.conv1x1 = Conv1D(in_channels, conv_channels, 1) if fusion_type == 'att': self.fusion_linear = nn.Linear(spk_embed_dim, in_channels) self.average = Conv1D(in_channels, in_channels, kernel_size, kernel_size, groups=in_channels) self.average.weight = nn.Parameter(th.ones(in_channels, 1, kernel_size) / kernel_size) self.average.bias = nn.Parameter(th.zeros(in_channels)) for p in self.average.parameters(): p.requires_grad = False self.conv1x1 = Conv1D(in_channels, conv_channels, 1) self.prelu1 = nn.PReLU() # self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) if norm_type == 'gLN': self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) elif norm_type == 'cLN': self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True) elif norm_type == 'cgLN': self.norm1 = CumLN(conv_channels, elementwise_affine=True) self.norm2 = CumLN(conv_channels, elementwise_affine=True) dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( dilation * (kernel_size - 1)) self.dconv = nn.Conv1d( conv_channels, conv_channels, kernel_size, groups=conv_channels, padding=dconv_pad, dilation=dilation, bias=True) self.prelu2 = nn.PReLU() # self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else ( # ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)) self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) self.causal = causal self.dconv_pad = dconv_pad self.dilation = dilation def _concatenation(self, aux, output, L): aux_concat = th.unsqueeze(aux, -1) aux_concat = aux_concat.repeat(1, 1, L) # -> [B, N(embeddings_size), L] output = th.cat([output, aux_concat], 1) # -> [B, N(input_size + embeddings_size), L] return output def _addition(self, aux, output, L, fusion_linear): aux_add = fusion_linear(aux) # -> [B, N(input_size)] aux_add = th.unsqueeze(aux_add, -1) aux_add = aux_add.repeat(1, 1, L) # -> [B, N(input_size), L] output = output + aux_add # -> [B, N(input_size, L] return output def _multiplication(self, aux, output, L, fusion_linear): aux_mul = fusion_linear(aux) # -> [B, N(input_size)] aux_mul = th.unsqueeze(aux_mul, -1) aux_mul = aux_mul.repeat(1, 1, L) # -> [B, N(input_size), L] output = output * aux_mul # -> [B, N(input_size, L] return output def _attention(self, aux, output, fusion_linear): L = output.shape[-1] aux_att = fusion_linear(aux) aux_att = th.unsqueeze(aux_att, -1) aux_att = aux_att.repeat(1, 1, L) att = th.sum(output * aux_att, 1, keepdim=True) att = F.softmax(att, -1) att = att * aux_att return att + aux_att def _film(self, aux, output, L): output = self._multiplication(aux, output, L, self.fusion_linear_1) # -> [B, N(input_size, L] output = self._addition(aux, output, L, self.fusion_linear_2) # -> [B, N(input_size, L] return output def forward(self, x, aux): # Repeatedly concated speaker embedding aux to each frame of the representation x T = x.shape[-1] if self.fusion_type == 'cat': y = self._concatenation(aux, x, T) # -> [B, N(input_size + embeddings_size), L] if self.fusion_type == 'add': y = self._addition(aux, x, T, self.fusion_linear) # -> [B, N(input_size), L] if self.fusion_type == 'mul': y = self._multiplication(aux, x, T, self.fusion_linear) # -> [B, N(input_size), L] if self.fusion_type == 'film': y = self._film(aux, x, T) # -> [B, N(input_size), L] if self.fusion_type == 'att': output_avg = self.average(x) att_out = self._attention(aux, output_avg, self.fusion_linear) upsampling = nn.Upsample(size=T, mode='nearest') att_out = upsampling(att_out) y = x * att_out y = self.conv1x1(y) y = self.norm1(self.prelu1(y)) y = self.dconv(y) if self.causal: y = y[:, :, :-self.dconv_pad] y = self.norm2(self.prelu2(y)) y = self.sconv(y) y += x return y class ResBlock(nn.Module): """ Resnet block for speaker encoder to obtain speaker embedding ref to https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py and https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py """ def __init__(self, in_dims, out_dims): super(ResBlock, self).__init__() self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False) self.batch_norm1 = nn.BatchNorm1d(out_dims) self.batch_norm2 = nn.BatchNorm1d(out_dims) self.prelu1 = nn.PReLU() self.prelu2 = nn.PReLU() self.maxpool = nn.MaxPool1d(3) if in_dims != out_dims: self.downsample = True self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) else: self.downsample = False def forward(self, x): y = self.conv1(x) y = self.batch_norm1(y) y = self.prelu1(y) y = self.conv2(y) y = self.batch_norm2(y) if self.downsample: y += self.conv_downsample(x) else: y += x y = self.prelu2(y) return self.maxpool(y)