import argparse import datetime import os import shutil import sys import time import warnings from functools import partial import cv2 import torch import torch.cuda.amp as amp import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.nn.parallel import torch.optim import torch.utils.data as data from loguru import logger from torch.optim.lr_scheduler import MultiStepLR import utils.config as config import wandb from utils.dataset_sbert import RefDataset_gref from engine.engine_gref import train, validate from model import build_segmenter from utils.misc import (init_random_seed, set_random_seed, setup_logger, worker_init_fn, build_scheduler, collate_fn) warnings.filterwarnings("ignore") cv2.setNumThreads(0) def get_parser(): parser = argparse.ArgumentParser( description='Pytorch Referring Expression Segmentation') parser.add_argument('--config', default='path to xxx.yaml', type=str, help='config file') parser.add_argument('--opts', default=None, nargs=argparse.REMAINDER, help='override some settings in the config.') parser.add_argument('--local-rank', type=int, default=0, help='local rank for distributed training') args = parser.parse_args() assert args.config is not None cfg = config.load_cfg_from_cfg_file(args.config) if args.opts is not None: cfg = config.merge_cfg_from_list(cfg, args.opts) return cfg @logger.catch def main(): args = get_parser() args.manual_seed = init_random_seed(args.manual_seed) set_random_seed(args.manual_seed, deterministic=True) if 'LOCAL_RANK' in os.environ: args.local_rank = int(os.environ['LOCAL_RANK']) logger.info(f"LOCAL_RANK from env: {args.local_rank}") if 'LOCAL_RANK' in os.environ: main_worker_ddp(args) else: args.ngpus_per_node = torch.cuda.device_count() args.world_size = args.ngpus_per_node * getattr(args, 'world_size', 1) mp.spawn(main_worker_mp, nprocs=args.ngpus_per_node, args=(args,)) def main_worker_ddp(args): args.output_dir = os.path.join(args.output_folder, args.exp_name) args.gpu = args.local_rank args.rank = args.local_rank args.world_size = int(os.environ.get('WORLD_SIZE', 1)) torch.cuda.set_device(args.gpu) setup_logger(args.output_dir, distributed_rank=args.gpu, filename="train.log", mode="a") logger.info(f"Starting with GPU: {args.gpu}, Rank: {args.rank}, World Size: {args.world_size}") dist_url = os.environ.get('MASTER_ADDR', 'localhost') + ':' + os.environ.get('MASTER_PORT', '12355') dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'), init_method=f"env://", world_size=args.world_size, rank=args.rank) run_training(args) def main_worker_mp(gpu, args): args.output_dir = os.path.join(args.output_folder, args.exp_name) # local rank & global rank args.gpu = gpu rank = getattr(args, 'rank', 0) args.rank = rank * args.ngpus_per_node + gpu torch.cuda.set_device(args.gpu) setup_logger(args.output_dir, distributed_rank=args.gpu, filename="train.log", mode="a") dist_url = getattr(args, 'dist_url', f'tcp://localhost:12355') dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'), init_method=dist_url, world_size=args.world_size, rank=args.rank) run_training(args) def run_training(args): # wandb if args.rank == 0: wandb.init(job_type="training", mode="offline", config=args, project=args.exp_name, name=args.exp_name, tags=[args.dataset]) dist.barrier() # build model model, param_list = build_segmenter(args) model = model.cuda(args.gpu) if hasattr(model, 'text_encoder'): model.text_encoder = model.text_encoder.cuda(args.gpu) logger.info(f"Model moved to GPU: {args.gpu}") logger.info(args) # build optimizer & lr scheduler optimizer = torch.optim.AdamW(param_list, lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.amsgrad ) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True ) scaler = amp.GradScaler() args.batch_size = int(args.batch_size / dist.get_world_size()) args.batch_size_val = int(args.batch_size_val / dist.get_world_size()) args.workers = int((args.workers + dist.get_world_size() - 1) / dist.get_world_size()) # build dataset train_data = RefDataset_gref(lmdb_dir=args.train_lmdb, mask_dir=args.mask_root, dataset=args.dataset, split=args.train_split, mode='train', input_size=args.input_size, word_length=args.word_len, args=args ) val_data = RefDataset_gref(lmdb_dir=args.val_lmdb, mask_dir=args.mask_root, dataset=args.dataset, split=args.val_split, mode='val', input_size=args.input_size, word_length=args.word_len, args=args ) # build dataloader init_fn = partial(worker_init_fn, num_workers=args.workers, rank=args.rank, seed=args.manual_seed) train_sampler = data.distributed.DistributedSampler(train_data, shuffle=True) val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) train_loader = data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, worker_init_fn=init_fn, sampler=train_sampler, collate_fn=collate_fn, drop_last=True) val_loader = data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers_val, pin_memory=True, sampler=val_sampler, drop_last=False, collate_fn=collate_fn, ) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9) best_IoU = 0.0 # resume if args.resume: if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] best_IoU = checkpoint["best_iou"] checkpoint['model_state_dict'].pop('decoder.tokens.weight') optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: raise ValueError( "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" .format(args.resume)) # start training start_time = time.time() for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 # shuffle loader train_sampler.set_epoch(epoch_log) # train train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) # evaluation iou, prec_dict = validate(val_loader, model, epoch_log, args) # save model if dist.get_rank() == 0: lastname = os.path.join(args.output_dir, "last_model.pth") torch.save( { 'epoch': epoch_log, 'cur_iou': iou, 'best_iou': best_IoU, 'prec': prec_dict, 'model_state_dict': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, lastname) if iou >= best_IoU and epoch_log<50: best_IoU = iou bestname = os.path.join(args.output_dir, "best_model.pth") shutil.copyfile(lastname, bestname) torch.cuda.empty_cache() time.sleep(2) if dist.get_rank() == 0: try: wandb.finish() except AttributeError: logger.warning("Failed to properly finish wandb run due to StreamToLoguru conflict") logger.info("* Best IoU={} * ".format(best_IoU)) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('* Training time {} *'.format(total_time_str)) if __name__ == '__main__': main() sys.exit(0)