| | import torch |
| | from torch import nn |
| | from torch.nn.utils import weight_norm |
| |
|
| | from TTS.utils.io import load_fsspec |
| | from TTS.vocoder.layers.melgan import ResidualStack |
| |
|
| |
|
| | class MelganGenerator(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels=80, |
| | out_channels=1, |
| | proj_kernel=7, |
| | base_channels=512, |
| | upsample_factors=(8, 8, 2, 2), |
| | res_kernel=3, |
| | num_res_blocks=3, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." |
| |
|
| | |
| | base_padding = (proj_kernel - 1) // 2 |
| | act_slope = 0.2 |
| | self.inference_padding = 2 |
| |
|
| | |
| | layers = [] |
| | layers += [ |
| | nn.ReflectionPad1d(base_padding), |
| | weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), |
| | ] |
| |
|
| | |
| | for idx, upsample_factor in enumerate(upsample_factors): |
| | layer_in_channels = base_channels // (2**idx) |
| | layer_out_channels = base_channels // (2 ** (idx + 1)) |
| | layer_filter_size = upsample_factor * 2 |
| | layer_stride = upsample_factor |
| | layer_output_padding = upsample_factor % 2 |
| | layer_padding = upsample_factor // 2 + layer_output_padding |
| | layers += [ |
| | nn.LeakyReLU(act_slope), |
| | weight_norm( |
| | nn.ConvTranspose1d( |
| | layer_in_channels, |
| | layer_out_channels, |
| | layer_filter_size, |
| | stride=layer_stride, |
| | padding=layer_padding, |
| | output_padding=layer_output_padding, |
| | bias=True, |
| | ) |
| | ), |
| | ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), |
| | ] |
| |
|
| | layers += [nn.LeakyReLU(act_slope)] |
| |
|
| | |
| | layers += [ |
| | nn.ReflectionPad1d(base_padding), |
| | weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), |
| | nn.Tanh(), |
| | ] |
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, c): |
| | return self.layers(c) |
| |
|
| | def inference(self, c): |
| | c = c.to(self.layers[1].weight.device) |
| | c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") |
| | return self.layers(c) |
| |
|
| | def remove_weight_norm(self): |
| | for _, layer in enumerate(self.layers): |
| | if len(layer.state_dict()) != 0: |
| | try: |
| | nn.utils.remove_weight_norm(layer) |
| | except ValueError: |
| | layer.remove_weight_norm() |
| |
|
| | def load_checkpoint( |
| | self, config, checkpoint_path, eval=False, cache=False |
| | ): |
| | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) |
| | self.load_state_dict(state["model"]) |
| | if eval: |
| | self.eval() |
| | assert not self.training |
| | self.remove_weight_norm() |
| |
|