|
|
import os |
|
|
import logging |
|
|
import torch |
|
|
import torch.utils.data |
|
|
import pytorch_lightning as pl |
|
|
import laion_clap |
|
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
from model.CLAPSep_decoder import HTSAT_Decoder |
|
|
from model.CLAPSep import LightningModule |
|
|
import argparse |
|
|
from helpers import utils as local_utils |
|
|
from dataset import CLAPSepDataSet, CLAPSepDataEngineDataSet |
|
|
|
|
|
import wandb |
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
|
|
|
|
|
def main(args): |
|
|
torch.set_float32_matmul_precision('medium') |
|
|
|
|
|
data_train = CLAPSepDataEngineDataSet(**args.train_data) |
|
|
|
|
|
logging.info("Loaded train dataset at %s containing %d elements" % |
|
|
(args.train_data['data_list'], len(data_train))) |
|
|
data_val = CLAPSepDataSet(**args.val_data) |
|
|
logging.info("Loaded test dataset at %s containing %d elements" % |
|
|
(args.val_data['data_list'], len(data_val))) |
|
|
train_loader = torch.utils.data.DataLoader(data_train, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=args.n_workers, |
|
|
pin_memory=True) |
|
|
val_loader = torch.utils.data.DataLoader(data_val, |
|
|
batch_size=args.eval_batch_size, |
|
|
shuffle=False, |
|
|
num_workers=args.n_workers, |
|
|
pin_memory=True) |
|
|
|
|
|
clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') |
|
|
clap_model.load_ckpt(args.clap_path) |
|
|
decoder = HTSAT_Decoder(**args.model) |
|
|
lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'], |
|
|
use_lora=args.lora, |
|
|
rank=args.lora_rank, |
|
|
nfft=args.nfft,) |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(args.exp_dir, 'checkpoints'), |
|
|
filename="{epoch:02d}-{step}-{val_loss:.2f}", |
|
|
monitor="val_loss", |
|
|
mode="max", |
|
|
save_top_k=3, |
|
|
every_n_train_steps=args.save_ckpt_every_steps, |
|
|
save_last=True) |
|
|
logger = TensorBoardLogger(args.exp_dir) |
|
|
|
|
|
|
|
|
|
|
|
distributed_backend = "ddp" |
|
|
trainer = pl.Trainer( |
|
|
default_root_dir=args.exp_dir, |
|
|
devices=args.gpu_ids if args.use_cuda else "auto", |
|
|
accelerator="gpu" if args.use_cuda else "cpu", |
|
|
benchmark=True, |
|
|
gradient_clip_val=5.0, |
|
|
precision='bf16-mixed', |
|
|
limit_train_batches=1.0, |
|
|
max_epochs=args.epochs, |
|
|
strategy=distributed_backend, |
|
|
logger=logger, |
|
|
callbacks=[checkpoint_callback], |
|
|
) |
|
|
|
|
|
if os.path.exists(args.resume_ckpt): |
|
|
print('Load resume ckpt:', args.resume_ckpt) |
|
|
trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader, |
|
|
ckpt_path=args.resume_ckpt) |
|
|
elif os.path.exists(args.init_ckpt): |
|
|
print('Load init ckpt:', args.init_ckpt) |
|
|
weights = torch.load(args.init_ckpt, map_location='cpu')['state_dict'] |
|
|
lightning_module.load_state_dict(weights, strict=False) |
|
|
trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader) |
|
|
else: |
|
|
print('Training from scratch') |
|
|
trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('exp_dir', type=str, |
|
|
default='./experiments/CLAPSep_base', |
|
|
help="Path to save checkpoints and logs.") |
|
|
parser.add_argument('--init_ckpt', type=str, default='') |
|
|
parser.add_argument('--resume_ckpt', type=str, default='') |
|
|
|
|
|
parser.add_argument('--multi_label_training', dest='multi_label_training', action='store_true', |
|
|
help="Whether to multi label training") |
|
|
|
|
|
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', |
|
|
help="Whether to use cuda") |
|
|
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None, |
|
|
help="List of GPU ids used for training. " |
|
|
"Eg., --gpu_ids 2 4. All GPUs are used by default.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
pl.seed_everything(114514) |
|
|
|
|
|
if not os.path.exists(args.exp_dir): |
|
|
os.makedirs(args.exp_dir) |
|
|
|
|
|
|
|
|
params = local_utils.Params(os.path.join(args.exp_dir, 'config.json')) |
|
|
for k, v in params.__dict__.items(): |
|
|
vars(args)[k] = v |
|
|
main(args) |
|
|
|