Spaces:
Sleeping
Sleeping
| """ | |
| This is a base lightning module that can be used to train a model. | |
| The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. | |
| """ | |
| import inspect | |
| from abc import ABC | |
| from typing import Any, Dict | |
| import torch | |
| from lightning import LightningModule | |
| from lightning.pytorch.utilities import grad_norm | |
| from matcha import utils | |
| from matcha.utils.utils import plot_tensor | |
| log = utils.get_pylogger(__name__) | |
| class BaseLightningClass(LightningModule, ABC): | |
| def update_data_statistics(self, data_statistics): | |
| if data_statistics is None: | |
| data_statistics = { | |
| "mel_mean": 0.0, | |
| "mel_std": 1.0, | |
| } | |
| self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) | |
| self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) | |
| def configure_optimizers(self) -> Any: | |
| optimizer = self.hparams.optimizer(params=self.parameters()) | |
| if self.hparams.scheduler not in (None, {}): | |
| scheduler_args = {} | |
| # Manage last epoch for exponential schedulers | |
| if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: | |
| if hasattr(self, "ckpt_loaded_epoch"): | |
| current_epoch = self.ckpt_loaded_epoch - 1 | |
| else: | |
| current_epoch = -1 | |
| scheduler_args.update({"optimizer": optimizer}) | |
| scheduler = self.hparams.scheduler.scheduler(**scheduler_args) | |
| scheduler.last_epoch = current_epoch | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": scheduler, | |
| "interval": self.hparams.scheduler.lightning_args.interval, | |
| "frequency": self.hparams.scheduler.lightning_args.frequency, | |
| "name": "learning_rate", | |
| }, | |
| } | |
| return {"optimizer": optimizer} | |
| def get_losses(self, batch): | |
| x, x_lengths = batch["x"], batch["x_lengths"] | |
| y, y_lengths = batch["y"], batch["y_lengths"] | |
| spks = batch["spks"] | |
| dur_loss, prior_loss, diff_loss = self( | |
| x=x, | |
| x_lengths=x_lengths, | |
| y=y, | |
| y_lengths=y_lengths, | |
| spks=spks, | |
| out_size=self.out_size, | |
| ) | |
| return { | |
| "dur_loss": dur_loss, | |
| "prior_loss": prior_loss, | |
| "diff_loss": diff_loss, | |
| } | |
| def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
| self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init | |
| def training_step(self, batch: Any, batch_idx: int): | |
| loss_dict = self.get_losses(batch) | |
| self.log( | |
| "step", | |
| float(self.global_step), | |
| on_step=True, | |
| prog_bar=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| "sub_loss/train_dur_loss", | |
| loss_dict["dur_loss"], | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| "sub_loss/train_prior_loss", | |
| loss_dict["prior_loss"], | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| "sub_loss/train_diff_loss", | |
| loss_dict["diff_loss"], | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| total_loss = sum(loss_dict.values()) | |
| self.log( | |
| "loss/train", | |
| total_loss, | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| return {"loss": total_loss, "log": loss_dict} | |
| def validation_step(self, batch: Any, batch_idx: int): | |
| loss_dict = self.get_losses(batch) | |
| self.log( | |
| "sub_loss/val_dur_loss", | |
| loss_dict["dur_loss"], | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| "sub_loss/val_prior_loss", | |
| loss_dict["prior_loss"], | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| "sub_loss/val_diff_loss", | |
| loss_dict["diff_loss"], | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| total_loss = sum(loss_dict.values()) | |
| self.log( | |
| "loss/val", | |
| total_loss, | |
| on_step=True, | |
| on_epoch=True, | |
| logger=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| return total_loss | |
| def on_validation_end(self) -> None: | |
| if self.trainer.is_global_zero: | |
| one_batch = next(iter(self.trainer.val_dataloaders)) | |
| if self.current_epoch == 0: | |
| log.debug("Plotting original samples") | |
| for i in range(2): | |
| y = one_batch["y"][i].unsqueeze(0).to(self.device) | |
| self.logger.experiment.add_image( | |
| f"original/{i}", | |
| plot_tensor(y.squeeze().cpu()), | |
| self.current_epoch, | |
| dataformats="HWC", | |
| ) | |
| log.debug("Synthesising...") | |
| for i in range(2): | |
| x = one_batch["x"][i].unsqueeze(0).to(self.device) | |
| x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) | |
| spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None | |
| output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) | |
| y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] | |
| attn = output["attn"] | |
| self.logger.experiment.add_image( | |
| f"generated_enc/{i}", | |
| plot_tensor(y_enc.squeeze().cpu()), | |
| self.current_epoch, | |
| dataformats="HWC", | |
| ) | |
| self.logger.experiment.add_image( | |
| f"generated_dec/{i}", | |
| plot_tensor(y_dec.squeeze().cpu()), | |
| self.current_epoch, | |
| dataformats="HWC", | |
| ) | |
| self.logger.experiment.add_image( | |
| f"alignment/{i}", | |
| plot_tensor(attn.squeeze().cpu()), | |
| self.current_epoch, | |
| dataformats="HWC", | |
| ) | |
| def on_before_optimizer_step(self, optimizer): | |
| self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) | |