Spaces:
Sleeping
Sleeping
| """ refer from https://github.com/zceng/LVCNet """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from .amp import AMPBlock | |
| class KernelPredictor(torch.nn.Module): | |
| """Kernel predictor for the location-variable convolutions""" | |
| def __init__( | |
| self, | |
| cond_channels, | |
| conv_in_channels, | |
| conv_out_channels, | |
| conv_layers, | |
| conv_kernel_size=3, | |
| kpnet_hidden_channels=64, | |
| kpnet_conv_size=3, | |
| kpnet_dropout=0.0, | |
| kpnet_nonlinear_activation="LeakyReLU", | |
| kpnet_nonlinear_activation_params={"negative_slope": 0.1}, | |
| ): | |
| """ | |
| Args: | |
| cond_channels (int): number of channel for the conditioning sequence, | |
| conv_in_channels (int): number of channel for the input sequence, | |
| conv_out_channels (int): number of channel for the output sequence, | |
| conv_layers (int): number of layers | |
| """ | |
| super().__init__() | |
| self.conv_in_channels = conv_in_channels | |
| self.conv_out_channels = conv_out_channels | |
| self.conv_kernel_size = conv_kernel_size | |
| self.conv_layers = conv_layers | |
| kpnet_kernel_channels = ( | |
| conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers | |
| ) # l_w | |
| kpnet_bias_channels = conv_out_channels * conv_layers # l_b | |
| self.input_conv = nn.Sequential( | |
| weight_norm( | |
| nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) | |
| ), | |
| getattr(nn, kpnet_nonlinear_activation)( | |
| **kpnet_nonlinear_activation_params | |
| ), | |
| ) | |
| self.residual_convs = nn.ModuleList() | |
| padding = (kpnet_conv_size - 1) // 2 | |
| for _ in range(3): | |
| self.residual_convs.append( | |
| nn.Sequential( | |
| nn.Dropout(kpnet_dropout), | |
| weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_hidden_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ), | |
| getattr(nn, kpnet_nonlinear_activation)( | |
| **kpnet_nonlinear_activation_params | |
| ), | |
| weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_hidden_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ), | |
| getattr(nn, kpnet_nonlinear_activation)( | |
| **kpnet_nonlinear_activation_params | |
| ), | |
| ) | |
| ) | |
| self.kernel_conv = weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_kernel_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ) | |
| self.bias_conv = weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_bias_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ) | |
| def forward(self, c): | |
| """ | |
| Args: | |
| c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) | |
| """ | |
| batch, _, cond_length = c.shape | |
| c = self.input_conv(c) | |
| for residual_conv in self.residual_convs: | |
| residual_conv.to(c.device) | |
| c = c + residual_conv(c) | |
| k = self.kernel_conv(c) | |
| b = self.bias_conv(c) | |
| kernels = k.contiguous().view( | |
| batch, | |
| self.conv_layers, | |
| self.conv_in_channels, | |
| self.conv_out_channels, | |
| self.conv_kernel_size, | |
| cond_length, | |
| ) | |
| bias = b.contiguous().view( | |
| batch, | |
| self.conv_layers, | |
| self.conv_out_channels, | |
| cond_length, | |
| ) | |
| return kernels, bias | |
| class LVCBlock(torch.nn.Module): | |
| """the location-variable convolutions""" | |
| def __init__( | |
| self, | |
| in_channels, | |
| cond_channels, | |
| stride, | |
| dilations=[1, 3, 9, 27], | |
| lReLU_slope=0.2, | |
| conv_kernel_size=3, | |
| cond_hop_length=256, | |
| kpnet_hidden_channels=64, | |
| kpnet_conv_size=3, | |
| kpnet_dropout=0.0, | |
| add_extra_noise=False, | |
| downsampling=False, | |
| ): | |
| super().__init__() | |
| self.add_extra_noise = add_extra_noise | |
| self.cond_hop_length = cond_hop_length | |
| self.conv_layers = len(dilations) | |
| self.conv_kernel_size = conv_kernel_size | |
| self.kernel_predictor = KernelPredictor( | |
| cond_channels=cond_channels, | |
| conv_in_channels=in_channels, | |
| conv_out_channels=2 * in_channels, | |
| conv_layers=len(dilations), | |
| conv_kernel_size=conv_kernel_size, | |
| kpnet_hidden_channels=kpnet_hidden_channels, | |
| kpnet_conv_size=kpnet_conv_size, | |
| kpnet_dropout=kpnet_dropout, | |
| kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, | |
| ) | |
| if downsampling: | |
| self.convt_pre = nn.Sequential( | |
| nn.LeakyReLU(lReLU_slope), | |
| weight_norm( | |
| nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same") | |
| ), | |
| nn.AvgPool1d(stride, stride), | |
| ) | |
| else: | |
| if stride == 1: | |
| self.convt_pre = nn.Sequential( | |
| nn.LeakyReLU(lReLU_slope), | |
| weight_norm(nn.Conv1d(in_channels, in_channels, 1)), | |
| ) | |
| else: | |
| self.convt_pre = nn.Sequential( | |
| nn.LeakyReLU(lReLU_slope), | |
| weight_norm( | |
| nn.ConvTranspose1d( | |
| in_channels, | |
| in_channels, | |
| 2 * stride, | |
| stride=stride, | |
| padding=stride // 2 + stride % 2, | |
| output_padding=stride % 2, | |
| ) | |
| ), | |
| ) | |
| self.amp_block = AMPBlock(in_channels) | |
| self.conv_blocks = nn.ModuleList() | |
| for d in dilations: | |
| self.conv_blocks.append( | |
| nn.Sequential( | |
| nn.LeakyReLU(lReLU_slope), | |
| weight_norm( | |
| nn.Conv1d( | |
| in_channels, | |
| in_channels, | |
| conv_kernel_size, | |
| dilation=d, | |
| padding="same", | |
| ) | |
| ), | |
| nn.LeakyReLU(lReLU_slope), | |
| ) | |
| ) | |
| def forward(self, x, c): | |
| """forward propagation of the location-variable convolutions. | |
| Args: | |
| x (Tensor): the input sequence (batch, in_channels, in_length) | |
| c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) | |
| Returns: | |
| Tensor: the output sequence (batch, in_channels, in_length) | |
| """ | |
| _, in_channels, _ = x.shape # (B, c_g, L') | |
| x = self.convt_pre(x) # (B, c_g, stride * L') | |
| # Add one amp block just after the upsampling | |
| x = self.amp_block(x) # (B, c_g, stride * L') | |
| kernels, bias = self.kernel_predictor(c) | |
| if self.add_extra_noise: | |
| # Add extra noise to part of the feature | |
| a, b = x.chunk(2, dim=1) | |
| b = b + torch.randn_like(b) * 0.1 | |
| x = torch.cat([a, b], dim=1) | |
| for i, conv in enumerate(self.conv_blocks): | |
| output = conv(x) # (B, c_g, stride * L') | |
| k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) | |
| b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) | |
| output = self.location_variable_convolution( | |
| output, k, b, hop_size=self.cond_hop_length | |
| ) # (B, 2 * c_g, stride * L'): LVC | |
| x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( | |
| output[:, in_channels:, :] | |
| ) # (B, c_g, stride * L'): GAU | |
| return x | |
| def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): | |
| """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. | |
| Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. | |
| Args: | |
| x (Tensor): the input sequence (batch, in_channels, in_length). | |
| kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) | |
| bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) | |
| dilation (int): the dilation of convolution. | |
| hop_size (int): the hop_size of the conditioning sequence. | |
| Returns: | |
| (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). | |
| """ | |
| batch, _, in_length = x.shape | |
| batch, _, out_channels, kernel_size, kernel_length = kernel.shape | |
| assert in_length == ( | |
| kernel_length * hop_size | |
| ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}" | |
| padding = dilation * int((kernel_size - 1) / 2) | |
| x = F.pad( | |
| x, (padding, padding), "constant", 0 | |
| ) # (batch, in_channels, in_length + 2*padding) | |
| x = x.unfold( | |
| 2, hop_size + 2 * padding, hop_size | |
| ) # (batch, in_channels, kernel_length, hop_size + 2*padding) | |
| if hop_size < dilation: | |
| x = F.pad(x, (0, dilation), "constant", 0) | |
| x = x.unfold( | |
| 3, dilation, dilation | |
| ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) | |
| x = x[:, :, :, :, :hop_size] | |
| x = x.transpose( | |
| 3, 4 | |
| ) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) | |
| x = x.unfold( | |
| 4, kernel_size, 1 | |
| ) # (batch, in_channels, kernel_length, dilation, _, kernel_size) | |
| o = torch.einsum("bildsk,biokl->bolsd", x, kernel) | |
| o = o.to(memory_format=torch.channels_last_3d) | |
| bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) | |
| o = o + bias | |
| o = o.contiguous().view(batch, out_channels, -1) | |
| return o | |