camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
from nemo.collections.tts.modules.submodules import Invertible1x1Conv, WaveNet
from nemo.collections.tts.parts.utils.helpers import OperationMode, remove, split_view
from nemo.core.classes import Exportable, NeuralModule, typecheck
from nemo.core.neural_types.elements import (
AudioSignal,
IntType,
MelSpectrogramType,
NormalDistributionSamplesType,
VoidType,
)
from nemo.core.neural_types.neural_type import NeuralType
class WaveGlowModule(NeuralModule, Exportable):
def __init__(
self,
n_mel_channels: int,
n_flows: int,
n_group: int,
n_early_every: int,
n_early_size: int,
n_wn_channels: int,
n_wn_layers: int,
wn_kernel_size: int,
):
"""
WaveGlow module
Args:
n_mel_channels (int): Number of mel channels to output.
n_flows (int): Number of flow layers
n_group (int): Number of groups to respace the inputs
n_early_every (int): Every n_early_every layers, n_early_size gets skip connected to the output
n_early_size (int): The size of the chunk to be skip connected
n_wn_channels (int): Number of channels for the non-invertible wavenet transformation
n_wn_layers (int): Number of layers for the non-invertible wavenet transformation
wn_kernel_size (int): Kernel size for the non-invertible wavenet transformation
"""
super().__init__()
self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, n_mel_channels, 1024, stride=256)
self.n_mel_channels = n_mel_channels
assert n_group % 2 == 0
self.n_flows = n_flows
self.n_group = n_group
self.n_early_every = n_early_every
self.n_early_size = n_early_size
self.wavenet = torch.nn.ModuleList()
self.convinv = torch.nn.ModuleList()
self.mode = OperationMode.infer
n_half = n_group // 2
# Set up layers with the right sizes based on how many dimensions
# have been output already
n_remaining_channels = n_group
for k in range(n_flows):
if k % self.n_early_every == 0 and k > 0:
n_half = n_half - int(self.n_early_size / 2)
n_remaining_channels = n_remaining_channels - self.n_early_size
self.convinv.append(Invertible1x1Conv(n_remaining_channels))
self.wavenet.append(
WaveNet(
n_half,
n_mel_channels * n_group,
n_layers=n_wn_layers,
n_channels=n_wn_channels,
kernel_size=wn_kernel_size,
)
)
self.n_remaining_channels = n_remaining_channels
self.time_cutoff = self.upsample.stride[0] - self.upsample.kernel_size[0]
# Pre-calculating the sizes of noise to use so it's not dynamic
n_halves = []
n_half = self.n_remaining_channels // 2
for k in reversed(range(self.n_flows)):
n_halves.append(n_half)
if k % self.n_early_every == 0 and k > 0:
n_half = n_half + int(self.n_early_size / 2)
n_halves.reverse()
self.n_halves = n_halves
self.removed_weightnorm = False
def _prepare_for_export(self, **kwargs):
"""
Override this method to prepare module for export. This is in-place operation.
Base version does common necessary module replacements (Apex etc)
"""
self.remove_weightnorm()
super()._prepare_for_export(**kwargs)
@typecheck()
def forward(self, spec, z=None, audio=None, run_inverse=True, sigma=1.0):
""" TODO
"""
if self.training and self.mode != OperationMode.training:
raise ValueError(f"{self} has self.training set to True but self.OperationMode was not set to training")
if not self.training and self.mode == OperationMode.training:
raise ValueError(f"{self} has self.training set to False but self.OperationMode was set to training")
audio_pred = torch.zeros((1, 1))
if audio is not None and self.mode != OperationMode.infer:
# audio_to_normal_dist is used to calculate loss so only run this in train or val model
z1, log_s_list, log_det_W_list = self.audio_to_normal_dist(spec=spec, audio=audio)
if run_inverse:
# norm_dist_to_audio is used to predict audio from spectrogram so only used in val or infer mode
# Could also log train audio but currently not done
audio_pred = self.norm_dist_to_audio(spec=spec, sigma=sigma, z=z)
# Return the necessary tensors
if self.mode == OperationMode.training or self.mode == OperationMode.validation:
return z1, log_s_list, log_det_W_list, audio_pred
return audio_pred
@property
def input_types(self):
if self.mode == OperationMode.infer:
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"z": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True),
"sigma": NeuralType(optional=True),
}
else:
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"z": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True),
"audio": NeuralType(('B', 'T'), AudioSignal(), optional=True),
"run_inverse": NeuralType(elements_type=IntType(), optional=True),
"sigma": NeuralType(optional=True),
}
@property
def output_types(self):
if self.mode == OperationMode.training or self.mode == OperationMode.validation:
return {
"pred_normal_dist": NeuralType(('B', 'flowgroup', 'T'), NormalDistributionSamplesType()),
"log_s_list": [NeuralType(('B', 'flowgroup', 'T'), VoidType())], # TODO: Figure out a good typing
"log_det_W_list": [NeuralType(elements_type=VoidType())], # TODO: Figure out a good typing
"audio_pred": NeuralType(('B', 'T'), AudioSignal()),
}
else:
return {
"audio": NeuralType(('B', 'T'), AudioSignal()),
}
def input_example(self, max_batch=1, max_dim=256):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
par = next(self.parameters())
mel = torch.randn((max_batch, self.n_mel_channels, max_dim), device=par.device, dtype=par.dtype)
z = torch.randn(
(max_batch, self.n_mel_channels, max_dim * self.upsample.stride[0] // self.n_group),
device=par.device,
dtype=par.dtype,
)
return {"spec": mel, "z": z}
def audio_to_normal_dist(self, *, spec: torch.Tensor, audio: torch.Tensor) -> Tuple[torch.Tensor, list, list]:
# Upsample spectrogram to size of audio
spec = self.upsample(spec)
assert spec.size(2) >= audio.size(1)
if spec.size(2) > audio.size(1):
spec = spec[:, :, : audio.size(1)]
# logging.debug(f"spec: {spec.shape}. n_group: {self.n_group}")
spec = split_view(spec, self.n_group, 2).permute(0, 2, 1, 3)
spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
spec = spec.permute(0, 2, 1)
audio = split_view(audio, self.n_group, 1).permute(0, 2, 1)
output_audio = []
log_s_list = []
log_det_W_list = []
for k in range(self.n_flows):
if k % self.n_early_every == 0 and k > 0:
output_audio.append(audio[:, : self.n_early_size, :])
audio = audio[:, self.n_early_size :, :]
audio, log_det_W = self.convinv[k](audio)
log_det_W_list.append(log_det_W)
n_half = int(audio.size(1) / 2)
audio_0 = audio[:, :n_half, :]
audio_1 = audio[:, n_half:, :]
output = self.wavenet[k]((audio_0, spec))
log_s = output[:, n_half:, :]
b = output[:, :n_half, :]
audio_1 = torch.exp(log_s) * audio_1 + b
log_s_list.append(log_s)
audio = torch.cat([audio_0, audio_1], 1)
output_audio.append(audio)
return torch.cat(output_audio, 1), log_s_list, log_det_W_list
def norm_dist_to_audio(self, *, spec, z=None, sigma: float = 1.0):
spec = self.upsample(spec)
spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
# trim conv artifacts. maybe pad spec to kernel multiple
if self.time_cutoff != 0:
spec = spec[:, :, : self.time_cutoff]
spec = split_view(spec, self.n_group, 2).permute(0, 2, 1, 3)
spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
spec = spec.permute(0, 2, 1)
z_size = torch.Size([spec.size(0), self.n_group, spec.size(2)])
if z is None:
z = sigma * torch.randn(z_size, device=spec.device).to(spec.dtype)
audio, z = torch.split(z, [self.n_remaining_channels, z.size(1) - self.n_remaining_channels], 1)
for k in reversed(range(self.n_flows)):
n_half = self.n_halves[k]
audio_0, audio_1 = torch.split(audio, [n_half, audio.size(1) - n_half], 1)
output = self.wavenet[k]((audio_0, spec))
b, s = torch.split(output, [n_half, output.size(1) - n_half], 1)
audio_1 = audio_1 - b
audio_1 = audio_1 / torch.exp(s)
audio = torch.cat((audio_0, audio_1), 1)
audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0:
z1, z = torch.split(z, [self.n_early_size, z.size(1) - self.n_early_size], 1)
audio = torch.cat((z1, audio), 1)
return audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1)
def remove_weightnorm(self):
if self.removed_weightnorm:
return
for wavenet in self.wavenet:
wavenet.start = torch.nn.utils.remove_weight_norm(wavenet.start)
wavenet.in_layers = remove(wavenet.in_layers)
wavenet.cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layer)
wavenet.res_skip_layers = remove(wavenet.res_skip_layers)
self.removed_weightnorm = True