File size: 5,373 Bytes
dbbd709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import logging
import torch
import torch.utils.data
import pytorch_lightning as pl
import laion_clap
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from model.CLAPSep_decoder import HTSAT_Decoder
from model.CLAPSep import LightningModule
import argparse
from helpers import utils as local_utils
from dataset import CLAPSepDataSet, CLAPSepDataEngineDataSet

import wandb
from pytorch_lightning.loggers import WandbLogger


def main(args):
    torch.set_float32_matmul_precision('medium')
    # Load dataset
    data_train = CLAPSepDataEngineDataSet(**args.train_data)
    # data_train = CLAPSepDataSet(**args.train_data)
    logging.info("Loaded train dataset at %s containing %d elements" %
                 (args.train_data['data_list'], len(data_train)))
    data_val = CLAPSepDataSet(**args.val_data)
    logging.info("Loaded test dataset at %s containing %d elements" %
                 (args.val_data['data_list'], len(data_val)))
    train_loader = torch.utils.data.DataLoader(data_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.n_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(data_val,
                                             batch_size=args.eval_batch_size,
                                             shuffle=False,
                                             num_workers=args.n_workers,
                                             pin_memory=True)

    clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu')
    clap_model.load_ckpt(args.clap_path)
    decoder = HTSAT_Decoder(**args.model)
    lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'],
                                       use_lora=args.lora,
                                       rank=args.lora_rank,
                                       nfft=args.nfft,)
    
    checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(args.exp_dir, 'checkpoints'),
                                          filename="{epoch:02d}-{step}-{val_loss:.2f}",
                                          monitor="val_loss",
                                          mode="max",
                                          save_top_k=3,
                                          every_n_train_steps=args.save_ckpt_every_steps,
                                          save_last=True)
    logger = TensorBoardLogger(args.exp_dir)
    # wandb_logger = WandbLogger(project='clapsep')
    # wandb_logger = WandbLogger(project='clapsep', id='', resume='must')
    # distributed_backend = "ddp_find_unused_parameters_true"
    distributed_backend = "ddp"
    trainer = pl.Trainer(
        default_root_dir=args.exp_dir,
        devices=args.gpu_ids if args.use_cuda else "auto",
        accelerator="gpu" if args.use_cuda else "cpu",
        benchmark=True,
        gradient_clip_val=5.0,
        precision='bf16-mixed',
        limit_train_batches=1.0,
        max_epochs=args.epochs,
        strategy=distributed_backend,
        logger=logger,
        callbacks=[checkpoint_callback],
    )

    if os.path.exists(args.resume_ckpt):
        print('Load resume ckpt:', args.resume_ckpt)
        trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader,
                    ckpt_path=args.resume_ckpt)
    elif os.path.exists(args.init_ckpt):
        print('Load init ckpt:', args.init_ckpt)
        weights = torch.load(args.init_ckpt, map_location='cpu')['state_dict']
        lightning_module.load_state_dict(weights, strict=False)
        trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader)
    else:
        print('Training from scratch')
        trainer.fit(model=lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Data Params
    parser.add_argument('exp_dir', type=str,
                        default='./experiments/CLAPSep_base',
                        help="Path to save checkpoints and logs.")
    parser.add_argument('--init_ckpt', type=str, default='')
    parser.add_argument('--resume_ckpt', type=str, default='')

    parser.add_argument('--multi_label_training', dest='multi_label_training', action='store_true',
                        help="Whether to multi label training")
    
    parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
                        help="Whether to use cuda")
    parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
                        help="List of GPU ids used for training. "
                             "Eg., --gpu_ids 2 4. All GPUs are used by default.")

    args = parser.parse_args()

    # Set the random seed for reproducible experiments
    pl.seed_everything(114514)
    # Set up checkpoints
    if not os.path.exists(args.exp_dir):
        os.makedirs(args.exp_dir)

    # Load model and training params
    params = local_utils.Params(os.path.join(args.exp_dir, 'config.json'))
    for k, v in params.__dict__.items():
        vars(args)[k] = v
    main(args)