Spaces:
Running on Zero
Running on Zero
| import math | |
| import torch | |
| from torch import nn | |
| from torch.nn.utils import weight_norm | |
| def WNConv1d(*args, **kwargs): | |
| return weight_norm(nn.Conv1d(*args, **kwargs)) | |
| def WNConvTranspose1d(*args, **kwargs): | |
| return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
| # Scripting this brings model speed up 1.4x | |
| def snake(x, alpha): | |
| shape = x.shape | |
| x = x.reshape(shape[0], shape[1], -1) | |
| x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) | |
| x = x.reshape(shape) | |
| return x | |
| class Snake1d(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.ones(1, channels, 1)) | |
| def forward(self, x): | |
| return snake(x, self.alpha) | |
| class ResidualUnit(nn.Module): | |
| def __init__(self, dim: int = 16, dilation: int = 1): | |
| super().__init__() | |
| pad = ((7 - 1) * dilation) // 2 | |
| self.block = nn.Sequential( | |
| Snake1d(dim), | |
| WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), | |
| Snake1d(dim), | |
| WNConv1d(dim, dim, kernel_size=1), | |
| ) | |
| def forward(self, x): | |
| y = self.block(x) | |
| pad = (x.shape[-1] - y.shape[-1]) // 2 | |
| if pad > 0: | |
| x = x[..., pad:-pad] | |
| return x + y | |
| class EncoderBlock(nn.Module): | |
| def __init__(self, dim: int = 16, stride: int = 1): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| ResidualUnit(dim // 2, dilation=1), | |
| ResidualUnit(dim // 2, dilation=3), | |
| ResidualUnit(dim // 2, dilation=9), | |
| Snake1d(dim // 2), | |
| WNConv1d( | |
| dim // 2, | |
| dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=math.ceil(stride / 2), | |
| ), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_channel: int = 2, | |
| n_filters: int = 128, | |
| strides: list = [2, 4, 8, 8], | |
| d_latent: int = 64, | |
| ): | |
| super().__init__() | |
| self.input_channel = input_channel | |
| # Create first convolution | |
| self.block = [WNConv1d(self.input_channel, n_filters, kernel_size=7, padding=3)] | |
| # Create EncoderBlocks that double channels as they downsample by `stride` | |
| for stride in strides: | |
| n_filters *= 2 | |
| self.block += [EncoderBlock(n_filters, stride=stride)] | |
| # Create last convolution | |
| self.block += [ | |
| Snake1d(n_filters), | |
| WNConv1d(n_filters, d_latent, kernel_size=3, padding=1), | |
| ] | |
| # Wrap black into nn.Sequential | |
| self.block = nn.Sequential(*self.block) | |
| self.enc_dim = n_filters | |
| def forward(self, x): | |
| return self.block(x) | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| Snake1d(input_dim), | |
| WNConvTranspose1d( | |
| input_dim, | |
| output_dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=math.ceil(stride / 2), | |
| ), | |
| ResidualUnit(output_dim, dilation=1), | |
| ResidualUnit(output_dim, dilation=3), | |
| ResidualUnit(output_dim, dilation=9), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_latent, | |
| n_filters, | |
| rates, | |
| out_channel: int = 2, | |
| ): | |
| super().__init__() | |
| channels = n_filters * (2 ** len(rates)) | |
| # Add first conv layer | |
| layers = [WNConv1d(d_latent, channels, kernel_size=7, padding=3)] | |
| # Add upsampling + MRF blocks | |
| for i, stride in enumerate(rates): | |
| input_dim = channels // 2 ** i | |
| output_dim = channels // 2 ** (i + 1) | |
| layers += [DecoderBlock(input_dim, output_dim, stride)] | |
| # Add final conv layer | |
| layers += [ | |
| Snake1d(output_dim), | |
| WNConv1d(output_dim, out_channel, kernel_size=7, padding=3), | |
| # nn.Tanh(), | |
| ] | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |