| """ |
| This is a base lightning module that can be used to train a model. |
| The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. |
| """ |
| import inspect |
| from abc import ABC |
| from typing import Any, Dict |
|
|
| import torch |
| from lightning import LightningModule |
| from lightning.pytorch.utilities import grad_norm |
|
|
| from matcha import utils |
| from matcha.utils.utils import plot_tensor |
|
|
| log = utils.get_pylogger(__name__) |
|
|
|
|
| class BaseLightningClass(LightningModule, ABC): |
| def update_data_statistics(self, data_statistics): |
| if data_statistics is None: |
| data_statistics = { |
| "mel_mean": 0.0, |
| "mel_std": 1.0, |
| } |
|
|
| self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) |
| self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) |
|
|
| def configure_optimizers(self) -> Any: |
| optimizer = self.hparams.optimizer(params=self.parameters()) |
| if self.hparams.scheduler not in (None, {}): |
| scheduler_args = {} |
| |
| if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: |
| if hasattr(self, "ckpt_loaded_epoch"): |
| current_epoch = self.ckpt_loaded_epoch - 1 |
| else: |
| current_epoch = -1 |
|
|
| scheduler_args.update({"optimizer": optimizer}) |
| scheduler = self.hparams.scheduler.scheduler(**scheduler_args) |
| scheduler.last_epoch = current_epoch |
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": { |
| "scheduler": scheduler, |
| "interval": self.hparams.scheduler.lightning_args.interval, |
| "frequency": self.hparams.scheduler.lightning_args.frequency, |
| "name": "learning_rate", |
| }, |
| } |
|
|
| return {"optimizer": optimizer} |
|
|
| def get_losses(self, batch): |
| x, x_lengths = batch["x"], batch["x_lengths"] |
| y, y_lengths = batch["y"], batch["y_lengths"] |
| spks = batch["spks"] |
|
|
| dur_loss, prior_loss, diff_loss = self( |
| x=x, |
| x_lengths=x_lengths, |
| y=y, |
| y_lengths=y_lengths, |
| spks=spks, |
| out_size=self.out_size, |
| ) |
| return { |
| "dur_loss": dur_loss, |
| "prior_loss": prior_loss, |
| "diff_loss": diff_loss, |
| } |
|
|
| def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: |
| self.ckpt_loaded_epoch = checkpoint["epoch"] |
|
|
| def training_step(self, batch: Any, batch_idx: int): |
| loss_dict = self.get_losses(batch) |
| self.log( |
| "step", |
| float(self.global_step), |
| on_step=True, |
| prog_bar=True, |
| logger=True, |
| sync_dist=True, |
| ) |
|
|
| self.log( |
| "sub_loss/train_dur_loss", |
| loss_dict["dur_loss"], |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "sub_loss/train_prior_loss", |
| loss_dict["prior_loss"], |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "sub_loss/train_diff_loss", |
| loss_dict["diff_loss"], |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| sync_dist=True, |
| ) |
|
|
| total_loss = sum(loss_dict.values()) |
| self.log( |
| "loss/train", |
| total_loss, |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| prog_bar=True, |
| sync_dist=True, |
| ) |
|
|
| return {"loss": total_loss, "log": loss_dict} |
|
|
| def validation_step(self, batch: Any, batch_idx: int): |
| loss_dict = self.get_losses(batch) |
| self.log( |
| "sub_loss/val_dur_loss", |
| loss_dict["dur_loss"], |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "sub_loss/val_prior_loss", |
| loss_dict["prior_loss"], |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| sync_dist=True, |
| ) |
| self.log( |
| "sub_loss/val_diff_loss", |
| loss_dict["diff_loss"], |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| sync_dist=True, |
| ) |
|
|
| total_loss = sum(loss_dict.values()) |
| self.log( |
| "loss/val", |
| total_loss, |
| on_step=True, |
| on_epoch=True, |
| logger=True, |
| prog_bar=True, |
| sync_dist=True, |
| ) |
|
|
| return total_loss |
|
|
| def on_validation_end(self) -> None: |
| if self.trainer.is_global_zero: |
| one_batch = next(iter(self.trainer.val_dataloaders)) |
| if self.current_epoch == 0: |
| log.debug("Plotting original samples") |
| for i in range(2): |
| y = one_batch["y"][i].unsqueeze(0).to(self.device) |
| self.logger.experiment.add_image( |
| f"original/{i}", |
| plot_tensor(y.squeeze().cpu()), |
| self.current_epoch, |
| dataformats="HWC", |
| ) |
|
|
| log.debug("Synthesising...") |
| for i in range(2): |
| x = one_batch["x"][i].unsqueeze(0).to(self.device) |
| x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) |
| spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None |
| output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) |
| y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] |
| attn = output["attn"] |
| self.logger.experiment.add_image( |
| f"generated_enc/{i}", |
| plot_tensor(y_enc.squeeze().cpu()), |
| self.current_epoch, |
| dataformats="HWC", |
| ) |
| self.logger.experiment.add_image( |
| f"generated_dec/{i}", |
| plot_tensor(y_dec.squeeze().cpu()), |
| self.current_epoch, |
| dataformats="HWC", |
| ) |
| self.logger.experiment.add_image( |
| f"alignment/{i}", |
| plot_tensor(attn.squeeze().cpu()), |
| self.current_epoch, |
| dataformats="HWC", |
| ) |
|
|
| def on_before_optimizer_step(self, optimizer): |
| self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) |
|
|