Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from torch.nn.utils import remove_weight_norm | |
| from torch.autograd.function import InplaceFunction | |
| def get_padding(kernel_size, dilation=1): | |
| return int((kernel_size * dilation - dilation) / 2) | |
| # Scripting this brings model speed up 1.4x | |
| def snake(x, alpha): | |
| shape = x.shape | |
| x = x.reshape(shape[0], shape[1], -1) | |
| x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) | |
| x = x.reshape(shape) | |
| return x | |
| class Snake1d(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.ones(1, channels, 1)) | |
| def forward(self, x): | |
| return snake(x, self.alpha) | |
| class Conv1d(nn.Conv1d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| padding_mode: str = "zeros", | |
| bias: bool = True, | |
| padding=None, | |
| causal: bool = False, | |
| w_init_gain=None, | |
| ): | |
| self.causal = causal | |
| if padding is None: | |
| if causal: | |
| padding = 0 | |
| self.left_padding = dilation * (kernel_size - 1) | |
| else: | |
| padding = get_padding(kernel_size, dilation) | |
| super(Conv1d, self).__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| padding_mode=padding_mode, | |
| bias=bias, | |
| ) | |
| if w_init_gain is not None: | |
| torch.nn.init.xavier_uniform_( | |
| self.weight, gain=torch.nn.init.calculate_gain(w_init_gain) | |
| ) | |
| def forward(self, x): | |
| if self.causal: | |
| x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2) | |
| return super(Conv1d, self).forward(x) | |
| class ConvTranspose1d(nn.ConvTranspose1d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| output_padding: int = 0, | |
| groups: int = 1, | |
| bias: bool = True, | |
| dilation: int = 1, | |
| padding=None, | |
| padding_mode: str = "zeros", | |
| causal: bool = False, | |
| ): | |
| if padding is None: | |
| padding = 0 if causal else (kernel_size - stride) // 2 | |
| if causal: | |
| assert padding == 0, "padding is not allowed in causal ConvTranspose1d." | |
| assert ( | |
| kernel_size == 2 * stride | |
| ), "kernel_size must be equal to 2*stride is not allowed in causal ConvTranspose1d." | |
| super(ConvTranspose1d, self).__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| output_padding=output_padding, | |
| groups=groups, | |
| bias=bias, | |
| dilation=dilation, | |
| padding_mode=padding_mode, | |
| ) | |
| self.causal = causal | |
| self.stride = stride | |
| def forward(self, x): | |
| x = super(ConvTranspose1d, self).forward(x) | |
| if self.causal: | |
| x = x[:, :, : -self.stride] | |
| return x | |
| class PreProcessor(nn.Module): | |
| def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False): | |
| super(PreProcessor, self).__init__() | |
| self.pooling = torch.nn.AvgPool1d(kernel_size=num_samples) | |
| self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal) | |
| self.activation = nn.PReLU() | |
| def forward(self, x): | |
| output = self.activation(self.conv(x)) | |
| output = self.pooling(output) | |
| return output | |
| class PostProcessor(nn.Module): | |
| def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False): | |
| super(PostProcessor, self).__init__() | |
| self.num_samples = num_samples | |
| self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal) | |
| self.activation = nn.PReLU() | |
| def forward(self, x): | |
| x = torch.transpose(x, 1, 2) | |
| B, T, C = x.size() | |
| x = x.repeat(1, 1, self.num_samples).view(B, -1, C) | |
| x = torch.transpose(x, 1, 2) | |
| output = self.activation(self.conv(x)) | |
| return output | |
| class ResidualUnit(nn.Module): | |
| def __init__(self, n_in, n_out, dilation, res_kernel_size=7, causal=False): | |
| super(ResidualUnit, self).__init__() | |
| self.conv1 = weight_norm( | |
| Conv1d( | |
| n_in, | |
| n_out, | |
| kernel_size=res_kernel_size, | |
| dilation=dilation, | |
| causal=causal, | |
| ) | |
| ) | |
| self.conv2 = weight_norm(Conv1d(n_in, n_out, kernel_size=1, causal=causal)) | |
| self.activation1 = nn.PReLU() | |
| self.activation2 = nn.PReLU() | |
| def forward(self, x): | |
| output = self.activation1(self.conv1(x)) | |
| output = self.activation2(self.conv2(output)) | |
| return output + x | |
| class ResEncoderBlock(nn.Module): | |
| def __init__( | |
| self, n_in, n_out, stride, down_kernel_size, res_kernel_size=7, causal=False | |
| ): | |
| super(ResEncoderBlock, self).__init__() | |
| self.convs = nn.ModuleList( | |
| [ | |
| ResidualUnit( | |
| n_in, | |
| n_out // 2, | |
| dilation=1, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out // 2, | |
| n_out // 2, | |
| dilation=3, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out // 2, | |
| n_out // 2, | |
| dilation=5, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out // 2, | |
| n_out // 2, | |
| dilation=7, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out // 2, | |
| n_out // 2, | |
| dilation=9, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ] | |
| ) | |
| self.down_conv = DownsampleLayer( | |
| n_in, n_out, down_kernel_size, stride=stride, causal=causal | |
| ) | |
| def forward(self, x): | |
| for conv in self.convs: | |
| x = conv(x) | |
| x = self.down_conv(x) | |
| return x | |
| class ResDecoderBlock(nn.Module): | |
| def __init__( | |
| self, n_in, n_out, stride, up_kernel_size, res_kernel_size=7, causal=False | |
| ): | |
| super(ResDecoderBlock, self).__init__() | |
| self.up_conv = UpsampleLayer( | |
| n_in, | |
| n_out, | |
| kernel_size=up_kernel_size, | |
| stride=stride, | |
| causal=causal, | |
| activation=None, | |
| ) | |
| self.convs = nn.ModuleList( | |
| [ | |
| ResidualUnit( | |
| n_out, | |
| n_out, | |
| dilation=1, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out, | |
| n_out, | |
| dilation=3, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out, | |
| n_out, | |
| dilation=5, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out, | |
| n_out, | |
| dilation=7, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ResidualUnit( | |
| n_out, | |
| n_out, | |
| dilation=9, | |
| res_kernel_size=res_kernel_size, | |
| causal=causal, | |
| ), | |
| ] | |
| ) | |
| def forward(self, x): | |
| x = self.up_conv(x) | |
| for conv in self.convs: | |
| x = conv(x) | |
| return x | |
| class DownsampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| causal: bool = False, | |
| activation=nn.PReLU(), | |
| use_weight_norm: bool = True, | |
| pooling: bool = False, | |
| ): | |
| super(DownsampleLayer, self).__init__() | |
| self.pooling = pooling | |
| self.stride = stride | |
| self.activation = nn.PReLU() | |
| self.use_weight_norm = use_weight_norm | |
| if pooling: | |
| self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal) | |
| self.pooling = nn.AvgPool1d(kernel_size=stride) | |
| else: | |
| self.layer = Conv1d( | |
| in_channels, out_channels, kernel_size, stride=stride, causal=causal | |
| ) | |
| if use_weight_norm: | |
| self.layer = weight_norm(self.layer) | |
| def forward(self, x): | |
| x = self.layer(x) | |
| x = self.activation(x) if self.activation is not None else x | |
| if self.pooling: | |
| x = self.pooling(x) | |
| return x | |
| def remove_weight_norm(self): | |
| if self.use_weight_norm: | |
| remove_weight_norm(self.layer) | |
| class UpsampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| causal: bool = False, | |
| activation=nn.PReLU(), | |
| use_weight_norm: bool = True, | |
| repeat: bool = False, | |
| ): | |
| super(UpsampleLayer, self).__init__() | |
| self.repeat = repeat | |
| self.stride = stride | |
| self.activation = activation | |
| self.use_weight_norm = use_weight_norm | |
| if repeat: | |
| self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal) | |
| else: | |
| self.layer = ConvTranspose1d( | |
| in_channels, out_channels, kernel_size, stride=stride, causal=causal | |
| ) | |
| if use_weight_norm: | |
| self.layer = weight_norm(self.layer) | |
| def forward(self, x): | |
| x = self.layer(x) | |
| x = self.activation(x) if self.activation is not None else x | |
| if self.repeat: | |
| x = torch.transpose(x, 1, 2) | |
| B, T, C = x.size() | |
| x = x.repeat(1, 1, self.stride).view(B, -1, C) | |
| x = torch.transpose(x, 1, 2) | |
| return x | |
| def remove_weight_norm(self): | |
| if self.use_weight_norm: | |
| remove_weight_norm(self.layer) | |
| class round_func9(InplaceFunction): | |
| def forward(ctx, input): | |
| ctx.input = input | |
| return torch.round(9 * input) / 9 | |
| def backward(ctx, grad_output): | |
| grad_input = grad_output.clone() | |
| return grad_input | |
| class ScalarModel(nn.Module): | |
| def __init__( | |
| self, | |
| num_bands, | |
| sample_rate, | |
| causal, | |
| num_samples, | |
| downsample_factors, | |
| downsample_kernel_sizes, | |
| upsample_factors, | |
| upsample_kernel_sizes, | |
| latent_hidden_dim, | |
| default_kernel_size, | |
| delay_kernel_size, | |
| init_channel, | |
| res_kernel_size, | |
| mode="pre_proj", | |
| ): | |
| super(ScalarModel, self).__init__() | |
| # self.args = args | |
| self.encoder = [] | |
| self.decoder = [] | |
| self.vq = round_func9() # using 9 | |
| self.mode = mode | |
| # Encoder parts | |
| self.encoder.append( | |
| weight_norm( | |
| Conv1d( | |
| num_bands, | |
| init_channel, | |
| kernel_size=default_kernel_size, | |
| causal=causal, | |
| ) | |
| ) | |
| ) | |
| if num_samples > 1: | |
| # Downsampling | |
| self.encoder.append( | |
| PreProcessor( | |
| init_channel, | |
| init_channel, | |
| num_samples, | |
| kernel_size=default_kernel_size, | |
| causal=causal, | |
| ) | |
| ) | |
| for i, down_factor in enumerate(downsample_factors): | |
| self.encoder.append( | |
| ResEncoderBlock( | |
| init_channel * np.power(2, i), | |
| init_channel * np.power(2, i + 1), | |
| down_factor, | |
| downsample_kernel_sizes[i], | |
| res_kernel_size, | |
| causal=causal, | |
| ) | |
| ) | |
| self.encoder.append( | |
| weight_norm( | |
| Conv1d( | |
| init_channel * np.power(2, len(downsample_factors)), | |
| latent_hidden_dim, | |
| kernel_size=default_kernel_size, | |
| causal=causal, | |
| ) | |
| ) | |
| ) | |
| # Decoder | |
| # look ahead | |
| self.decoder.append( | |
| weight_norm( | |
| Conv1d( | |
| latent_hidden_dim, | |
| init_channel * np.power(2, len(upsample_factors)), | |
| kernel_size=delay_kernel_size, | |
| ) | |
| ) | |
| ) | |
| for i, upsample_factor in enumerate(upsample_factors): | |
| self.decoder.append( | |
| ResDecoderBlock( | |
| init_channel * np.power(2, len(upsample_factors) - i), | |
| init_channel * np.power(2, len(upsample_factors) - i - 1), | |
| upsample_factor, | |
| upsample_kernel_sizes[i], | |
| res_kernel_size, | |
| causal=causal, | |
| ) | |
| ) | |
| if num_samples > 1: | |
| self.decoder.append( | |
| PostProcessor( | |
| init_channel, | |
| init_channel, | |
| num_samples, | |
| kernel_size=default_kernel_size, | |
| causal=causal, | |
| ) | |
| ) | |
| self.decoder.append( | |
| weight_norm( | |
| Conv1d( | |
| init_channel, | |
| num_bands, | |
| kernel_size=default_kernel_size, | |
| causal=causal, | |
| ) | |
| ) | |
| ) | |
| self.encoder = nn.ModuleList(self.encoder) | |
| self.decoder = nn.ModuleList(self.decoder) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.encoder): | |
| if i != len(self.encoder) - 1: | |
| x = layer(x) | |
| else: | |
| x = F.tanh(layer(x)) | |
| # import pdb; pdb.set_trace() | |
| x = self.vq.apply(x) # vq | |
| for i, layer in enumerate(self.decoder): | |
| x = layer(x) | |
| return x | |
| def inference(self, x): | |
| for i, layer in enumerate(self.encoder): | |
| if i != len(self.encoder) - 1: | |
| x = layer(x) | |
| else: | |
| x = F.tanh(layer(x)) # reverse to tanh | |
| emb = x | |
| # import pdb; pdb.set_trace() | |
| emb_quant = self.vq.apply(emb) # vq | |
| x = emb_quant | |
| for i, layer in enumerate(self.decoder): | |
| x = layer(x) | |
| return emb, emb_quant, x | |
| def encode(self, x): | |
| for i, layer in enumerate(self.encoder): | |
| if i != len(self.encoder) - 1: | |
| x = layer(x) | |
| else: | |
| x = F.tanh(layer(x)) # reverse to tanh | |
| emb = x | |
| # import pdb; pdb.set_trace() | |
| emb_quant = self.vq.apply(emb) # vq | |
| return emb | |
| def decode(self, x): | |
| x = self.vq.apply( | |
| x | |
| ) # make sure the prediction follow the similar disctribution | |
| for i, layer in enumerate(self.decoder): | |
| x = layer(x) | |
| return x | |