import math import torch from torch import nn from torch.nn.utils import weight_norm def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) # Scripting this brings model speed up 1.4x @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x): return snake(x, self.alpha) class ResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( Snake1d(dim), WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), Snake1d(dim), WNConv1d(dim, dim, kernel_size=1), ) def forward(self, x): y = self.block(x) pad = (x.shape[-1] - y.shape[-1]) // 2 if pad > 0: x = x[..., pad:-pad] return x + y class EncoderBlock(nn.Module): def __init__(self, dim: int = 16, stride: int = 1): super().__init__() self.block = nn.Sequential( ResidualUnit(dim // 2, dilation=1), ResidualUnit(dim // 2, dilation=3), ResidualUnit(dim // 2, dilation=9), Snake1d(dim // 2), WNConv1d( dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ), ) def forward(self, x): return self.block(x) class Encoder(nn.Module): def __init__( self, input_channel: int = 2, n_filters: int = 128, strides: list = [2, 4, 8, 8], d_latent: int = 64, ): super().__init__() self.input_channel = input_channel # Create first convolution self.block = [WNConv1d(self.input_channel, n_filters, kernel_size=7, padding=3)] # Create EncoderBlocks that double channels as they downsample by `stride` for stride in strides: n_filters *= 2 self.block += [EncoderBlock(n_filters, stride=stride)] # Create last convolution self.block += [ Snake1d(n_filters), WNConv1d(n_filters, d_latent, kernel_size=3, padding=1), ] # Wrap black into nn.Sequential self.block = nn.Sequential(*self.block) self.enc_dim = n_filters def forward(self, x): return self.block(x) class DecoderBlock(nn.Module): def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): super().__init__() self.block = nn.Sequential( Snake1d(input_dim), WNConvTranspose1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ), ResidualUnit(output_dim, dilation=1), ResidualUnit(output_dim, dilation=3), ResidualUnit(output_dim, dilation=9), ) def forward(self, x): return self.block(x) class Decoder(nn.Module): def __init__( self, d_latent, n_filters, rates, out_channel: int = 2, ): super().__init__() channels = n_filters * (2 ** len(rates)) # Add first conv layer layers = [WNConv1d(d_latent, channels, kernel_size=7, padding=3)] # Add upsampling + MRF blocks for i, stride in enumerate(rates): input_dim = channels // 2 ** i output_dim = channels // 2 ** (i + 1) layers += [DecoderBlock(input_dim, output_dim, stride)] # Add final conv layer layers += [ Snake1d(output_dim), WNConv1d(output_dim, out_channel, kernel_size=7, padding=3), # nn.Tanh(), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)