| import os | |
| import cv2 | |
| import sys | |
| import tqdm | |
| import torch | |
| import datetime | |
| import torch.nn as nn | |
| import torch.distributed as dist | |
| import torch.cuda as cuda | |
| from torch.utils.data.dataloader import DataLoader | |
| from torch.optim import Adam, SGD | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.cuda.amp.grad_scaler import GradScaler | |
| from torch.cuda.amp.autocast_mode import autocast | |
| filepath = os.path.split(os.path.abspath(__file__))[0] | |
| repopath = os.path.split(filepath)[0] | |
| sys.path.append(repopath) | |
| from lib import * | |
| from lib.optim import * | |
| from data.dataloader import * | |
| from utils.misc import * | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| def train(opt, args): | |
| train_dataset = eval(opt.Train.Dataset.type)( | |
| root=opt.Train.Dataset.root, | |
| sets=opt.Train.Dataset.sets, | |
| tfs=opt.Train.Dataset.transforms) | |
| if args.device_num > 1: | |
| cuda.set_device(args.local_rank) | |
| dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.device_num, timeout=datetime.timedelta(seconds=3600)) | |
| train_sampler = DistributedSampler(train_dataset, shuffle=True) | |
| else: | |
| train_sampler = None | |
| train_loader = DataLoader(dataset=train_dataset, | |
| batch_size=opt.Train.Dataloader.batch_size, | |
| shuffle=train_sampler is None, | |
| sampler=train_sampler, | |
| num_workers=opt.Train.Dataloader.num_workers, | |
| pin_memory=opt.Train.Dataloader.pin_memory, | |
| drop_last=True) | |
| model_ckpt = None | |
| state_ckpt = None | |
| if args.resume is True: | |
| if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')): | |
| model_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'), map_location='cpu') | |
| if args.local_rank <= 0: | |
| print('Resume from checkpoint') | |
| if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth')): | |
| state_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth'), map_location='cpu') | |
| if args.local_rank <= 0: | |
| print('Resume from state') | |
| model = eval(opt.Model.name)(**opt.Model) | |
| if model_ckpt is not None: | |
| model.load_state_dict(model_ckpt) | |
| if args.device_num > 1: | |
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
| model = model.cuda() | |
| model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) | |
| else: | |
| model = model.cuda() | |
| backbone_params = nn.ParameterList() | |
| decoder_params = nn.ParameterList() | |
| for name, param in model.named_parameters(): | |
| if 'backbone' in name: | |
| backbone_params.append(param) | |
| else: | |
| decoder_params.append(param) | |
| params_list = [{'params': backbone_params}, { | |
| 'params': decoder_params, 'lr': opt.Train.Optimizer.lr * 10}] | |
| optimizer = eval(opt.Train.Optimizer.type)( | |
| params_list, opt.Train.Optimizer.lr, weight_decay=opt.Train.Optimizer.weight_decay) | |
| if state_ckpt is not None: | |
| optimizer.load_state_dict(state_ckpt['optimizer']) | |
| if opt.Train.Optimizer.mixed_precision is True: | |
| scaler = GradScaler() | |
| else: | |
| scaler = None | |
| scheduler = eval(opt.Train.Scheduler.type)(optimizer, gamma=opt.Train.Scheduler.gamma, | |
| minimum_lr=opt.Train.Scheduler.minimum_lr, | |
| max_iteration=len(train_loader) * opt.Train.Scheduler.epoch, | |
| warmup_iteration=opt.Train.Scheduler.warmup_iteration) | |
| if state_ckpt is not None: | |
| scheduler.load_state_dict(state_ckpt['scheduler']) | |
| model.train() | |
| start = 1 | |
| if state_ckpt is not None: | |
| start = state_ckpt['epoch'] | |
| epoch_iter = range(start, opt.Train.Scheduler.epoch + 1) | |
| if args.local_rank <= 0 and args.verbose is True: | |
| epoch_iter = tqdm.tqdm(epoch_iter, desc='Epoch', total=opt.Train.Scheduler.epoch, initial=start - 1, | |
| position=0, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}') | |
| for epoch in epoch_iter: | |
| if args.local_rank <= 0 and args.verbose is True: | |
| step_iter = tqdm.tqdm(enumerate(train_loader, start=1), desc='Iter', total=len( | |
| train_loader), position=1, leave=False, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}') | |
| if args.device_num > 1 and train_sampler is not None: | |
| train_sampler.set_epoch(epoch) | |
| else: | |
| step_iter = enumerate(train_loader, start=1) | |
| for i, sample in step_iter: | |
| optimizer.zero_grad() | |
| if opt.Train.Optimizer.mixed_precision is True and scaler is not None: | |
| with autocast(): | |
| sample = to_cuda(sample) | |
| out = model(sample) | |
| scaler.scale(out['loss']).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| else: | |
| sample = to_cuda(sample) | |
| out = model(sample) | |
| out['loss'].backward() | |
| optimizer.step() | |
| scheduler.step() | |
| if args.local_rank <= 0 and args.verbose is True: | |
| step_iter.set_postfix({'loss': out['loss'].item()}) | |
| if args.local_rank <= 0: | |
| os.makedirs(opt.Train.Checkpoint.checkpoint_dir, exist_ok=True) | |
| os.makedirs(os.path.join( | |
| opt.Train.Checkpoint.checkpoint_dir, 'debug'), exist_ok=True) | |
| if epoch % opt.Train.Checkpoint.checkpoint_epoch == 0: | |
| if args.device_num > 1: | |
| model_ckpt = model.module.state_dict() | |
| else: | |
| model_ckpt = model.state_dict() | |
| state_ckpt = {'epoch': epoch + 1, | |
| 'optimizer': optimizer.state_dict(), | |
| 'scheduler': scheduler.state_dict()} | |
| torch.save(model_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')) | |
| torch.save(state_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth')) | |
| if args.debug is True: | |
| debout = debug_tile(sum([out[k] for k in opt.Train.Debug.keys], []), activation=torch.sigmoid) | |
| cv2.imwrite(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'debug', str(epoch) + '.png'), debout) | |
| if args.local_rank <= 0: | |
| torch.save(model.module.state_dict() if args.device_num > 1 else model.state_dict(), | |
| os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')) | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| opt = load_config(args.config) | |
| train(opt, args) | |