File size: 11,118 Bytes
7934b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
from nemo.collections.tts.modules.submodules import Invertible1x1Conv, WaveNet
from nemo.collections.tts.parts.utils.helpers import OperationMode, remove, split_view
from nemo.core.classes import Exportable, NeuralModule, typecheck
from nemo.core.neural_types.elements import (
AudioSignal,
IntType,
MelSpectrogramType,
NormalDistributionSamplesType,
VoidType,
)
from nemo.core.neural_types.neural_type import NeuralType
class WaveGlowModule(NeuralModule, Exportable):
def __init__(
self,
n_mel_channels: int,
n_flows: int,
n_group: int,
n_early_every: int,
n_early_size: int,
n_wn_channels: int,
n_wn_layers: int,
wn_kernel_size: int,
):
"""
WaveGlow module
Args:
n_mel_channels (int): Number of mel channels to output.
n_flows (int): Number of flow layers
n_group (int): Number of groups to respace the inputs
n_early_every (int): Every n_early_every layers, n_early_size gets skip connected to the output
n_early_size (int): The size of the chunk to be skip connected
n_wn_channels (int): Number of channels for the non-invertible wavenet transformation
n_wn_layers (int): Number of layers for the non-invertible wavenet transformation
wn_kernel_size (int): Kernel size for the non-invertible wavenet transformation
"""
super().__init__()
self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, n_mel_channels, 1024, stride=256)
self.n_mel_channels = n_mel_channels
assert n_group % 2 == 0
self.n_flows = n_flows
self.n_group = n_group
self.n_early_every = n_early_every
self.n_early_size = n_early_size
self.wavenet = torch.nn.ModuleList()
self.convinv = torch.nn.ModuleList()
self.mode = OperationMode.infer
n_half = n_group // 2
# Set up layers with the right sizes based on how many dimensions
# have been output already
n_remaining_channels = n_group
for k in range(n_flows):
if k % self.n_early_every == 0 and k > 0:
n_half = n_half - int(self.n_early_size / 2)
n_remaining_channels = n_remaining_channels - self.n_early_size
self.convinv.append(Invertible1x1Conv(n_remaining_channels))
self.wavenet.append(
WaveNet(
n_half,
n_mel_channels * n_group,
n_layers=n_wn_layers,
n_channels=n_wn_channels,
kernel_size=wn_kernel_size,
)
)
self.n_remaining_channels = n_remaining_channels
self.time_cutoff = self.upsample.stride[0] - self.upsample.kernel_size[0]
# Pre-calculating the sizes of noise to use so it's not dynamic
n_halves = []
n_half = self.n_remaining_channels // 2
for k in reversed(range(self.n_flows)):
n_halves.append(n_half)
if k % self.n_early_every == 0 and k > 0:
n_half = n_half + int(self.n_early_size / 2)
n_halves.reverse()
self.n_halves = n_halves
self.removed_weightnorm = False
def _prepare_for_export(self, **kwargs):
"""
Override this method to prepare module for export. This is in-place operation.
Base version does common necessary module replacements (Apex etc)
"""
self.remove_weightnorm()
super()._prepare_for_export(**kwargs)
@typecheck()
def forward(self, spec, z=None, audio=None, run_inverse=True, sigma=1.0):
""" TODO
"""
if self.training and self.mode != OperationMode.training:
raise ValueError(f"{self} has self.training set to True but self.OperationMode was not set to training")
if not self.training and self.mode == OperationMode.training:
raise ValueError(f"{self} has self.training set to False but self.OperationMode was set to training")
audio_pred = torch.zeros((1, 1))
if audio is not None and self.mode != OperationMode.infer:
# audio_to_normal_dist is used to calculate loss so only run this in train or val model
z1, log_s_list, log_det_W_list = self.audio_to_normal_dist(spec=spec, audio=audio)
if run_inverse:
# norm_dist_to_audio is used to predict audio from spectrogram so only used in val or infer mode
# Could also log train audio but currently not done
audio_pred = self.norm_dist_to_audio(spec=spec, sigma=sigma, z=z)
# Return the necessary tensors
if self.mode == OperationMode.training or self.mode == OperationMode.validation:
return z1, log_s_list, log_det_W_list, audio_pred
return audio_pred
@property
def input_types(self):
if self.mode == OperationMode.infer:
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"z": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True),
"sigma": NeuralType(optional=True),
}
else:
return {
"spec": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"z": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True),
"audio": NeuralType(('B', 'T'), AudioSignal(), optional=True),
"run_inverse": NeuralType(elements_type=IntType(), optional=True),
"sigma": NeuralType(optional=True),
}
@property
def output_types(self):
if self.mode == OperationMode.training or self.mode == OperationMode.validation:
return {
"pred_normal_dist": NeuralType(('B', 'flowgroup', 'T'), NormalDistributionSamplesType()),
"log_s_list": [NeuralType(('B', 'flowgroup', 'T'), VoidType())], # TODO: Figure out a good typing
"log_det_W_list": [NeuralType(elements_type=VoidType())], # TODO: Figure out a good typing
"audio_pred": NeuralType(('B', 'T'), AudioSignal()),
}
else:
return {
"audio": NeuralType(('B', 'T'), AudioSignal()),
}
def input_example(self, max_batch=1, max_dim=256):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
par = next(self.parameters())
mel = torch.randn((max_batch, self.n_mel_channels, max_dim), device=par.device, dtype=par.dtype)
z = torch.randn(
(max_batch, self.n_mel_channels, max_dim * self.upsample.stride[0] // self.n_group),
device=par.device,
dtype=par.dtype,
)
return {"spec": mel, "z": z}
def audio_to_normal_dist(self, *, spec: torch.Tensor, audio: torch.Tensor) -> Tuple[torch.Tensor, list, list]:
# Upsample spectrogram to size of audio
spec = self.upsample(spec)
assert spec.size(2) >= audio.size(1)
if spec.size(2) > audio.size(1):
spec = spec[:, :, : audio.size(1)]
# logging.debug(f"spec: {spec.shape}. n_group: {self.n_group}")
spec = split_view(spec, self.n_group, 2).permute(0, 2, 1, 3)
spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
spec = spec.permute(0, 2, 1)
audio = split_view(audio, self.n_group, 1).permute(0, 2, 1)
output_audio = []
log_s_list = []
log_det_W_list = []
for k in range(self.n_flows):
if k % self.n_early_every == 0 and k > 0:
output_audio.append(audio[:, : self.n_early_size, :])
audio = audio[:, self.n_early_size :, :]
audio, log_det_W = self.convinv[k](audio)
log_det_W_list.append(log_det_W)
n_half = int(audio.size(1) / 2)
audio_0 = audio[:, :n_half, :]
audio_1 = audio[:, n_half:, :]
output = self.wavenet[k]((audio_0, spec))
log_s = output[:, n_half:, :]
b = output[:, :n_half, :]
audio_1 = torch.exp(log_s) * audio_1 + b
log_s_list.append(log_s)
audio = torch.cat([audio_0, audio_1], 1)
output_audio.append(audio)
return torch.cat(output_audio, 1), log_s_list, log_det_W_list
def norm_dist_to_audio(self, *, spec, z=None, sigma: float = 1.0):
spec = self.upsample(spec)
spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
# trim conv artifacts. maybe pad spec to kernel multiple
if self.time_cutoff != 0:
spec = spec[:, :, : self.time_cutoff]
spec = split_view(spec, self.n_group, 2).permute(0, 2, 1, 3)
spec = spec.contiguous().view(spec.size(0), spec.size(1), -1)
spec = spec.permute(0, 2, 1)
z_size = torch.Size([spec.size(0), self.n_group, spec.size(2)])
if z is None:
z = sigma * torch.randn(z_size, device=spec.device).to(spec.dtype)
audio, z = torch.split(z, [self.n_remaining_channels, z.size(1) - self.n_remaining_channels], 1)
for k in reversed(range(self.n_flows)):
n_half = self.n_halves[k]
audio_0, audio_1 = torch.split(audio, [n_half, audio.size(1) - n_half], 1)
output = self.wavenet[k]((audio_0, spec))
b, s = torch.split(output, [n_half, output.size(1) - n_half], 1)
audio_1 = audio_1 - b
audio_1 = audio_1 / torch.exp(s)
audio = torch.cat((audio_0, audio_1), 1)
audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0:
z1, z = torch.split(z, [self.n_early_size, z.size(1) - self.n_early_size], 1)
audio = torch.cat((z1, audio), 1)
return audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1)
def remove_weightnorm(self):
if self.removed_weightnorm:
return
for wavenet in self.wavenet:
wavenet.start = torch.nn.utils.remove_weight_norm(wavenet.start)
wavenet.in_layers = remove(wavenet.in_layers)
wavenet.cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layer)
wavenet.res_skip_layers = remove(wavenet.res_skip_layers)
self.removed_weightnorm = True
|