|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from random import random, randrange |
|
|
from typing import List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision |
|
|
from einops import rearrange |
|
|
from hydra.utils import instantiate |
|
|
from omegaconf import DictConfig |
|
|
from pytorch_lightning import Trainer |
|
|
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger |
|
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
|
|
|
|
from nemo.collections.tts.losses.spectrogram_enhancer_losses import ( |
|
|
ConsistencyLoss, |
|
|
GeneratorLoss, |
|
|
GradientPenaltyLoss, |
|
|
HingeLoss, |
|
|
) |
|
|
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor, to_device_recursive |
|
|
from nemo.core import Exportable, ModelPT, typecheck |
|
|
from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType |
|
|
from nemo.core.neural_types.elements import BoolType |
|
|
from nemo.utils import logging |
|
|
|
|
|
|
|
|
class SpectrogramEnhancerModel(ModelPT, Exportable): |
|
|
""" |
|
|
GAN-based model to add details to blurry spectrograms from TTS models like Tacotron or FastPitch. Based on StyleGAN 2 [1] |
|
|
[1] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None: |
|
|
self.spectrogram_model = None |
|
|
super().__init__(cfg=cfg, trainer=trainer) |
|
|
|
|
|
self.generator = instantiate(cfg.generator) |
|
|
self.discriminator = instantiate(cfg.discriminator) |
|
|
|
|
|
self.generator_loss = GeneratorLoss() |
|
|
self.discriminator_loss = HingeLoss() |
|
|
self.consistency_loss = ConsistencyLoss(cfg.consistency_loss_weight) |
|
|
self.gradient_penalty_loss = GradientPenaltyLoss(cfg.gradient_penalty_loss_weight) |
|
|
|
|
|
def move_to_correct_device(self, e): |
|
|
return to_device_recursive(e, next(iter(self.generator.parameters())).device) |
|
|
|
|
|
def normalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: |
|
|
spectrogram = spectrogram - self._cfg.spectrogram_min_value |
|
|
spectrogram = spectrogram / (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value) |
|
|
return mask_sequence_tensor(spectrogram, lengths) |
|
|
|
|
|
def unnormalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: |
|
|
spectrogram = spectrogram * (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value) |
|
|
spectrogram = spectrogram + self._cfg.spectrogram_min_value |
|
|
return mask_sequence_tensor(spectrogram, lengths) |
|
|
|
|
|
def generate_zs(self, batch_size: int = 1, mixing: bool = False): |
|
|
if mixing and self._cfg.mixed_prob < random(): |
|
|
mixing_point = randrange(1, self.generator.num_layers) |
|
|
first_part = [torch.randn(batch_size, self._cfg.latent_dim)] * mixing_point |
|
|
second_part = [torch.randn(batch_size, self._cfg.latent_dim)] * (self.generator.num_layers - mixing_point) |
|
|
zs = [*first_part, *second_part] |
|
|
else: |
|
|
zs = [torch.randn(batch_size, self._cfg.latent_dim)] * self.generator.num_layers |
|
|
|
|
|
return self.move_to_correct_device(zs) |
|
|
|
|
|
def generate_noise(self, batch_size: int = 1) -> torch.Tensor: |
|
|
noise = torch.rand(batch_size, self._cfg.n_bands, 4096, 1) |
|
|
return self.move_to_correct_device(noise) |
|
|
|
|
|
def pad_spectrograms(self, spectrograms): |
|
|
multiplier = self.generator.upsample_factor |
|
|
*_, max_length = spectrograms.shape |
|
|
return F.pad(spectrograms, (0, multiplier - max_length % multiplier)) |
|
|
|
|
|
@typecheck( |
|
|
input_types={ |
|
|
"input_spectrograms": NeuralType(("B", "D", "T_spec"), MelSpectrogramType()), |
|
|
"lengths": NeuralType(("B",), LengthsType()), |
|
|
"mixing": NeuralType(None, BoolType(), optional=True), |
|
|
"normalize": NeuralType(None, BoolType(), optional=True), |
|
|
} |
|
|
) |
|
|
def forward( |
|
|
self, *, input_spectrograms: torch.Tensor, lengths: torch.Tensor, mixing: bool = False, normalize: bool = True, |
|
|
): |
|
|
""" |
|
|
Generator forward pass. Noise inputs will be generated. |
|
|
|
|
|
input_spectrograms: batch of spectrograms, typically synthetic |
|
|
lengths: length for every spectrogam in the batch |
|
|
mixing: style mixing, usually True during training |
|
|
normalize: normalize spectrogram range to ~[0, 1], True for normal use |
|
|
|
|
|
returns: batch of enhanced spectrograms |
|
|
|
|
|
For explanation of style mixing refer to [1] |
|
|
[1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948) |
|
|
""" |
|
|
|
|
|
return self.forward_with_custom_noise( |
|
|
input_spectrograms=input_spectrograms, |
|
|
lengths=lengths, |
|
|
mixing=mixing, |
|
|
normalize=normalize, |
|
|
zs=None, |
|
|
ws=None, |
|
|
noise=None, |
|
|
) |
|
|
|
|
|
def forward_with_custom_noise( |
|
|
self, |
|
|
input_spectrograms: torch.Tensor, |
|
|
lengths: torch.Tensor, |
|
|
zs: Optional[List[torch.Tensor]] = None, |
|
|
ws: Optional[List[torch.Tensor]] = None, |
|
|
noise: Optional[torch.Tensor] = None, |
|
|
mixing: bool = False, |
|
|
normalize: bool = True, |
|
|
): |
|
|
""" |
|
|
Generator forward pass. Noise inputs will be generated if None. |
|
|
|
|
|
input_spectrograms: batch of spectrograms, typically synthetic |
|
|
lenghts: length for every spectrogam in the batch |
|
|
zs: latent noise inputs on the unit sphere (either this or ws or neither) |
|
|
ws: latent noise inputs in the style space (either this or zs or neither) |
|
|
noise: per-pixel indepentent gaussian noise |
|
|
mixing: style mixing, usually True during training |
|
|
normalize: normalize spectrogram range to ~[0, 1], True for normal use |
|
|
|
|
|
returns: batch of enhanced spectrograms |
|
|
|
|
|
For explanation of style mixing refer to [1] |
|
|
For definititions of z, w [2] |
|
|
[1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948) |
|
|
[2] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN, 2019 (https://arxiv.org/abs/1912.04958) |
|
|
""" |
|
|
batch_size, *_, max_length = input_spectrograms.shape |
|
|
|
|
|
|
|
|
if zs is not None and ws is not None: |
|
|
raise ValueError( |
|
|
"Please specify either zs or ws or neither, but not both. It is not clear which one to use." |
|
|
) |
|
|
|
|
|
if zs is None: |
|
|
zs = self.generate_zs(batch_size, mixing) |
|
|
if ws is None: |
|
|
ws = [self.generator.style_mapping(z) for z in zs] |
|
|
if noise is None: |
|
|
noise = self.generate_noise(batch_size) |
|
|
|
|
|
input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l") |
|
|
|
|
|
if normalize: |
|
|
input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths) |
|
|
input_spectrograms = self.pad_spectrograms(input_spectrograms) |
|
|
|
|
|
|
|
|
enhanced_spectrograms = self.generator(input_spectrograms, lengths, ws, noise) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
enhanced_spectrograms = self.unnormalize_spectrograms(enhanced_spectrograms, lengths) |
|
|
enhanced_spectrograms = enhanced_spectrograms[:, :, :, :max_length] |
|
|
enhanced_spectrograms = rearrange(enhanced_spectrograms, "b 1 c l -> b c l") |
|
|
|
|
|
return enhanced_spectrograms |
|
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx): |
|
|
input_spectrograms, target_spectrograms, lengths = batch |
|
|
|
|
|
with torch.no_grad(): |
|
|
input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths) |
|
|
target_spectrograms = self.normalize_spectrograms(target_spectrograms, lengths) |
|
|
|
|
|
|
|
|
if optimizer_idx == 0: |
|
|
enhanced_spectrograms = self.forward( |
|
|
input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False |
|
|
) |
|
|
enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l") |
|
|
fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths) |
|
|
|
|
|
target_spectrograms_ = rearrange(target_spectrograms, "b c l -> b 1 c l").requires_grad_() |
|
|
real_logits = self.discriminator(target_spectrograms_, input_spectrograms, lengths) |
|
|
d_loss = self.discriminator_loss(real_logits, fake_logits) |
|
|
self.log("d_loss", d_loss, prog_bar=True) |
|
|
|
|
|
if batch_idx % self._cfg.gradient_penalty_loss_every_n_steps == 0: |
|
|
gp_loss = self.gradient_penalty_loss(target_spectrograms_, real_logits) |
|
|
self.log("d_loss_gp", gp_loss, prog_bar=True) |
|
|
return d_loss + gp_loss |
|
|
|
|
|
return d_loss |
|
|
|
|
|
|
|
|
if optimizer_idx == 1: |
|
|
enhanced_spectrograms = self.forward( |
|
|
input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False |
|
|
) |
|
|
|
|
|
input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l") |
|
|
enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l") |
|
|
|
|
|
fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths) |
|
|
g_loss = self.generator_loss(fake_logits) |
|
|
c_loss = self.consistency_loss(input_spectrograms, enhanced_spectrograms, lengths) |
|
|
|
|
|
self.log("g_loss", g_loss, prog_bar=True) |
|
|
self.log("c_loss", c_loss, prog_bar=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
target_spectrograms = rearrange(target_spectrograms, "b c l -> b 1 c l") |
|
|
self.log_illustration(target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths) |
|
|
return g_loss + c_loss |
|
|
|
|
|
def configure_optimizers(self): |
|
|
generator_opt = instantiate(self._cfg.generator_opt, params=self.generator.parameters(),) |
|
|
discriminator_opt = instantiate(self._cfg.discriminator_opt, params=self.discriminator.parameters()) |
|
|
return [discriminator_opt, generator_opt], [] |
|
|
|
|
|
def setup_training_data(self, train_data_config): |
|
|
dataset = instantiate(train_data_config.dataset) |
|
|
self._train_dl = torch.utils.data.DataLoader( |
|
|
dataset, collate_fn=dataset.collate_fn, **train_data_config.dataloader_params |
|
|
) |
|
|
|
|
|
def setup_validation_data(self, val_data_config): |
|
|
""" |
|
|
There is no validation step for this model. |
|
|
It is not clear whether any of used losses is a sensible metric for choosing between two models. |
|
|
This might change in the future. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def list_available_models(cls): |
|
|
return [] |
|
|
|
|
|
def log_illustration(self, target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths): |
|
|
if self.global_rank != 0: |
|
|
return |
|
|
|
|
|
if not self.loggers: |
|
|
return |
|
|
|
|
|
step = self.trainer.global_step // 2 |
|
|
if step % self.trainer.log_every_n_steps != 0: |
|
|
return |
|
|
|
|
|
idx = 0 |
|
|
length = int(lengths.flatten()[idx].item()) |
|
|
tensor = torch.stack( |
|
|
[ |
|
|
enhanced_spectrograms - input_spectrograms, |
|
|
input_spectrograms, |
|
|
enhanced_spectrograms, |
|
|
target_spectrograms, |
|
|
], |
|
|
dim=0, |
|
|
).cpu()[:, idx, :, :, :length] |
|
|
|
|
|
grid = torchvision.utils.make_grid(tensor, nrow=1).clamp(0.0, 1.0) |
|
|
|
|
|
for logger in self.loggers: |
|
|
if isinstance(logger, TensorBoardLogger): |
|
|
writer: SummaryWriter = logger.experiment |
|
|
writer.add_image("spectrograms", grid, global_step=step) |
|
|
writer.flush() |
|
|
elif isinstance(logger, WandbLogger): |
|
|
logger.log_image("spectrograms", [grid], caption=["residual, input, output, ground truth"], step=step) |
|
|
else: |
|
|
logging.warning("Unsupported logger type: %s", str(type(logger))) |
|
|
|