| | from inspect import signature |
| | from typing import Dict, List, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from coqpit import Coqpit |
| | from torch import nn |
| | from torch.utils.data import DataLoader |
| | from torch.utils.data.distributed import DistributedSampler |
| | from trainer.trainer_utils import get_optimizer, get_scheduler |
| |
|
| | from TTS.utils.audio import AudioProcessor |
| | from TTS.utils.io import load_fsspec |
| | from TTS.vocoder.datasets.gan_dataset import GANDataset |
| | from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss |
| | from TTS.vocoder.models import setup_discriminator, setup_generator |
| | from TTS.vocoder.models.base_vocoder import BaseVocoder |
| | from TTS.vocoder.utils.generic_utils import plot_results |
| |
|
| |
|
| | class GAN(BaseVocoder): |
| | def __init__(self, config: Coqpit, ap: AudioProcessor = None): |
| | """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. |
| | It also helps mixing and matching different generator and disciminator networks easily. |
| | |
| | To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest |
| | is handled by the `GAN` class. |
| | |
| | Args: |
| | config (Coqpit): Model configuration. |
| | ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None. |
| | |
| | Examples: |
| | Initializing the GAN model with HifiGAN generator and discriminator. |
| | >>> from TTS.vocoder.configs import HifiganConfig |
| | >>> config = HifiganConfig() |
| | >>> model = GAN(config) |
| | """ |
| | super().__init__(config) |
| | self.config = config |
| | self.model_g = setup_generator(config) |
| | self.model_d = setup_discriminator(config) |
| | self.train_disc = False |
| | self.y_hat_g = None |
| | self.ap = ap |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Run the generator's forward pass. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor. |
| | |
| | Returns: |
| | torch.Tensor: output of the GAN generator network. |
| | """ |
| | return self.model_g.forward(x) |
| |
|
| | def inference(self, x: torch.Tensor) -> torch.Tensor: |
| | """Run the generator's inference pass. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor. |
| | Returns: |
| | torch.Tensor: output of the GAN generator network. |
| | """ |
| | return self.model_g.inference(x) |
| |
|
| | def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]: |
| | """Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for |
| | network on the current pass. |
| | |
| | Args: |
| | batch (Dict): Batch of samples returned by the dataloader. |
| | criterion (Dict): Criterion used to compute the losses. |
| | optimizer_idx (int): ID of the optimizer in use on the current pass. |
| | |
| | Raises: |
| | ValueError: `optimizer_idx` is an unexpected value. |
| | |
| | Returns: |
| | Tuple[Dict, Dict]: model outputs and the computed loss values. |
| | """ |
| | outputs = {} |
| | loss_dict = {} |
| |
|
| | x = batch["input"] |
| | y = batch["waveform"] |
| |
|
| | if optimizer_idx not in [0, 1]: |
| | raise ValueError(" [!] Unexpected `optimizer_idx`.") |
| |
|
| | if optimizer_idx == 0: |
| | |
| |
|
| | |
| | y_hat = self.model_g(x)[:, :, : y.size(2)] |
| |
|
| | |
| | |
| | self.y_hat_g = y_hat |
| | self.y_hat_sub = None |
| | self.y_sub_g = None |
| |
|
| | |
| | if y_hat.shape[1] > 1: |
| | self.y_hat_sub = y_hat |
| | y_hat = self.model_g.pqmf_synthesis(y_hat) |
| | self.y_hat_g = y_hat |
| | self.y_sub_g = self.model_g.pqmf_analysis(y) |
| |
|
| | scores_fake, feats_fake, feats_real = None, None, None |
| |
|
| | if self.train_disc: |
| | |
| | if self.config.diff_samples_for_G_and_D: |
| | x_d = batch["input_disc"] |
| | y_d = batch["waveform_disc"] |
| | |
| | with torch.no_grad(): |
| | y_hat = self.model_g(x_d) |
| |
|
| | |
| | if y_hat.shape[1] > 1: |
| | y_hat = self.model_g.pqmf_synthesis(y_hat) |
| | else: |
| | |
| | x_d = x.clone() |
| | y_d = y.clone() |
| | y_hat = self.y_hat_g |
| |
|
| | |
| | if len(signature(self.model_d.forward).parameters) == 2: |
| | D_out_fake = self.model_d(y_hat.detach().clone(), x_d) |
| | D_out_real = self.model_d(y_d, x_d) |
| | else: |
| | D_out_fake = self.model_d(y_hat.detach()) |
| | D_out_real = self.model_d(y_d) |
| |
|
| | |
| | if isinstance(D_out_fake, tuple): |
| | |
| | scores_fake, feats_fake = D_out_fake |
| | if D_out_real is None: |
| | scores_real, feats_real = None, None |
| | else: |
| | scores_real, feats_real = D_out_real |
| | else: |
| | |
| | scores_fake = D_out_fake |
| | scores_real = D_out_real |
| |
|
| | |
| | loss_dict = criterion[optimizer_idx](scores_fake, scores_real) |
| | outputs = {"model_outputs": y_hat} |
| |
|
| | if optimizer_idx == 1: |
| | |
| | scores_fake, feats_fake, feats_real = None, None, None |
| | if self.train_disc: |
| | if len(signature(self.model_d.forward).parameters) == 2: |
| | D_out_fake = self.model_d(self.y_hat_g, x) |
| | else: |
| | D_out_fake = self.model_d(self.y_hat_g) |
| | D_out_real = None |
| |
|
| | if self.config.use_feat_match_loss: |
| | with torch.no_grad(): |
| | D_out_real = self.model_d(y) |
| |
|
| | |
| | if isinstance(D_out_fake, tuple): |
| | scores_fake, feats_fake = D_out_fake |
| | if D_out_real is None: |
| | feats_real = None |
| | else: |
| | _, feats_real = D_out_real |
| | else: |
| | scores_fake = D_out_fake |
| | feats_fake, feats_real = None, None |
| |
|
| | |
| | loss_dict = criterion[optimizer_idx]( |
| | self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g |
| | ) |
| | outputs = {"model_outputs": self.y_hat_g} |
| | return outputs, loss_dict |
| |
|
| | def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: |
| | """Logging shared by the training and evaluation. |
| | |
| | Args: |
| | name (str): Name of the run. `train` or `eval`, |
| | ap (AudioProcessor): Audio processor used in training. |
| | batch (Dict): Batch used in the last train/eval step. |
| | outputs (Dict): Model outputs from the last train/eval step. |
| | |
| | Returns: |
| | Tuple[Dict, Dict]: log figures and audio samples. |
| | """ |
| | y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"] |
| | y = batch["waveform"] |
| | figures = plot_results(y_hat, y, ap, name) |
| | sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() |
| | audios = {f"{name}/audio": sample_voice} |
| | return figures, audios |
| |
|
| | def train_log( |
| | self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int |
| | ) -> Tuple[Dict, np.ndarray]: |
| | """Call `_log()` for training.""" |
| | figures, audios = self._log("eval", self.ap, batch, outputs) |
| | logger.eval_figures(steps, figures) |
| | logger.eval_audios(steps, audios, self.ap.sample_rate) |
| |
|
| | @torch.no_grad() |
| | def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: |
| | """Call `train_step()` with `no_grad()`""" |
| | self.train_disc = True |
| | return self.train_step(batch, criterion, optimizer_idx) |
| |
|
| | def eval_log( |
| | self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int |
| | ) -> Tuple[Dict, np.ndarray]: |
| | """Call `_log()` for evaluation.""" |
| | figures, audios = self._log("eval", self.ap, batch, outputs) |
| | logger.eval_figures(steps, figures) |
| | logger.eval_audios(steps, audios, self.ap.sample_rate) |
| |
|
| | def load_checkpoint( |
| | self, |
| | config: Coqpit, |
| | checkpoint_path: str, |
| | eval: bool = False, |
| | cache: bool = False, |
| | ) -> None: |
| | """Load a GAN checkpoint and initialize model parameters. |
| | |
| | Args: |
| | config (Coqpit): Model config. |
| | checkpoint_path (str): Checkpoint file path. |
| | eval (bool, optional): If true, load the model for inference. If falseDefaults to False. |
| | """ |
| | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) |
| | |
| | if "model_disc" in state: |
| | self.model_g.load_checkpoint(config, checkpoint_path, eval) |
| | else: |
| | self.load_state_dict(state["model"]) |
| | if eval: |
| | self.model_d = None |
| | if hasattr(self.model_g, "remove_weight_norm"): |
| | self.model_g.remove_weight_norm() |
| |
|
| | def on_train_step_start(self, trainer) -> None: |
| | """Enable the discriminator training based on `steps_to_start_discriminator` |
| | |
| | Args: |
| | trainer (Trainer): Trainer object. |
| | """ |
| | self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator |
| |
|
| | def get_optimizer(self) -> List: |
| | """Initiate and return the GAN optimizers based on the config parameters. |
| | |
| | It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. |
| | |
| | Returns: |
| | List: optimizers. |
| | """ |
| | optimizer1 = get_optimizer( |
| | self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g |
| | ) |
| | optimizer2 = get_optimizer( |
| | self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d |
| | ) |
| | return [optimizer2, optimizer1] |
| |
|
| | def get_lr(self) -> List: |
| | """Set the initial learning rates for each optimizer. |
| | |
| | Returns: |
| | List: learning rates for each optimizer. |
| | """ |
| | return [self.config.lr_disc, self.config.lr_gen] |
| |
|
| | def get_scheduler(self, optimizer) -> List: |
| | """Set the schedulers for each optimizer. |
| | |
| | Args: |
| | optimizer (List[`torch.optim.Optimizer`]): List of optimizers. |
| | |
| | Returns: |
| | List: Schedulers, one for each optimizer. |
| | """ |
| | scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) |
| | scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) |
| | return [scheduler2, scheduler1] |
| |
|
| | @staticmethod |
| | def format_batch(batch: List) -> Dict: |
| | """Format the batch for training. |
| | |
| | Args: |
| | batch (List): Batch out of the dataloader. |
| | |
| | Returns: |
| | Dict: formatted model inputs. |
| | """ |
| | if isinstance(batch[0], list): |
| | x_G, y_G = batch[0] |
| | x_D, y_D = batch[1] |
| | return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D} |
| | x, y = batch |
| | return {"input": x, "waveform": y} |
| |
|
| | def get_data_loader( |
| | self, |
| | config: Coqpit, |
| | assets: Dict, |
| | is_eval: True, |
| | samples: List, |
| | verbose: bool, |
| | num_gpus: int, |
| | rank: int = None, |
| | ): |
| | """Initiate and return the GAN dataloader. |
| | |
| | Args: |
| | config (Coqpit): Model config. |
| | ap (AudioProcessor): Audio processor. |
| | is_eval (True): Set the dataloader for evaluation if true. |
| | samples (List): Data samples. |
| | verbose (bool): Log information if true. |
| | num_gpus (int): Number of GPUs in use. |
| | rank (int): Rank of the current GPU. Defaults to None. |
| | |
| | Returns: |
| | DataLoader: Torch dataloader. |
| | """ |
| | dataset = GANDataset( |
| | ap=self.ap, |
| | items=samples, |
| | seq_len=config.seq_len, |
| | hop_len=self.ap.hop_length, |
| | pad_short=config.pad_short, |
| | conv_pad=config.conv_pad, |
| | return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, |
| | is_training=not is_eval, |
| | return_segments=not is_eval, |
| | use_noise_augment=config.use_noise_augment, |
| | use_cache=config.use_cache, |
| | verbose=verbose, |
| | ) |
| | dataset.shuffle_mapping() |
| | sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=1 if is_eval else config.batch_size, |
| | shuffle=num_gpus == 0, |
| | drop_last=False, |
| | sampler=sampler, |
| | num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, |
| | pin_memory=False, |
| | ) |
| | return loader |
| |
|
| | def get_criterion(self): |
| | """Return criterions for the optimizers""" |
| | return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] |
| |
|
| | @staticmethod |
| | def init_from_config(config: Coqpit, verbose=True) -> "GAN": |
| | ap = AudioProcessor.init_from_config(config, verbose=verbose) |
| | return GAN(config, ap=ap) |
| |
|