Spaces:
Configuration error
Configuration error
| #! /usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # vim:fenc=utf-8 | |
| # | |
| # Copyright (c) 2021 Kazuhiro KOBAYASHI <root.4mac@gmail.com> | |
| # | |
| # Distributed under terms of the MIT license. | |
| """ | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ConvLayers(nn.Module): | |
| in_channels = 80 | |
| conv_channels = 256 | |
| out_channels = 222 | |
| kernel_size = 3 | |
| dilation_size = 1 | |
| group_size = 8 | |
| n_conv_layers = 3 | |
| use_causal = False | |
| conv_type = "original" | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| for k, v in kwargs.items(): | |
| if k not in self.__class__.__dict__.keys(): | |
| raise ValueError(f"{k} not in arguments {self.__class__}.") | |
| setattr(self, k, v) | |
| if self.conv_type == "ddsconv": | |
| self.conv_layers = self.ddsconv() | |
| elif self.conv_type == "original": | |
| self.conv_layers = self.original_conv() | |
| else: | |
| raise ValueError(f"Unsupported conv_type: {self.conv_type}") | |
| def forward(self, x): | |
| """ | |
| x: (B, T, in_channels) | |
| y: (B, T, out_channels) | |
| """ | |
| return self.conv_layers(x) | |
| def original_conv(self): | |
| modules = [] | |
| modules += [ | |
| Conv1d( | |
| self.in_channels, | |
| self.conv_channels, | |
| self.kernel_size, | |
| self.dilation_size, | |
| 1, | |
| self.use_causal, | |
| ), | |
| nn.ReLU(), | |
| ] | |
| for i in range(self.n_conv_layers): | |
| modules += [ | |
| Conv1d( | |
| self.conv_channels, | |
| self.conv_channels, | |
| self.kernel_size, | |
| self.dilation_size, | |
| self.group_size, | |
| self.use_causal, | |
| ), | |
| nn.ReLU(), | |
| ] | |
| modules += [ | |
| Conv1d( | |
| self.conv_channels, | |
| self.out_channels, | |
| self.kernel_size, | |
| self.dilation_size, | |
| 1, | |
| self.use_causal, | |
| ), | |
| ] | |
| return nn.Sequential(*modules) | |
| def ddsconv(self): | |
| modules = [] | |
| modules += [ | |
| Conv1d( | |
| in_channels=self.in_channels, | |
| out_channels=self.conv_channels, | |
| kernel_size=1, | |
| dilation_size=1, | |
| group_size=1, | |
| use_causal=self.use_causal, | |
| ) | |
| ] | |
| for i in range(self.n_conv_layers): | |
| if self.dilation_size == 1: | |
| dilation_size = self.kernel_size ** i | |
| else: | |
| dilation_size = self.dilation_size ** i | |
| modules += [ | |
| DepthSeparableConv1d( | |
| channels=self.conv_channels, | |
| kernel_size=self.kernel_size, | |
| dilation_size=dilation_size, | |
| use_causal=self.use_causal, | |
| ) | |
| ] | |
| modules += [ | |
| Conv1d( | |
| in_channels=self.conv_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=1, | |
| dilation_size=1, | |
| group_size=1, | |
| use_causal=self.use_causal, | |
| ) | |
| ] | |
| return nn.Sequential(*modules) | |
| def remove_weight_norm(self): | |
| def _remove_weight_norm(m): | |
| try: | |
| torch.nn.utils.remove_weight_norm(m) | |
| except ValueError: | |
| return | |
| self.apply(_remove_weight_norm) | |
| class Conv1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| dilation_size=1, | |
| group_size=1, | |
| use_causal=False, | |
| ): | |
| super().__init__() | |
| self.use_causal = use_causal | |
| self.kernel_size = kernel_size | |
| self.padding = (kernel_size - 1) * dilation_size | |
| self.conv1d = nn.Conv1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| padding=self.padding, | |
| dilation=dilation_size, | |
| groups=group_size, | |
| ) | |
| nn.init.kaiming_normal_(self.conv1d.weight) | |
| def forward(self, x): | |
| """ | |
| x: (B, T, D) | |
| y: (B, T, D) | |
| """ | |
| x = x.transpose(1, 2) | |
| y = self.conv1d(x) | |
| # NOTE(k2kobayashi): kernel_size=1 does not discard padding | |
| if self.kernel_size > 1 and self.conv1d.padding != (0,): | |
| if self.use_causal: | |
| y = y[..., : -self.padding] | |
| else: | |
| y = y[..., self.padding // 2 : -self.padding // 2] | |
| return y.transpose(1, 2) | |
| class DepthSeparableConv1d(nn.Module): | |
| def __init__(self, channels, kernel_size, dilation_size, use_causal=False): | |
| super().__init__() | |
| sep_conv = Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| dilation_size, | |
| group_size=channels, | |
| use_causal=use_causal, | |
| ) | |
| conv1d = Conv1d( | |
| channels, | |
| channels, | |
| kernel_size=1, | |
| dilation_size=1, | |
| group_size=1, | |
| use_causal=use_causal, | |
| ) | |
| ln1 = nn.LayerNorm(channels) | |
| ln2 = nn.LayerNorm(channels) | |
| gelu1 = nn.GELU() | |
| gelu2 = nn.GELU() | |
| modules = [sep_conv, ln1, gelu1, conv1d, ln2, gelu2] | |
| self.layers = nn.Sequential(*modules) | |
| def forward(self, x): | |
| y = self.layers(x) | |
| return x + y | |
| class DFTLayer(nn.Module): | |
| def __init__(self, n_fft=1024): | |
| super().__init__() | |
| self.n_fft = n_fft | |
| wsin, wcos = self._generate_fourier_kernels(n_fft=n_fft) | |
| self.register_buffer( | |
| "wsin", torch.tensor(wsin, dtype=torch.float), persistent=False | |
| ) | |
| self.register_buffer( | |
| "wcos", torch.tensor(wcos, dtype=torch.float), persistent=False | |
| ) | |
| def _generate_fourier_kernels(n_fft, window="hann"): | |
| freq_bins = n_fft | |
| s = np.arange(0, n_fft, 1.0) | |
| wsin = np.empty((freq_bins, 1, n_fft)) | |
| wcos = np.empty((freq_bins, 1, n_fft)) | |
| for k in range(freq_bins): | |
| wsin[k, 0, :] = np.sin(2 * np.pi * k * s / n_fft) | |
| wcos[k, 0, :] = np.cos(2 * np.pi * k * s / n_fft) | |
| return wsin.astype(np.float32), wcos.astype(np.float32) | |
| def forward(self, x, imag=None, inverse=False): | |
| if not inverse: | |
| return self.dft(x) | |
| else: | |
| return self.idft(x, imag) | |
| def dft(self, real): | |
| real = real.transpose(0, 1) | |
| imag = F.conv1d(real, self.wsin, stride=self.n_fft + 1).permute(2, 0, 1) | |
| real = F.conv1d(real, self.wcos, stride=self.n_fft + 1).permute(2, 0, 1) | |
| return real, -imag | |
| def idft(self, real, imag): | |
| real = real.transpose(0, 1) | |
| imag = imag.transpose(0, 1) | |
| a1 = F.conv1d(real, self.wcos, stride=self.n_fft + 1) | |
| a2 = F.conv1d(real, self.wsin, stride=self.n_fft + 1) | |
| b1 = F.conv1d(imag, self.wcos, stride=self.n_fft + 1) | |
| b2 = F.conv1d(imag, self.wsin, stride=self.n_fft + 1) | |
| imag = a2 + b1 | |
| real = a1 - b2 | |
| return ( | |
| (real / self.n_fft).permute(2, 0, 1), | |
| (imag / self.n_fft).permute(2, 0, 1), | |
| ) | |