Spaces:
No application file
No application file
| from argparse import ArgumentParser | |
| import matplotlib.pyplot as plt | |
| import pytorch_lightning as pl | |
| import torch | |
| import wandb | |
| from loguru import logger | |
| from mmengine import Config | |
| from mmengine.optim import OPTIMIZERS | |
| from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger | |
| from torch.utils.data import DataLoader | |
| from fish_diffusion.archs.diffsinger import DiffSinger | |
| from fish_diffusion.datasets import DATASETS | |
| from fish_diffusion.datasets.repeat import RepeatDataset | |
| from fish_diffusion.utils.scheduler import LR_SCHEUDLERS | |
| from fish_diffusion.utils.viz import viz_synth_sample | |
| from fish_diffusion.vocoders import VOCODERS | |
| class FishDiffusion(pl.LightningModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.model = DiffSinger(config.model) | |
| self.config = config | |
| # 音频编码器, 将梅尔谱转换为音频 | |
| self.vocoder = VOCODERS.build(config.model.vocoder) | |
| self.vocoder.freeze() | |
| def configure_optimizers(self): | |
| self.config.optimizer.params = self.parameters() | |
| optimizer = OPTIMIZERS.build(self.config.optimizer) | |
| self.config.scheduler.optimizer = optimizer | |
| scheduler = LR_SCHEUDLERS.build(self.config.scheduler) | |
| return [optimizer], dict(scheduler=scheduler, interval="step") | |
| def _step(self, batch, batch_idx, mode): | |
| assert batch["pitches"].shape[1] == batch["mels"].shape[1] | |
| pitches = batch["pitches"].clone() | |
| batch_size = batch["speakers"].shape[0] | |
| output = self.model( | |
| speakers=batch["speakers"], | |
| contents=batch["contents"], | |
| src_lens=batch["content_lens"], | |
| max_src_len=batch["max_content_len"], | |
| mels=batch["mels"], | |
| mel_lens=batch["mel_lens"], | |
| max_mel_len=batch["max_mel_len"], | |
| pitches=batch["pitches"], | |
| ) | |
| self.log(f"{mode}_loss", output["loss"], batch_size=batch_size, sync_dist=True) | |
| if mode != "valid": | |
| return output["loss"] | |
| x = self.model.diffusion(output["features"]) | |
| for idx, (gt_mel, gt_pitch, predict_mel, predict_mel_len) in enumerate( | |
| zip(batch["mels"], pitches, x, batch["mel_lens"]) | |
| ): | |
| image_mels, wav_reconstruction, wav_prediction = viz_synth_sample( | |
| gt_mel=gt_mel, | |
| gt_pitch=gt_pitch, | |
| predict_mel=predict_mel, | |
| predict_mel_len=predict_mel_len, | |
| vocoder=self.vocoder, | |
| return_image=False, | |
| ) | |
| wav_reconstruction = wav_reconstruction.to(torch.float32).cpu().numpy() | |
| wav_prediction = wav_prediction.to(torch.float32).cpu().numpy() | |
| # WanDB logger | |
| if isinstance(self.logger, WandbLogger): | |
| self.logger.experiment.log( | |
| { | |
| f"reconstruction_mel": wandb.Image(image_mels, caption="mels"), | |
| f"wavs": [ | |
| wandb.Audio( | |
| wav_reconstruction, | |
| sample_rate=44100, | |
| caption=f"reconstruction (gt)", | |
| ), | |
| wandb.Audio( | |
| wav_prediction, | |
| sample_rate=44100, | |
| caption=f"prediction", | |
| ), | |
| ], | |
| }, | |
| ) | |
| # TensorBoard logger | |
| if isinstance(self.logger, TensorBoardLogger): | |
| self.logger.experiment.add_figure( | |
| f"sample-{idx}/mels", | |
| image_mels, | |
| global_step=self.global_step, | |
| ) | |
| self.logger.experiment.add_audio( | |
| f"sample-{idx}/wavs/gt", | |
| wav_reconstruction, | |
| self.global_step, | |
| sample_rate=44100, | |
| ) | |
| self.logger.experiment.add_audio( | |
| f"sample-{idx}/wavs/prediction", | |
| wav_prediction, | |
| self.global_step, | |
| sample_rate=44100, | |
| ) | |
| if isinstance(image_mels, plt.Figure): | |
| plt.close(image_mels) | |
| return output["loss"] | |
| def training_step(self, batch, batch_idx): | |
| return self._step(batch, batch_idx, mode="train") | |
| def validation_step(self, batch, batch_idx): | |
| return self._step(batch, batch_idx, mode="valid") | |
| if __name__ == "__main__": | |
| pl.seed_everything(42, workers=True) | |
| parser = ArgumentParser() | |
| parser.add_argument("--config", type=str, required=True) | |
| parser.add_argument("--resume", type=str, default=None) | |
| parser.add_argument( | |
| "--tensorboard", | |
| action="store_true", | |
| default=False, | |
| help="Use tensorboard logger, default is wandb.", | |
| ) | |
| parser.add_argument("--resume-id", type=str, default=None, help="Wandb run id.") | |
| parser.add_argument("--entity", type=str, default=None, help="Wandb entity.") | |
| parser.add_argument("--name", type=str, default=None, help="Wandb run name.") | |
| parser.add_argument( | |
| "--pretrained", type=str, default=None, help="Pretrained model." | |
| ) | |
| parser.add_argument( | |
| "--only-train-speaker-embeddings", | |
| action="store_true", | |
| default=False, | |
| help="Only train speaker embeddings.", | |
| ) | |
| args = parser.parse_args() | |
| cfg = Config.fromfile(args.config) | |
| model = FishDiffusion(cfg) | |
| # We only load the state_dict of the model, not the optimizer. | |
| if args.pretrained: | |
| state_dict = torch.load(args.pretrained, map_location="cpu") | |
| if "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| result = model.load_state_dict(state_dict, strict=False) | |
| missing_keys = set(result.missing_keys) | |
| unexpected_keys = set(result.unexpected_keys) | |
| # Make sure incorrect keys are just noise predictor keys. | |
| unexpected_keys = unexpected_keys - set( | |
| i.replace(".naive_noise_predictor.", ".") for i in missing_keys | |
| ) | |
| assert len(unexpected_keys) == 0 | |
| if args.only_train_speaker_embeddings: | |
| for name, param in model.named_parameters(): | |
| if "speaker_encoder" not in name: | |
| param.requires_grad = False | |
| logger.info( | |
| "Only train speaker embeddings, all other parameters are frozen." | |
| ) | |
| logger = ( | |
| TensorBoardLogger("logs", name=cfg.model.type) | |
| if args.tensorboard | |
| else WandbLogger( | |
| project=cfg.model.type, | |
| save_dir="logs", | |
| log_model=True, | |
| name=args.name, | |
| entity=args.entity, | |
| resume="must" if args.resume_id else False, | |
| id=args.resume_id, | |
| ) | |
| ) | |
| trainer = pl.Trainer( | |
| logger=logger, | |
| **cfg.trainer, | |
| ) | |
| train_dataset = DATASETS.build(cfg.dataset.train) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| collate_fn=train_dataset.collate_fn, | |
| **cfg.dataloader.train, | |
| ) | |
| valid_dataset = DATASETS.build(cfg.dataset.valid) | |
| valid_dataset = RepeatDataset( | |
| valid_dataset, repeat=trainer.num_devices, collate_fn=valid_dataset.collate_fn | |
| ) | |
| valid_loader = DataLoader( | |
| valid_dataset, | |
| collate_fn=valid_dataset.collate_fn, | |
| **cfg.dataloader.valid, | |
| ) | |
| trainer.fit(model, train_loader, valid_loader, ckpt_path=args.resume) | |