| import os |
| |
| import pytorch_lightning as pl |
| import hydra |
| import torch |
| import random |
| import time |
| from os.path import join, basename, exists |
| from pytorch_lightning import seed_everything |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
| from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
| from pytorch_lightning.strategies import DDPStrategy,FSDPStrategy |
| from torch.utils.data import DataLoader |
| from data_module import DataModule |
| from lightning_module import CodecLightningModule |
| from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger |
| from omegaconf import OmegaConf |
| |
| seed = 1024 |
| seed_everything(seed) |
| |
| @hydra.main(config_path='config', config_name='default', version_base=None) |
| def train(cfg): |
| checkpoint_callback = ModelCheckpoint(dirpath=cfg.log_dir, |
| save_top_k=-1, save_last=True, |
| every_n_train_steps=6000, monitor='mel_loss', mode='min') |
|
|
| lr_monitor = LearningRateMonitor(logging_interval='step') |
| callbacks = [checkpoint_callback, lr_monitor] |
|
|
| datamodule = DataModule(cfg) |
| lightning_module = CodecLightningModule(cfg) |
| log_dir_name = os.path.basename(os.path.normpath(cfg.log_dir)) |
| if cfg.get('use_wandb', True): |
| logger = WandbLogger( |
| project='wavvae_debug', |
| name=log_dir_name, |
| config=OmegaConf.to_container(cfg, resolve=True) |
| ) |
| else: |
| |
| logger = TensorBoardLogger(save_dir=cfg.log_dir, name='tb_logs') |
|
|
| ckpt_path = None |
| last_ckpt = os.path.join(cfg.log_dir, 'last.ckpt') |
| if os.path.exists(last_ckpt): |
| ckpt_path = last_ckpt |
| print(f"Resuming from checkpoint: {ckpt_path}") |
| else: |
| print("No checkpoint found, starting training from scratch.") |
|
|
| trainer = pl.Trainer( |
| **cfg.train.trainer, |
| strategy=DDPStrategy(find_unused_parameters=True), |
| callbacks=callbacks, |
| logger=logger, |
| profiler="simple", |
| limit_train_batches=1.0 if not cfg.debug else 0.001 |
| ) |
| torch.backends.cudnn.benchmark = True |
| |
| |
| trainer.fit(lightning_module, datamodule=datamodule,ckpt_path=ckpt_path ) |
| print(f'Training ends, best score: {checkpoint_callback.best_model_score}, ckpt path: {checkpoint_callback.best_model_path}') |
|
|
| if __name__ == '__main__': |
| train() |
|
|