import os import pytorch_lightning as pl import hydra import torch import random import time from os.path import join, basename, exists from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.strategies import DDPStrategy,FSDPStrategy from torch.utils.data import DataLoader from data_module import DataModule from lightning_module import CodecLightningModule from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger from omegaconf import OmegaConf seed = 1024 seed_everything(seed) @hydra.main(config_path='config', config_name='default', version_base=None) def train(cfg): checkpoint_callback = ModelCheckpoint(dirpath=cfg.log_dir, save_top_k=-1, save_last=True, every_n_train_steps=6000, monitor='mel_loss', mode='min') lr_monitor = LearningRateMonitor(logging_interval='step') callbacks = [checkpoint_callback, lr_monitor] datamodule = DataModule(cfg) lightning_module = CodecLightningModule(cfg) log_dir_name = os.path.basename(os.path.normpath(cfg.log_dir)) if cfg.get('use_wandb', True): logger = WandbLogger( project='wavvae_debug', name=log_dir_name, config=OmegaConf.to_container(cfg, resolve=True) ) else: # 用 TensorBoard,日志保存在 cfg.log_dir/tb_logs/ logger = TensorBoardLogger(save_dir=cfg.log_dir, name='tb_logs') ckpt_path = None last_ckpt = os.path.join(cfg.log_dir, 'last.ckpt') if os.path.exists(last_ckpt): ckpt_path = last_ckpt print(f"Resuming from checkpoint: {ckpt_path}") else: print("No checkpoint found, starting training from scratch.") trainer = pl.Trainer( **cfg.train.trainer, strategy=DDPStrategy(find_unused_parameters=True), callbacks=callbacks, logger=logger, profiler="simple", # 启用 Profiler limit_train_batches=1.0 if not cfg.debug else 0.001 ) torch.backends.cudnn.benchmark = True # lightning_module.strict_loading = False # LightningModule.strict_loading = True trainer.fit(lightning_module, datamodule=datamodule,ckpt_path=ckpt_path ) print(f'Training ends, best score: {checkpoint_callback.best_model_score}, ckpt path: {checkpoint_callback.best_model_path}') if __name__ == '__main__': train()