import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils import remove_weight_norm from torch.autograd.function import InplaceFunction def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) # Scripting this brings model speed up 1.4x @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x): return snake(x, self.alpha) class Conv1d(nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, padding_mode: str = "zeros", bias: bool = True, padding=None, causal: bool = False, w_init_gain=None, ): self.causal = causal if padding is None: if causal: padding = 0 self.left_padding = dilation * (kernel_size - 1) else: padding = get_padding(kernel_size, dilation) super(Conv1d, self).__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, padding_mode=padding_mode, bias=bias, ) if w_init_gain is not None: torch.nn.init.xavier_uniform_( self.weight, gain=torch.nn.init.calculate_gain(w_init_gain) ) def forward(self, x): if self.causal: x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2) return super(Conv1d, self).forward(x) class ConvTranspose1d(nn.ConvTranspose1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, output_padding: int = 0, groups: int = 1, bias: bool = True, dilation: int = 1, padding=None, padding_mode: str = "zeros", causal: bool = False, ): if padding is None: padding = 0 if causal else (kernel_size - stride) // 2 if causal: assert padding == 0, "padding is not allowed in causal ConvTranspose1d." assert ( kernel_size == 2 * stride ), "kernel_size must be equal to 2*stride is not allowed in causal ConvTranspose1d." super(ConvTranspose1d, self).__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, padding_mode=padding_mode, ) self.causal = causal self.stride = stride def forward(self, x): x = super(ConvTranspose1d, self).forward(x) if self.causal: x = x[:, :, : -self.stride] return x class PreProcessor(nn.Module): def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False): super(PreProcessor, self).__init__() self.pooling = torch.nn.AvgPool1d(kernel_size=num_samples) self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal) self.activation = nn.PReLU() def forward(self, x): output = self.activation(self.conv(x)) output = self.pooling(output) return output class PostProcessor(nn.Module): def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False): super(PostProcessor, self).__init__() self.num_samples = num_samples self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal) self.activation = nn.PReLU() def forward(self, x): x = torch.transpose(x, 1, 2) B, T, C = x.size() x = x.repeat(1, 1, self.num_samples).view(B, -1, C) x = torch.transpose(x, 1, 2) output = self.activation(self.conv(x)) return output class ResidualUnit(nn.Module): def __init__(self, n_in, n_out, dilation, res_kernel_size=7, causal=False): super(ResidualUnit, self).__init__() self.conv1 = weight_norm( Conv1d( n_in, n_out, kernel_size=res_kernel_size, dilation=dilation, causal=causal, ) ) self.conv2 = weight_norm(Conv1d(n_in, n_out, kernel_size=1, causal=causal)) self.activation1 = nn.PReLU() self.activation2 = nn.PReLU() def forward(self, x): output = self.activation1(self.conv1(x)) output = self.activation2(self.conv2(output)) return output + x class ResEncoderBlock(nn.Module): def __init__( self, n_in, n_out, stride, down_kernel_size, res_kernel_size=7, causal=False ): super(ResEncoderBlock, self).__init__() self.convs = nn.ModuleList( [ ResidualUnit( n_in, n_out // 2, dilation=1, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out // 2, n_out // 2, dilation=3, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out // 2, n_out // 2, dilation=5, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out // 2, n_out // 2, dilation=7, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out // 2, n_out // 2, dilation=9, res_kernel_size=res_kernel_size, causal=causal, ), ] ) self.down_conv = DownsampleLayer( n_in, n_out, down_kernel_size, stride=stride, causal=causal ) def forward(self, x): for conv in self.convs: x = conv(x) x = self.down_conv(x) return x class ResDecoderBlock(nn.Module): def __init__( self, n_in, n_out, stride, up_kernel_size, res_kernel_size=7, causal=False ): super(ResDecoderBlock, self).__init__() self.up_conv = UpsampleLayer( n_in, n_out, kernel_size=up_kernel_size, stride=stride, causal=causal, activation=None, ) self.convs = nn.ModuleList( [ ResidualUnit( n_out, n_out, dilation=1, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out, n_out, dilation=3, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out, n_out, dilation=5, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out, n_out, dilation=7, res_kernel_size=res_kernel_size, causal=causal, ), ResidualUnit( n_out, n_out, dilation=9, res_kernel_size=res_kernel_size, causal=causal, ), ] ) def forward(self, x): x = self.up_conv(x) for conv in self.convs: x = conv(x) return x class DownsampleLayer(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, causal: bool = False, activation=nn.PReLU(), use_weight_norm: bool = True, pooling: bool = False, ): super(DownsampleLayer, self).__init__() self.pooling = pooling self.stride = stride self.activation = nn.PReLU() self.use_weight_norm = use_weight_norm if pooling: self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal) self.pooling = nn.AvgPool1d(kernel_size=stride) else: self.layer = Conv1d( in_channels, out_channels, kernel_size, stride=stride, causal=causal ) if use_weight_norm: self.layer = weight_norm(self.layer) def forward(self, x): x = self.layer(x) x = self.activation(x) if self.activation is not None else x if self.pooling: x = self.pooling(x) return x def remove_weight_norm(self): if self.use_weight_norm: remove_weight_norm(self.layer) class UpsampleLayer(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, causal: bool = False, activation=nn.PReLU(), use_weight_norm: bool = True, repeat: bool = False, ): super(UpsampleLayer, self).__init__() self.repeat = repeat self.stride = stride self.activation = activation self.use_weight_norm = use_weight_norm if repeat: self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal) else: self.layer = ConvTranspose1d( in_channels, out_channels, kernel_size, stride=stride, causal=causal ) if use_weight_norm: self.layer = weight_norm(self.layer) def forward(self, x): x = self.layer(x) x = self.activation(x) if self.activation is not None else x if self.repeat: x = torch.transpose(x, 1, 2) B, T, C = x.size() x = x.repeat(1, 1, self.stride).view(B, -1, C) x = torch.transpose(x, 1, 2) return x def remove_weight_norm(self): if self.use_weight_norm: remove_weight_norm(self.layer) class round_func9(InplaceFunction): @staticmethod def forward(ctx, input): ctx.input = input return torch.round(9 * input) / 9 @staticmethod def backward(ctx, grad_output): grad_input = grad_output.clone() return grad_input class ScalarModel(nn.Module): def __init__( self, num_bands, sample_rate, causal, num_samples, downsample_factors, downsample_kernel_sizes, upsample_factors, upsample_kernel_sizes, latent_hidden_dim, default_kernel_size, delay_kernel_size, init_channel, res_kernel_size, mode="pre_proj", ): super(ScalarModel, self).__init__() # self.args = args self.encoder = [] self.decoder = [] self.vq = round_func9() # using 9 self.mode = mode # Encoder parts self.encoder.append( weight_norm( Conv1d( num_bands, init_channel, kernel_size=default_kernel_size, causal=causal, ) ) ) if num_samples > 1: # Downsampling self.encoder.append( PreProcessor( init_channel, init_channel, num_samples, kernel_size=default_kernel_size, causal=causal, ) ) for i, down_factor in enumerate(downsample_factors): self.encoder.append( ResEncoderBlock( init_channel * np.power(2, i), init_channel * np.power(2, i + 1), down_factor, downsample_kernel_sizes[i], res_kernel_size, causal=causal, ) ) self.encoder.append( weight_norm( Conv1d( init_channel * np.power(2, len(downsample_factors)), latent_hidden_dim, kernel_size=default_kernel_size, causal=causal, ) ) ) # Decoder # look ahead self.decoder.append( weight_norm( Conv1d( latent_hidden_dim, init_channel * np.power(2, len(upsample_factors)), kernel_size=delay_kernel_size, ) ) ) for i, upsample_factor in enumerate(upsample_factors): self.decoder.append( ResDecoderBlock( init_channel * np.power(2, len(upsample_factors) - i), init_channel * np.power(2, len(upsample_factors) - i - 1), upsample_factor, upsample_kernel_sizes[i], res_kernel_size, causal=causal, ) ) if num_samples > 1: self.decoder.append( PostProcessor( init_channel, init_channel, num_samples, kernel_size=default_kernel_size, causal=causal, ) ) self.decoder.append( weight_norm( Conv1d( init_channel, num_bands, kernel_size=default_kernel_size, causal=causal, ) ) ) self.encoder = nn.ModuleList(self.encoder) self.decoder = nn.ModuleList(self.decoder) def forward(self, x): for i, layer in enumerate(self.encoder): if i != len(self.encoder) - 1: x = layer(x) else: x = F.tanh(layer(x)) # import pdb; pdb.set_trace() x = self.vq.apply(x) # vq for i, layer in enumerate(self.decoder): x = layer(x) return x def inference(self, x): for i, layer in enumerate(self.encoder): if i != len(self.encoder) - 1: x = layer(x) else: x = F.tanh(layer(x)) # reverse to tanh emb = x # import pdb; pdb.set_trace() emb_quant = self.vq.apply(emb) # vq x = emb_quant for i, layer in enumerate(self.decoder): x = layer(x) return emb, emb_quant, x def encode(self, x): for i, layer in enumerate(self.encoder): if i != len(self.encoder) - 1: x = layer(x) else: x = F.tanh(layer(x)) # reverse to tanh emb = x # import pdb; pdb.set_trace() emb_quant = self.vq.apply(emb) # vq return emb def decode(self, x): x = self.vq.apply( x ) # make sure the prediction follow the similar disctribution for i, layer in enumerate(self.decoder): x = layer(x) return x