Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class ResBlock(nn.Module): | |
| def __init__(self, dims): | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) | |
| self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) | |
| self.batch_norm1 = nn.BatchNorm1d(dims) | |
| self.batch_norm2 = nn.BatchNorm1d(dims) | |
| def forward(self, x): | |
| residual = x | |
| x = self.conv1(x) | |
| x = self.batch_norm1(x) | |
| x = F.relu(x) | |
| x = self.conv2(x) | |
| x = self.batch_norm2(x) | |
| x = x + residual | |
| return x | |
| class MelResNet(nn.Module): | |
| def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): | |
| super().__init__() | |
| kernel_size = pad * 2 + 1 | |
| self.conv_in = nn.Conv1d( | |
| in_dims, compute_dims, kernel_size=kernel_size, bias=False | |
| ) | |
| self.batch_norm = nn.BatchNorm1d(compute_dims) | |
| self.layers = nn.ModuleList() | |
| for i in range(res_blocks): | |
| self.layers.append(ResBlock(compute_dims)) | |
| self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) | |
| def forward(self, x): | |
| x = self.conv_in(x) | |
| x = self.batch_norm(x) | |
| x = F.relu(x) | |
| for f in self.layers: | |
| x = f(x) | |
| x = self.conv_out(x) | |
| return x | |
| class Stretch2d(nn.Module): | |
| def __init__(self, x_scale, y_scale): | |
| super().__init__() | |
| self.x_scale = x_scale | |
| self.y_scale = y_scale | |
| def forward(self, x): | |
| b, c, h, w = x.size() | |
| x = x.unsqueeze(-1).unsqueeze(3) | |
| x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) | |
| return x.view(b, c, h * self.y_scale, w * self.x_scale) | |
| class UpsampleNetwork(nn.Module): | |
| def __init__( | |
| self, feat_dims, upsample_scales, compute_dims, res_blocks, res_out_dims, pad | |
| ): | |
| super().__init__() | |
| total_scale = np.cumproduct(upsample_scales)[-1] | |
| self.indent = pad * total_scale | |
| self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) | |
| self.resnet_stretch = Stretch2d(total_scale, 1) | |
| self.up_layers = nn.ModuleList() | |
| for scale in upsample_scales: | |
| kernel_size = (1, scale * 2 + 1) | |
| padding = (0, scale) | |
| stretch = Stretch2d(scale, 1) | |
| conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) | |
| conv.weight.data.fill_(1.0 / kernel_size[1]) | |
| self.up_layers.append(stretch) | |
| self.up_layers.append(conv) | |
| def forward(self, m): | |
| aux = self.resnet(m).unsqueeze(1) | |
| aux = self.resnet_stretch(aux) | |
| aux = aux.squeeze(1) | |
| m = m.unsqueeze(1) | |
| for f in self.up_layers: | |
| m = f(m) | |
| m = m.squeeze(1)[:, :, self.indent : -self.indent] | |
| return m.transpose(1, 2), aux.transpose(1, 2) | |
| class WaveRNN(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.pad = self.cfg.VOCODER.MEL_FRAME_PAD | |
| if self.cfg.VOCODER.MODE == "mu_law_quantize": | |
| self.n_classes = 2**self.cfg.VOCODER.BITS | |
| elif self.cfg.VOCODER.MODE == "mu_law" or self.cfg.VOCODER: | |
| self.n_classes = 30 | |
| self._to_flatten = [] | |
| self.rnn_dims = self.cfg.VOCODER.RNN_DIMS | |
| self.aux_dims = self.cfg.VOCODER.RES_OUT_DIMS // 4 | |
| self.hop_length = self.cfg.VOCODER.HOP_LENGTH | |
| self.fc_dims = self.cfg.VOCODER.FC_DIMS | |
| self.upsample_factors = self.cfg.VOCODER.UPSAMPLE_FACTORS | |
| self.feat_dims = self.cfg.VOCODER.INPUT_DIM | |
| self.compute_dims = self.cfg.VOCODER.COMPUTE_DIMS | |
| self.res_out_dims = self.cfg.VOCODER.RES_OUT_DIMS | |
| self.res_blocks = self.cfg.VOCODER.RES_BLOCKS | |
| self.upsample = UpsampleNetwork( | |
| self.feat_dims, | |
| self.upsample_factors, | |
| self.compute_dims, | |
| self.res_blocks, | |
| self.res_out_dims, | |
| self.pad, | |
| ) | |
| self.I = nn.Linear(self.feat_dims + self.aux_dims + 1, self.rnn_dims) | |
| self.rnn1 = nn.GRU(self.rnn_dims, self.rnn_dims, batch_first=True) | |
| self.rnn2 = nn.GRU( | |
| self.rnn_dims + self.aux_dims, self.rnn_dims, batch_first=True | |
| ) | |
| self._to_flatten += [self.rnn1, self.rnn2] | |
| self.fc1 = nn.Linear(self.rnn_dims + self.aux_dims, self.fc_dims) | |
| self.fc2 = nn.Linear(self.fc_dims + self.aux_dims, self.fc_dims) | |
| self.fc3 = nn.Linear(self.fc_dims, self.n_classes) | |
| self.num_params() | |
| self._flatten_parameters() | |
| def forward(self, x, mels): | |
| device = next(self.parameters()).device | |
| self._flatten_parameters() | |
| batch_size = x.size(0) | |
| h1 = torch.zeros(1, batch_size, self.rnn_dims, device=device) | |
| h2 = torch.zeros(1, batch_size, self.rnn_dims, device=device) | |
| mels, aux = self.upsample(mels) | |
| aux_idx = [self.aux_dims * i for i in range(5)] | |
| a1 = aux[:, :, aux_idx[0] : aux_idx[1]] | |
| a2 = aux[:, :, aux_idx[1] : aux_idx[2]] | |
| a3 = aux[:, :, aux_idx[2] : aux_idx[3]] | |
| a4 = aux[:, :, aux_idx[3] : aux_idx[4]] | |
| x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) | |
| x = self.I(x) | |
| res = x | |
| x, _ = self.rnn1(x, h1) | |
| x = x + res | |
| res = x | |
| x = torch.cat([x, a2], dim=2) | |
| x, _ = self.rnn2(x, h2) | |
| x = x + res | |
| x = torch.cat([x, a3], dim=2) | |
| x = F.relu(self.fc1(x)) | |
| x = torch.cat([x, a4], dim=2) | |
| x = F.relu(self.fc2(x)) | |
| return self.fc3(x) | |
| def num_params(self, print_out=True): | |
| parameters = filter(lambda p: p.requires_grad, self.parameters()) | |
| parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 | |
| if print_out: | |
| print("Trainable Parameters: %.3fM" % parameters) | |
| return parameters | |
| def _flatten_parameters(self): | |
| [m.flatten_parameters() for m in self._to_flatten] | |