| import torch.nn as nn |
| import math |
| from src.models.utils import capture_init, weights_init |
| from src.models.modules import WNConv1d, WNConvTranspose1d |
| from torchaudio.functional import resample |
| from torch.nn import functional as F |
|
|
| class ResnetBlock(nn.Module): |
| def __init__(self, dim, dilation=1): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.LeakyReLU(0.2), |
| nn.ReflectionPad1d(dilation), |
| WNConv1d(dim, dim, kernel_size=3, dilation=dilation), |
| nn.LeakyReLU(0.2), |
| WNConv1d(dim, dim, kernel_size=1), |
| ) |
| self.shortcut = WNConv1d(dim, dim, kernel_size=1) |
|
|
| def forward(self, x): |
| return self.shortcut(x) + self.block(x) |
|
|
|
|
| class Seanet(nn.Module): |
|
|
| @capture_init |
| def __init__(self, |
| latent_space_size=128, |
| ngf=32, n_residual_layers=3, |
| resample=1, |
| normalize=True, |
| floor=1e-3, |
| ratios=[8, 8, 2, 2], |
| in_channels=1, |
| out_channels=1, |
| lr_sr=16000, |
| hr_sr=16000, |
| upsample=True): |
| super().__init__() |
|
|
| self.resample = resample |
| self.normalize = normalize |
| self.floor = floor |
| self.lr_sr = lr_sr |
| self.hr_sr = hr_sr |
| self.scale_factor = int(self.hr_sr / self.lr_sr) |
| self.upsample = upsample |
|
|
| self.encoder = nn.ModuleList() |
| self.decoder = nn.ModuleList() |
|
|
| self.ratios = ratios |
| mult = int(2 ** len(ratios)) |
|
|
| decoder_wrapper_conv_layer = [ |
| nn.LeakyReLU(0.2), |
| nn.ReflectionPad1d(3), |
| WNConv1d(latent_space_size, mult * ngf, kernel_size=7, padding=0), |
| ] |
|
|
| encoder_wrapper_conv_layer = [ |
| nn.LeakyReLU(0.2), |
| nn.ReflectionPad1d(3), |
| WNConv1d(mult * ngf, latent_space_size, kernel_size=7, padding=0) |
| ] |
|
|
| self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer)) |
| self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer)) |
|
|
| for i, r in enumerate(ratios): |
| encoder_block = [ |
| nn.LeakyReLU(0.2), |
| WNConv1d(mult * ngf // 2, |
| mult * ngf, |
| kernel_size=r * 2, |
| stride=r, |
| padding=r // 2 + r % 2, |
| ), |
| ] |
|
|
| decoder_block = [ |
| nn.LeakyReLU(0.2), |
| WNConvTranspose1d( |
| mult * ngf, |
| mult * ngf // 2, |
| kernel_size=r * 2, |
| stride=r, |
| padding=r // 2 + r % 2, |
| output_padding=r % 2, |
| ), |
| ] |
|
|
| for j in range(n_residual_layers - 1, -1, -1): |
| encoder_block = [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] + encoder_block |
|
|
| for j in range(n_residual_layers): |
| decoder_block += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] |
|
|
| mult //= 2 |
|
|
| self.encoder.insert(0, nn.Sequential(*encoder_block)) |
| self.decoder.append(nn.Sequential(*decoder_block)) |
|
|
| encoder_wrapper_conv_layer = [ |
| nn.ReflectionPad1d(3), |
| WNConv1d(in_channels, ngf, kernel_size=7, padding=0), |
| nn.Tanh(), |
| ] |
| self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer)) |
|
|
| decoder_wrapper_conv_layer = [ |
| nn.LeakyReLU(0.2), |
| nn.ReflectionPad1d(3), |
| WNConv1d(ngf, out_channels, kernel_size=7, padding=0), |
| nn.Tanh(), |
| ] |
| self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer)) |
|
|
| self.apply(weights_init) |
|
|
| def estimate_output_length(self, length): |
| """ |
| Return the nearest valid length to use with the model so that |
| there is no time steps left over in a convolutions, e.g. for all |
| layers, size of the input - kernel_size % stride = 0. |
| |
| If the mixture has a valid length, the estimated sources |
| will have exactly the same length. |
| """ |
| depth = len(self.ratios) |
| for idx in range(depth - 1, -1, -1): |
| stride = self.ratios[idx] |
| kernel_size = 2 * stride |
| padding = stride // 2 + stride % 2 |
| length = math.ceil((length - kernel_size + 2 * padding) / stride) + 1 |
| length = max(length, 1) |
| for idx in range(depth): |
| stride = self.ratios[idx] |
| kernel_size = 2 * stride |
| padding = stride // 2 + stride % 2 |
| output_padding = stride % 2 |
| length = (length - 1) * stride + kernel_size - 2 * padding + output_padding |
| return int(length) |
|
|
| def pad_to_valid_length(self, signal): |
| valid_length = self.estimate_output_length(signal.shape[-1]) |
| padding_len = valid_length - signal.shape[-1] |
| signal = F.pad(signal, (0, padding_len)) |
| return signal, padding_len |
|
|
| def forward(self, signal): |
|
|
| target_len = signal.shape[-1] |
| if self.upsample: |
| target_len *= self.scale_factor |
| if self.normalize: |
| mono = signal.mean(dim=1, keepdim=True) |
| std = mono.std(dim=-1, keepdim=True) |
| signal = signal / (self.floor + std) |
| else: |
| std = 1 |
| x = signal |
| if self.upsample: |
| x = resample(x, self.lr_sr, self.hr_sr) |
|
|
| x, padding_len = self.pad_to_valid_length(x) |
| skips = [] |
| for i, encode in enumerate(self.encoder): |
| skips.append(x) |
| x = encode(x) |
| for j, decode in enumerate(self.decoder): |
| x = decode(x) |
| skip = skips.pop(-1) |
| x = x + skip |
| if target_len < x.shape[-1]: |
| x = x[..., :target_len] |
| return std * x |
|
|