swc2's picture
update change 2
7eddfc5
#!/usr/bin/env python
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm, CumLN
class Conv1D(nn.Conv1d):
"""
1D Conv based on nn.Conv1d for 2D or 3D tensor
Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
Output: Default 3D tensor with [N, C_out, L_out]
If C_out=1 and squeeze is true, return 2D tensor
"""
def __init__(self, *args, **kwargs):
super(Conv1D, self).__init__(*args, **kwargs)
def forward(self, x, squeeze=False):
if x.dim() not in [2, 3]:
raise RuntimeError("{} require a 2/3D tensor input".format(
self.__name__))
x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
if squeeze:
x = th.squeeze(x)
return x
class ConvTrans1D(nn.ConvTranspose1d):
"""
1D Transposed Conv based on nn.ConvTranspose1d for 2D or 3D tensor
Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
Output: 2D tensor with [N, L_out]
"""
def __init__(self, *args, **kwargs):
super(ConvTrans1D, self).__init__(*args, **kwargs)
def forward(self, x):
if x.dim() not in [2, 3]:
raise RuntimeError("{} require a 2/3D tensor input".format(
self.__name__))
x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
# squeeze the channel dimension 1 after reconstructing the signal
return th.squeeze(x, 1)
class TCNBlock(nn.Module):
"""
Temporal convolutional network block,
1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv
Input: 3D tensor with [N, C_in, L_in]
Output: 3D tensor with [N, C_out, L_out]
"""
def __init__(self,
in_channels=256,
conv_channels=512,
kernel_size=3,
dilation=1,
causal=False,
norm_type='gLN'):
super(TCNBlock, self).__init__()
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
self.prelu1 = nn.PReLU()
# self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
if norm_type == 'gLN':
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
elif norm_type == 'cLN':
self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
elif norm_type == 'cgLN':
self.norm1 = CumLN(conv_channels, elementwise_affine=True)
self.norm2 = CumLN(conv_channels, elementwise_affine=True)
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
dilation * (kernel_size - 1))
self.dconv = nn.Conv1d(
conv_channels,
conv_channels,
kernel_size,
groups=conv_channels,
padding=dconv_pad,
dilation=dilation,
bias=True)
self.prelu2 = nn.PReLU()
# self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
self.causal = causal
self.dconv_pad = dconv_pad
def forward(self, x):
y = self.conv1x1(x)
y = self.norm1(self.prelu1(y))
y = self.dconv(y)
if self.causal:
y = y[:, :, :-self.dconv_pad]
y = self.norm2(self.prelu2(y))
y = self.sconv(y)
y += x
return y
class TCNBlock_Spk(nn.Module):
"""
Temporal convolutional network block,
1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv
The first tcn block takes additional speaker embedding as inputs
Input: 3D tensor with [N, C_in, L_in]
Input Speaker Embedding: 2D tensor with [N, D]
Output: 3D tensor with [N, C_out, L_out]
"""
def __init__(self,
in_channels=256,
spk_embed_dim=100,
conv_channels=512,
kernel_size=3,
dilation=1,
causal=False,
norm_type='gLN',
fusion_type='cat'):
super(TCNBlock_Spk, self).__init__()
self.fusion_type = fusion_type
if fusion_type == 'cat':
self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1)
if fusion_type in ('add', 'mul'):
self.fusion_linear = nn.Linear(spk_embed_dim, in_channels)
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
if fusion_type == 'film':
self.fusion_linear_1 = nn.Linear(spk_embed_dim, in_channels)
self.fusion_linear_2 = nn.Linear(spk_embed_dim, in_channels)
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
if fusion_type == 'att':
self.fusion_linear = nn.Linear(spk_embed_dim, in_channels)
self.average = Conv1D(in_channels, in_channels, kernel_size, kernel_size, groups=in_channels)
self.average.weight = nn.Parameter(th.ones(in_channels, 1, kernel_size) / kernel_size)
self.average.bias = nn.Parameter(th.zeros(in_channels))
for p in self.average.parameters():
p.requires_grad = False
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
self.prelu1 = nn.PReLU()
# self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
if norm_type == 'gLN':
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True)
elif norm_type == 'cLN':
self.norm1 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
self.norm2 = ChannelwiseLayerNorm(conv_channels, elementwise_affine=True)
elif norm_type == 'cgLN':
self.norm1 = CumLN(conv_channels, elementwise_affine=True)
self.norm2 = CumLN(conv_channels, elementwise_affine=True)
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
dilation * (kernel_size - 1))
self.dconv = nn.Conv1d(
conv_channels,
conv_channels,
kernel_size,
groups=conv_channels,
padding=dconv_pad,
dilation=dilation,
bias=True)
self.prelu2 = nn.PReLU()
# self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
# ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
self.causal = causal
self.dconv_pad = dconv_pad
self.dilation = dilation
def _concatenation(self, aux, output, L):
aux_concat = th.unsqueeze(aux, -1)
aux_concat = aux_concat.repeat(1, 1, L)
# -> [B, N(embeddings_size), L]
output = th.cat([output, aux_concat], 1)
# -> [B, N(input_size + embeddings_size), L]
return output
def _addition(self, aux, output, L, fusion_linear):
aux_add = fusion_linear(aux)
# -> [B, N(input_size)]
aux_add = th.unsqueeze(aux_add, -1)
aux_add = aux_add.repeat(1, 1, L)
# -> [B, N(input_size), L]
output = output + aux_add
# -> [B, N(input_size, L]
return output
def _multiplication(self, aux, output, L, fusion_linear):
aux_mul = fusion_linear(aux)
# -> [B, N(input_size)]
aux_mul = th.unsqueeze(aux_mul, -1)
aux_mul = aux_mul.repeat(1, 1, L)
# -> [B, N(input_size), L]
output = output * aux_mul
# -> [B, N(input_size, L]
return output
def _attention(self, aux, output, fusion_linear):
L = output.shape[-1]
aux_att = fusion_linear(aux)
aux_att = th.unsqueeze(aux_att, -1)
aux_att = aux_att.repeat(1, 1, L)
att = th.sum(output * aux_att, 1, keepdim=True)
att = F.softmax(att, -1)
att = att * aux_att
return att + aux_att
def _film(self, aux, output, L):
output = self._multiplication(aux, output, L, self.fusion_linear_1)
# -> [B, N(input_size, L]
output = self._addition(aux, output, L, self.fusion_linear_2)
# -> [B, N(input_size, L]
return output
def forward(self, x, aux):
# Repeatedly concated speaker embedding aux to each frame of the representation x
T = x.shape[-1]
if self.fusion_type == 'cat':
y = self._concatenation(aux, x, T)
# -> [B, N(input_size + embeddings_size), L]
if self.fusion_type == 'add':
y = self._addition(aux, x, T, self.fusion_linear)
# -> [B, N(input_size), L]
if self.fusion_type == 'mul':
y = self._multiplication(aux, x, T, self.fusion_linear)
# -> [B, N(input_size), L]
if self.fusion_type == 'film':
y = self._film(aux, x, T)
# -> [B, N(input_size), L]
if self.fusion_type == 'att':
output_avg = self.average(x)
att_out = self._attention(aux, output_avg, self.fusion_linear)
upsampling = nn.Upsample(size=T, mode='nearest')
att_out = upsampling(att_out)
y = x * att_out
y = self.conv1x1(y)
y = self.norm1(self.prelu1(y))
y = self.dconv(y)
if self.causal:
y = y[:, :, :-self.dconv_pad]
y = self.norm2(self.prelu2(y))
y = self.sconv(y)
y += x
return y
class ResBlock(nn.Module):
"""
Resnet block for speaker encoder to obtain speaker embedding
ref to
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py
and
https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py
"""
def __init__(self, in_dims, out_dims):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)
self.batch_norm1 = nn.BatchNorm1d(out_dims)
self.batch_norm2 = nn.BatchNorm1d(out_dims)
self.prelu1 = nn.PReLU()
self.prelu2 = nn.PReLU()
self.maxpool = nn.MaxPool1d(3)
if in_dims != out_dims:
self.downsample = True
self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
else:
self.downsample = False
def forward(self, x):
y = self.conv1(x)
y = self.batch_norm1(y)
y = self.prelu1(y)
y = self.conv2(y)
y = self.batch_norm2(y)
if self.downsample:
y += self.conv_downsample(x)
else:
y += x
y = self.prelu2(y)
return self.maxpool(y)