ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
from torch import Tensor
from TorchJaekwon.Util.Util import Util
from TorchJaekwon.Util.UtilData import UtilData
from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec
#from easydict import EasyDict
#Util.set_sys_path_to_parent_dir(__file__, depth_to_dir_from_file=2)
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
import FlashSR.BigVGAN.activations as activations
from FlashSR.BigVGAN.utils import init_weights, get_padding
from FlashSR.BigVGAN.alias_free_torch import *
LRELU_SLOPE = 0.1
class SRVocoder(torch.nn.Module):
def __init__(self,
num_mels = 256,
upsample_initial_channel = 1536,
resblock_kernel_sizes = [3, 7, 11],
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates = [10, 6, 2, 2, 2], #[4, 4, 2, 2, 2, 2], upsample_rates = [5, 4, 3, 2, 2, 2], #[4, 4, 2, 2, 2, 2],
upsample_kernel_sizes = None, # upsample_kernel_sizes = [7,8,7,4,4,4],
activation = 'snakebeta',
snake_logscale = True
):
super(SRVocoder, self).__init__()
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [upsample_rate * 2 for upsample_rate in upsample_rates]
self.audio_block = nn.ModuleDict()
self.audio_block["downsamples"] = nn.ModuleList()
self.audio_block["emb"] = Conv1d( 1, upsample_initial_channel // (2 ** len(upsample_rates)), 7, bias=True, padding=(7 - 1) // 2, )
for i in reversed(range(len(upsample_kernel_sizes))):
self.audio_block["downsamples"] += [
nn.Sequential(
nn.Conv1d(
upsample_initial_channel // (2 ** (i + 1)),
upsample_initial_channel // (2 ** i),
upsample_kernel_sizes[i],
upsample_rates[i],
padding=upsample_rates[i] - (upsample_kernel_sizes[i] % 2 == 0),
bias=True,
),
nn.LeakyReLU(negative_slope = 0.1)
)
]
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
# pre conv
self.conv_pre = weight_norm(Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3))
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(nn.ModuleList([
weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i),
upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2))
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d, activation=activation))
# post conv
if activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
# weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
'''
In audio sr
sampling_rate = 48000
filter_length = 2048
hop_length = 480
win_length = 2048
n_mel = 256
mel_fmin = 20
mel_fmax = 24000
'''
def forward(self,
mel_spec:Tensor, #[batch, mel_size, time//hop]
lr_audio:Tensor, #[batch, time]
) -> Tensor: #[batch, time]
audio_emb:Tensor = self.audio_block["emb"](lr_audio.unsqueeze(1))
audio_emb_list:list = [audio_emb]
for i in range(self.num_upsamples - 1):
audio_emb = self.audio_block["downsamples"][i](audio_emb)
audio_emb_list += [audio_emb]
# pre conv
x = self.conv_pre(mel_spec)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x) + audio_emb_list[-1-i]
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x).squeeze(1)
return {'pred_hr_audio': x }
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class AMPBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None, snake_logscale = 'snakebeta'):
super(AMPBlock1, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)