import os import logging import torch import torch.utils.data import pytorch_lightning as pl import laion_clap from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from model.CLAPSep_decoder import HTSAT_Decoder from model.CLAPSep import LightningModule import argparse from helpers import utils as local_utils from dataset import CLAPSepDataSet, CLAPSepDataEngineDataSet import wandb from pytorch_lightning.loggers import WandbLogger def main(args): torch.set_float32_matmul_precision('medium') # Load dataset data_train = CLAPSepDataEngineDataSet(**args.train_data) # data_train = CLAPSepDataSet(**args.train_data) logging.info("Loaded train dataset at %s containing %d elements" % (args.train_data['data_list'], len(data_train))) data_val = CLAPSepDataSet(**args.val_data) logging.info("Loaded test dataset at %s containing %d elements" % (args.val_data['data_list'], len(data_val))) train_loader = torch.utils.data.DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(data_val, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.n_workers, pin_memory=True) clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') clap_model.load_ckpt(args.clap_path) decoder = HTSAT_Decoder(**args.model) lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'], use_lora=args.lora, rank=args.lora_rank, nfft=args.nfft,) checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(args.exp_dir, 'checkpoints'), filename="{epoch:02d}-{step}-{val_loss:.2f}", monitor="val_loss", mode="max", save_top_k=3, every_n_train_steps=args.save_ckpt_every_steps, save_last=True) logger = TensorBoardLogger(args.exp_dir) # wandb_logger = WandbLogger(project='clapsep') # wandb_logger = WandbLogger(project='clapsep', id='', resume='must') # distributed_backend = "ddp_find_unused_parameters_true" distributed_backend = "ddp" trainer = pl.Trainer( default_root_dir=args.exp_dir, devices=args.gpu_ids if args.use_cuda else "auto", accelerator="gpu" if args.use_cuda else "cpu", benchmark=True, gradient_clip_val=5.0, precision='bf16-mixed', limit_train_batches=1.0, max_epochs=args.epochs, strategy=distributed_backend, logger=logger, callbacks=[checkpoint_callback], ) if os.path.exists(args.resume_ckpt): print('Load resume ckpt:', args.resume_ckpt) trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=args.resume_ckpt) elif os.path.exists(args.init_ckpt): print('Load init ckpt:', args.init_ckpt) weights = torch.load(args.init_ckpt, map_location='cpu')['state_dict'] lightning_module.load_state_dict(weights, strict=False) trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader) else: print('Training from scratch') trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader) if __name__ == '__main__': parser = argparse.ArgumentParser() # Data Params parser.add_argument('exp_dir', type=str, default='./experiments/CLAPSep_base', help="Path to save checkpoints and logs.") parser.add_argument('--init_ckpt', type=str, default='') parser.add_argument('--resume_ckpt', type=str, default='') parser.add_argument('--multi_label_training', dest='multi_label_training', action='store_true', help="Whether to multi label training") parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', help="Whether to use cuda") parser.add_argument('--gpu_ids', nargs='+', type=int, default=None, help="List of GPU ids used for training. " "Eg., --gpu_ids 2 4. All GPUs are used by default.") args = parser.parse_args() # Set the random seed for reproducible experiments pl.seed_everything(114514) # Set up checkpoints if not os.path.exists(args.exp_dir): os.makedirs(args.exp_dir) # Load model and training params params = local_utils.Params(os.path.join(args.exp_dir, 'config.json')) for k, v in params.__dict__.items(): vars(args)[k] = v main(args)