Spaces:
Runtime error
Runtime error
| from typing import List | |
| from lightning.pytorch.core import LightningModule | |
| import torch | |
| from torch.optim import AdamW, Optimizer, swa_utils | |
| from torch.optim.lr_scheduler import ExponentialLR | |
| from torch.utils.data import DataLoader | |
| from models.config import ( | |
| AcousticENModelConfig, | |
| AcousticFinetuningConfig, | |
| AcousticPretrainingConfig, | |
| AcousticTrainingConfig, | |
| VocoderFinetuningConfig, | |
| VocoderModelConfig, | |
| VocoderPretrainingConfig, | |
| VoicoderTrainingConfig, | |
| get_lang_map, | |
| lang2id, | |
| ) | |
| from models.config import ( | |
| PreprocessingConfigUnivNet as PreprocessingConfig, | |
| ) | |
| from models.helpers.dataloaders import train_dataloader | |
| from models.helpers.tools import get_mask_from_lengths | |
| # Models | |
| from models.tts.delightful_tts.acoustic_model import AcousticModel | |
| from models.vocoder.univnet.discriminator import Discriminator | |
| from models.vocoder.univnet.generator import Generator | |
| from training.loss import FastSpeech2LossGen, UnivnetLoss | |
| from training.preprocess.normalize_text import NormalizeText | |
| # Updated version of the tokenizer | |
| from training.preprocess.tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA | |
| class DelightfulUnivnet(LightningModule): | |
| r"""DEPRECATED: This idea is basically wrong. The model should synthesis pretty well mel spectrograms and then use them to generate the waveform based on the good quality mel-spec. | |
| Trainer for the acoustic model. | |
| Args: | |
| fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False. | |
| lang (str): Language of the dataset. | |
| n_speakers (int): Number of speakers in the dataset.generation during training. | |
| batch_size (int): The batch size. | |
| acc_grad_steps (int): The number of gradient accumulation steps. | |
| swa_steps (int): The number of steps for the SWA update. | |
| """ | |
| def __init__( | |
| self, | |
| fine_tuning: bool = True, | |
| lang: str = "en", | |
| n_speakers: int = 5392, | |
| batch_size: int = 12, | |
| acc_grad_steps: int = 5, | |
| swa_steps: int = 1000, | |
| ): | |
| super().__init__() | |
| # Switch to manual optimization | |
| self.automatic_optimization = False | |
| self.acc_grad_steps = acc_grad_steps | |
| self.swa_steps = swa_steps | |
| self.lang = lang | |
| self.fine_tuning = fine_tuning | |
| self.batch_size = batch_size | |
| lang_map = get_lang_map(lang) | |
| normilize_text_lang = lang_map.nemo | |
| self.tokenizer = TokenizerIPA(lang) | |
| self.normilize_text = NormalizeText(normilize_text_lang) | |
| # Acoustic model | |
| self.train_config_acoustic: AcousticTrainingConfig | |
| if self.fine_tuning: | |
| self.train_config_acoustic = AcousticFinetuningConfig() | |
| else: | |
| self.train_config_acoustic = AcousticPretrainingConfig() | |
| self.preprocess_config = PreprocessingConfig("english_only") | |
| self.model_config_acoustic = AcousticENModelConfig() | |
| # TODO: fix the arguments! | |
| self.acoustic_model = AcousticModel( | |
| preprocess_config=self.preprocess_config, | |
| model_config=self.model_config_acoustic, | |
| # NOTE: this parameter may be hyperparameter that you can define based on the demands | |
| n_speakers=n_speakers, | |
| ) | |
| # Initialize SWA | |
| self.swa_averaged_acoustic = swa_utils.AveragedModel(self.acoustic_model) | |
| # NOTE: in case of training from 0 bin_warmup should be True! | |
| self.loss_acoustic = FastSpeech2LossGen(bin_warmup=False) | |
| # Vocoder models | |
| self.model_config_vocoder = VocoderModelConfig() | |
| self.train_config: VoicoderTrainingConfig = ( | |
| VocoderFinetuningConfig() if fine_tuning else VocoderPretrainingConfig() | |
| ) | |
| self.univnet = Generator( | |
| model_config=self.model_config_vocoder, | |
| preprocess_config=self.preprocess_config, | |
| ) | |
| self.swa_averaged_univnet = swa_utils.AveragedModel(self.univnet) | |
| self.discriminator = Discriminator(model_config=self.model_config_vocoder) | |
| self.swa_averaged_discriminator = swa_utils.AveragedModel(self.discriminator) | |
| self.loss_univnet = UnivnetLoss() | |
| def forward( | |
| self, text: str, speaker_idx: torch.Tensor, lang: str = "en" | |
| ) -> torch.Tensor: | |
| r"""Performs a forward pass through the AcousticModel. | |
| This code must be run only with the loaded weights from the checkpoint! | |
| Args: | |
| text (str): The input text. | |
| speaker_idx (torch.Tensor): The index of the speaker. | |
| lang (str): The language. | |
| Returns: | |
| torch.Tensor: The output of the AcousticModel. | |
| """ | |
| normalized_text = self.normilize_text(text) | |
| _, phones = self.tokenizer(normalized_text) | |
| # Convert to tensor | |
| x = torch.tensor( | |
| phones, | |
| dtype=torch.int, | |
| device=speaker_idx.device, | |
| ).unsqueeze(0) | |
| speakers = speaker_idx.repeat(x.shape[1]).unsqueeze(0) | |
| langs = ( | |
| torch.tensor( | |
| [lang2id[lang]], | |
| dtype=torch.int, | |
| device=speaker_idx.device, | |
| ) | |
| .repeat(x.shape[1]) | |
| .unsqueeze(0) | |
| ) | |
| y_pred = self.acoustic_model.forward( | |
| x=x, | |
| speakers=speakers, | |
| langs=langs, | |
| ) | |
| mel_lens = torch.tensor( | |
| [y_pred.shape[2]], | |
| dtype=torch.int32, | |
| device=y_pred.device, | |
| ) | |
| wav = self.univnet.infer(y_pred, mel_lens) | |
| return wav | |
| # TODO: don't forget about torch.no_grad() ! | |
| # default used by the Trainer | |
| # trainer = Trainer(inference_mode=True) | |
| # Use `torch.no_grad` instead | |
| # trainer = Trainer(inference_mode=False) | |
| def training_step(self, batch: List, batch_idx: int): | |
| r"""Performs a training step for the model. | |
| Args: | |
| batch (List): The batch of data for training. The batch should contain: | |
| - ids: List of indexes. | |
| - raw_texts: Raw text inputs. | |
| - speakers: Speaker identities. | |
| - texts: Text inputs. | |
| - src_lens: Lengths of the source sequences. | |
| - mels: Mel spectrogram targets. | |
| - pitches: Pitch targets. | |
| - pitches_stat: Statistics of the pitches. | |
| - mel_lens: Lengths of the mel spectrograms. | |
| - langs: Language identities. | |
| - attn_priors: Prior attention weights. | |
| - wavs: Waveform targets. | |
| - energies: Energy targets. | |
| batch_idx (int): Index of the batch. | |
| Returns: | |
| - 'loss': The total loss for the training step. | |
| """ | |
| ( | |
| _, | |
| _, | |
| speakers, | |
| texts, | |
| src_lens, | |
| mels, | |
| pitches, | |
| _, | |
| mel_lens, | |
| langs, | |
| attn_priors, | |
| audio, | |
| energies, | |
| ) = batch | |
| ##################################### | |
| ## Acoustic model train step ## | |
| ##################################### | |
| outputs = self.acoustic_model.forward_train( | |
| x=texts, | |
| speakers=speakers, | |
| src_lens=src_lens, | |
| mels=mels, | |
| mel_lens=mel_lens, | |
| pitches=pitches, | |
| langs=langs, | |
| attn_priors=attn_priors, | |
| energies=energies, | |
| ) | |
| y_pred = outputs["y_pred"] | |
| log_duration_prediction = outputs["log_duration_prediction"] | |
| p_prosody_ref = outputs["p_prosody_ref"] | |
| p_prosody_pred = outputs["p_prosody_pred"] | |
| pitch_prediction = outputs["pitch_prediction"] | |
| energy_pred = outputs["energy_pred"] | |
| energy_target = outputs["energy_target"] | |
| src_mask = get_mask_from_lengths(src_lens) | |
| mel_mask = get_mask_from_lengths(mel_lens) | |
| ( | |
| acc_total_loss, | |
| acc_mel_loss, | |
| acc_ssim_loss, | |
| acc_duration_loss, | |
| acc_u_prosody_loss, | |
| acc_p_prosody_loss, | |
| acc_pitch_loss, | |
| acc_ctc_loss, | |
| acc_bin_loss, | |
| acc_energy_loss, | |
| ) = self.loss_acoustic.forward( | |
| src_masks=src_mask, | |
| mel_masks=mel_mask, | |
| mel_targets=mels, | |
| mel_predictions=y_pred, | |
| log_duration_predictions=log_duration_prediction, | |
| u_prosody_ref=outputs["u_prosody_ref"], | |
| u_prosody_pred=outputs["u_prosody_pred"], | |
| p_prosody_ref=p_prosody_ref, | |
| p_prosody_pred=p_prosody_pred, | |
| pitch_predictions=pitch_prediction, | |
| p_targets=outputs["pitch_target"], | |
| durations=outputs["attn_hard_dur"], | |
| attn_logprob=outputs["attn_logprob"], | |
| attn_soft=outputs["attn_soft"], | |
| attn_hard=outputs["attn_hard"], | |
| src_lens=src_lens, | |
| mel_lens=mel_lens, | |
| energy_pred=energy_pred, | |
| energy_target=energy_target, | |
| step=self.trainer.global_step, | |
| ) | |
| self.log( | |
| "acc_total_loss", acc_total_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "acc_mel_loss", acc_mel_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "acc_ssim_loss", acc_ssim_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "acc_duration_loss", | |
| acc_duration_loss, | |
| sync_dist=True, | |
| batch_size=self.batch_size, | |
| ) | |
| self.log( | |
| "acc_u_prosody_loss", | |
| acc_u_prosody_loss, | |
| sync_dist=True, | |
| batch_size=self.batch_size, | |
| ) | |
| self.log( | |
| "acc_p_prosody_loss", | |
| acc_p_prosody_loss, | |
| sync_dist=True, | |
| batch_size=self.batch_size, | |
| ) | |
| self.log( | |
| "acc_pitch_loss", acc_pitch_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "acc_ctc_loss", acc_ctc_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "acc_bin_loss", acc_bin_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "acc_energy_loss", | |
| acc_energy_loss, | |
| sync_dist=True, | |
| batch_size=self.batch_size, | |
| ) | |
| ##################################### | |
| ## Univnet model train step ## | |
| ##################################### | |
| fake_audio = self.univnet.forward(y_pred) | |
| res_fake, period_fake = self.discriminator(fake_audio.detach()) | |
| res_real, period_real = self.discriminator(audio) | |
| ( | |
| voc_total_loss_gen, | |
| voc_total_loss_disc, | |
| voc_stft_loss, | |
| voc_score_loss, | |
| voc_esr_loss, | |
| voc_snr_loss, | |
| ) = self.loss_univnet.forward( | |
| audio, | |
| fake_audio, | |
| res_fake, | |
| period_fake, | |
| res_real, | |
| period_real, | |
| ) | |
| self.log( | |
| "voc_total_loss_gen", | |
| voc_total_loss_gen, | |
| sync_dist=True, | |
| batch_size=self.batch_size, | |
| ) | |
| self.log( | |
| "voc_total_loss_disc", | |
| voc_total_loss_disc, | |
| sync_dist=True, | |
| batch_size=self.batch_size, | |
| ) | |
| self.log( | |
| "voc_stft_loss", voc_stft_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "voc_score_loss", voc_score_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "voc_esr_loss", voc_esr_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| self.log( | |
| "voc_snr_loss", voc_snr_loss, sync_dist=True, batch_size=self.batch_size | |
| ) | |
| # Manual optimizer | |
| # Access your optimizers | |
| optimizers = self.optimizers() | |
| schedulers = self.lr_schedulers() | |
| #################################### | |
| # Acoustic model manual optimizer ## | |
| #################################### | |
| opt_acoustic: Optimizer = optimizers[0] # type: ignore | |
| sch_acoustic: ExponentialLR = schedulers[0] # type: ignore | |
| opt_univnet: Optimizer = optimizers[0] # type: ignore | |
| sch_univnet: ExponentialLR = schedulers[0] # type: ignore | |
| opt_discriminator: Optimizer = optimizers[1] # type: ignore | |
| sch_discriminator: ExponentialLR = schedulers[1] # type: ignore | |
| # Backward pass for the acoustic model | |
| # NOTE: the loss is divided by the accumulated gradient steps | |
| self.manual_backward(acc_total_loss / self.acc_grad_steps, retain_graph=True) | |
| # Perform manual optimization univnet | |
| self.manual_backward( | |
| voc_total_loss_gen / self.acc_grad_steps, retain_graph=True | |
| ) | |
| self.manual_backward( | |
| voc_total_loss_disc / self.acc_grad_steps, retain_graph=True | |
| ) | |
| # accumulate gradients of N batches | |
| if (batch_idx + 1) % self.acc_grad_steps == 0: | |
| # Acoustic model optimizer step | |
| # clip gradients | |
| self.clip_gradients( | |
| opt_acoustic, gradient_clip_val=0.5, gradient_clip_algorithm="norm" | |
| ) | |
| # optimizer step | |
| opt_acoustic.step() | |
| # Scheduler step | |
| sch_acoustic.step() | |
| # zero the gradients | |
| opt_acoustic.zero_grad() | |
| # Univnet model optimizer step | |
| # clip gradients | |
| self.clip_gradients( | |
| opt_univnet, gradient_clip_val=0.5, gradient_clip_algorithm="norm" | |
| ) | |
| self.clip_gradients( | |
| opt_discriminator, gradient_clip_val=0.5, gradient_clip_algorithm="norm" | |
| ) | |
| # optimizer step | |
| opt_univnet.step() | |
| opt_discriminator.step() | |
| # Scheduler step | |
| sch_univnet.step() | |
| sch_discriminator.step() | |
| # zero the gradients | |
| opt_univnet.zero_grad() | |
| opt_discriminator.zero_grad() | |
| # Update SWA model every swa_steps | |
| if self.trainer.global_step % self.swa_steps == 0: | |
| self.swa_averaged_acoustic.update_parameters(self.acoustic_model) | |
| self.swa_averaged_univnet.update_parameters(self.univnet) | |
| self.swa_averaged_discriminator.update_parameters(self.discriminator) | |
| def on_train_epoch_end(self): | |
| r"""Updates the averaged model after each optimizer step with SWA.""" | |
| self.swa_averaged_acoustic.update_parameters(self.acoustic_model) | |
| self.swa_averaged_univnet.update_parameters(self.univnet) | |
| self.swa_averaged_discriminator.update_parameters(self.discriminator) | |
| def configure_optimizers(self): | |
| r"""Configures the optimizer used for training. | |
| Returns | |
| tuple: A tuple containing three dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models. | |
| """ | |
| #################################### | |
| # Acoustic model optimizer config ## | |
| #################################### | |
| # Compute the gamma and initial learning rate based on the current step | |
| lr_decay = self.train_config_acoustic.optimizer_config.lr_decay | |
| default_lr = self.train_config_acoustic.optimizer_config.learning_rate | |
| init_lr = ( | |
| default_lr | |
| if self.trainer.global_step == 0 | |
| else default_lr * (lr_decay**self.trainer.global_step) | |
| ) | |
| optimizer_acoustic = AdamW( | |
| self.acoustic_model.parameters(), | |
| lr=init_lr, | |
| betas=self.train_config_acoustic.optimizer_config.betas, | |
| eps=self.train_config_acoustic.optimizer_config.eps, | |
| weight_decay=self.train_config_acoustic.optimizer_config.weight_decay, | |
| ) | |
| scheduler_acoustic = ExponentialLR(optimizer_acoustic, gamma=lr_decay) | |
| #################################### | |
| # Univnet model optimizer config ## | |
| #################################### | |
| optim_univnet = AdamW( | |
| self.univnet.parameters(), | |
| self.train_config.learning_rate, | |
| betas=(self.train_config.adam_b1, self.train_config.adam_b2), | |
| ) | |
| scheduler_univnet = ExponentialLR( | |
| optim_univnet, | |
| gamma=self.train_config.lr_decay, | |
| last_epoch=-1, | |
| ) | |
| #################################### | |
| # Discriminator optimizer config ## | |
| #################################### | |
| optim_discriminator = AdamW( | |
| self.discriminator.parameters(), | |
| self.train_config.learning_rate, | |
| betas=(self.train_config.adam_b1, self.train_config.adam_b2), | |
| ) | |
| scheduler_discriminator = ExponentialLR( | |
| optim_discriminator, | |
| gamma=self.train_config.lr_decay, | |
| last_epoch=-1, | |
| ) | |
| return ( | |
| {"optimizer": optimizer_acoustic, "lr_scheduler": scheduler_acoustic}, | |
| {"optimizer": optim_univnet, "lr_scheduler": scheduler_univnet}, | |
| {"optimizer": optim_discriminator, "lr_scheduler": scheduler_discriminator}, | |
| ) | |
| def on_train_end(self): | |
| # Update SWA models after training | |
| swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_acoustic) | |
| swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_univnet) | |
| swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_discriminator) | |
| def train_dataloader( | |
| self, | |
| num_workers: int = 5, | |
| root: str = "datasets_cache/LIBRITTS", | |
| cache: bool = True, | |
| cache_dir: str = "datasets_cache", | |
| mem_cache: bool = False, | |
| url: str = "train-960", | |
| ) -> DataLoader: | |
| r"""Returns the training dataloader, that is using the LibriTTS dataset. | |
| Args: | |
| num_workers (int): The number of workers. | |
| root (str): The root directory of the dataset. | |
| cache (bool): Whether to cache the preprocessed data. | |
| cache_dir (str): The directory for the cache. | |
| mem_cache (bool): Whether to use memory cache. | |
| url (str): The URL of the dataset. | |
| Returns: | |
| Tupple[DataLoader, DataLoader]: The training and validation dataloaders. | |
| """ | |
| return train_dataloader( | |
| batch_size=self.batch_size, | |
| num_workers=num_workers, | |
| root=root, | |
| cache=cache, | |
| cache_dir=cache_dir, | |
| mem_cache=mem_cache, | |
| url=url, | |
| lang=self.lang, | |
| ) | |