| |
| |
|
|
| |
| |
| from torch import Tensor |
|
|
| from TorchJaekwon.Util.Util import Util |
| from TorchJaekwon.Util.UtilData import UtilData |
| from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec |
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| from torch.nn import Conv1d, ConvTranspose1d, Conv2d |
| from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm |
|
|
| import FlashSR.BigVGAN.activations as activations |
| from FlashSR.BigVGAN.utils import init_weights, get_padding |
| from FlashSR.BigVGAN.alias_free_torch import * |
|
|
| LRELU_SLOPE = 0.1 |
|
|
| class SRVocoder(torch.nn.Module): |
| def __init__(self, |
| num_mels = 256, |
| upsample_initial_channel = 1536, |
| resblock_kernel_sizes = [3, 7, 11], |
| resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| upsample_rates = [10, 6, 2, 2, 2], |
| upsample_kernel_sizes = None, |
| activation = 'snakebeta', |
| snake_logscale = True |
| ): |
| super(SRVocoder, self).__init__() |
| if upsample_kernel_sizes is None: |
| upsample_kernel_sizes = [upsample_rate * 2 for upsample_rate in upsample_rates] |
|
|
| self.audio_block = nn.ModuleDict() |
| self.audio_block["downsamples"] = nn.ModuleList() |
| self.audio_block["emb"] = Conv1d( 1, upsample_initial_channel // (2 ** len(upsample_rates)), 7, bias=True, padding=(7 - 1) // 2, ) |
| for i in reversed(range(len(upsample_kernel_sizes))): |
| self.audio_block["downsamples"] += [ |
| nn.Sequential( |
| nn.Conv1d( |
| upsample_initial_channel // (2 ** (i + 1)), |
| upsample_initial_channel // (2 ** i), |
| upsample_kernel_sizes[i], |
| upsample_rates[i], |
| padding=upsample_rates[i] - (upsample_kernel_sizes[i] % 2 == 0), |
| bias=True, |
| ), |
| nn.LeakyReLU(negative_slope = 0.1) |
| ) |
| ] |
|
|
| self.num_kernels = len(resblock_kernel_sizes) |
| self.num_upsamples = len(upsample_rates) |
|
|
| |
| self.conv_pre = weight_norm(Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)) |
|
|
| |
| resblock = AMPBlock1 |
|
|
| |
| self.ups = nn.ModuleList() |
| for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): |
| self.ups.append(nn.ModuleList([ |
| weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), |
| upsample_initial_channel // (2 ** (i + 1)), |
| k, u, padding=(k - u) // 2)) |
| ])) |
|
|
| |
| self.resblocks = nn.ModuleList() |
| for i in range(len(self.ups)): |
| ch = upsample_initial_channel // (2 ** (i + 1)) |
| for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): |
| self.resblocks.append(resblock(ch, k, d, activation=activation)) |
|
|
| |
| if activation == "snake": |
| activation_post = activations.Snake(ch, alpha_logscale=snake_logscale) |
| self.activation_post = Activation1d(activation=activation_post) |
| elif activation == "snakebeta": |
| activation_post = activations.SnakeBeta(ch, alpha_logscale=snake_logscale) |
| self.activation_post = Activation1d(activation=activation_post) |
| else: |
| raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") |
|
|
| self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) |
|
|
| |
| for i in range(len(self.ups)): |
| self.ups[i].apply(init_weights) |
| self.conv_post.apply(init_weights) |
| ''' |
| In audio sr |
| sampling_rate = 48000 |
| filter_length = 2048 |
| hop_length = 480 |
| win_length = 2048 |
| n_mel = 256 |
| mel_fmin = 20 |
| mel_fmax = 24000 |
| ''' |
|
|
| def forward(self, |
| mel_spec:Tensor, |
| lr_audio:Tensor, |
| ) -> Tensor: |
| |
| audio_emb:Tensor = self.audio_block["emb"](lr_audio.unsqueeze(1)) |
| audio_emb_list:list = [audio_emb] |
| for i in range(self.num_upsamples - 1): |
| audio_emb = self.audio_block["downsamples"][i](audio_emb) |
| audio_emb_list += [audio_emb] |
|
|
| |
| x = self.conv_pre(mel_spec) |
|
|
| for i in range(self.num_upsamples): |
| |
| for i_up in range(len(self.ups[i])): |
| x = self.ups[i][i_up](x) + audio_emb_list[-1-i] |
| |
| xs = None |
| for j in range(self.num_kernels): |
| if xs is None: |
| xs = self.resblocks[i * self.num_kernels + j](x) |
| else: |
| xs += self.resblocks[i * self.num_kernels + j](x) |
| x = xs / self.num_kernels |
|
|
| |
| x = self.activation_post(x) |
| x = self.conv_post(x) |
| x = torch.tanh(x).squeeze(1) |
|
|
| return {'pred_hr_audio': x } |
|
|
| def remove_weight_norm(self): |
| print('Removing weight norm...') |
| for l in self.ups: |
| for l_i in l: |
| remove_weight_norm(l_i) |
| for l in self.resblocks: |
| l.remove_weight_norm() |
| remove_weight_norm(self.conv_pre) |
| remove_weight_norm(self.conv_post) |
|
|
|
|
| class AMPBlock1(torch.nn.Module): |
| def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None, snake_logscale = 'snakebeta'): |
| super(AMPBlock1, self).__init__() |
|
|
| self.convs1 = nn.ModuleList([ |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], |
| padding=get_padding(kernel_size, dilation[0]))), |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], |
| padding=get_padding(kernel_size, dilation[1]))), |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], |
| padding=get_padding(kernel_size, dilation[2]))) |
| ]) |
| self.convs1.apply(init_weights) |
|
|
| self.convs2 = nn.ModuleList([ |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, |
| padding=get_padding(kernel_size, 1))), |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, |
| padding=get_padding(kernel_size, 1))), |
| weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, |
| padding=get_padding(kernel_size, 1))) |
| ]) |
| self.convs2.apply(init_weights) |
|
|
| self.num_layers = len(self.convs1) + len(self.convs2) |
|
|
| if activation == 'snake': |
| self.activations = nn.ModuleList([ |
| Activation1d( |
| activation=activations.Snake(channels, alpha_logscale=snake_logscale)) |
| for _ in range(self.num_layers) |
| ]) |
| elif activation == 'snakebeta': |
| self.activations = nn.ModuleList([ |
| Activation1d( |
| activation=activations.SnakeBeta(channels, alpha_logscale=snake_logscale)) |
| for _ in range(self.num_layers) |
| ]) |
| else: |
| raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") |
|
|
| def forward(self, x): |
| acts1, acts2 = self.activations[::2], self.activations[1::2] |
| for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): |
| xt = a1(x) |
| xt = c1(xt) |
| xt = a2(xt) |
| xt = c2(xt) |
| x = xt + x |
|
|
| return x |
|
|
| def remove_weight_norm(self): |
| for l in self.convs1: |
| remove_weight_norm(l) |
| for l in self.convs2: |
| remove_weight_norm(l) |
|
|