import argparse import datetime import os import shutil import sys import time import warnings from functools import partial import cv2 import torch import torch.cuda.amp as amp import torch.distributed as dist import torch.multiprocessing as mp #https://blog.csdn.net/hxxjxw/article/details/119839548 import torch.nn as nn import torch.nn.parallel import torch.optim import torch.utils.data as data from loguru import logger # https://hanjunqiang.blog.csdn.net/article/details/124779625 from torch.optim.lr_scheduler import MultiStepLR import utils.config as config import wandb from utils.dataset import RefDataset from engine.engine import train, validate from model import build_segmenter from utils.misc import (init_random_seed, set_random_seed, setup_logger, worker_init_fn, build_scheduler) #, collate_fn) warnings.filterwarnings("ignore") warnings.filterwarnings("ignore", category=UserWarning) cv2.setNumThreads(0) torch.cuda.empty_cache() import deepspeed from deepspeed.runtime.lr_schedules import WarmupLR def get_parser(): parser = argparse.ArgumentParser( description='Pytorch Referring Expression Segmentation') parser.add_argument('--config', default='path to xxx.yaml', type=str, help='config file') parser.add_argument('--opts', default=None, nargs=argparse.REMAINDER, help='override some settings in the config.') parser.add_argument("--local_rank", type=int, default=0) parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() assert args.config is not None cfg = config.load_cfg_from_cfg_file(args.config) if args.opts is not None: cfg = config.merge_cfg_from_list(cfg, args.opts) return cfg # @logger.catch # def main(): # args = get_parser() # args.manual_seed = init_random_seed(args.manual_seed) # set_random_seed(args.manual_seed, deterministic=True) # args.ngpus_per_node = torch.cuda.device_count() # args.world_size = args.ngpus_per_node * args.world_size # #mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args, )) def main(args): # local rank & global rank args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) args.output_dir = os.path.join(args.output_folder, args.exp_name) args.dist_url = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" torch.cuda.set_device(args.gpu) # logger setup_logger(args.output_dir, distributed_rank=args.gpu, filename="train.log", mode="a") # dist init # dist.init_process_group(backend=args.dist_backend, # init_method=args.dist_url, # world_size=args.world_size, # rank=args.rank) deepspeed.init_distributed(init_method=args.dist_url, #args.dist_backend, world_size=args.world_size, rank=args.rank) print("dist init done") # wandb if args.rank == 0: wandb.init(job_type="training", mode="offline", config=args, project=args.exp_name, name=args.exp_name, tags=[args.dataset]) dist.barrier() deepspeed_config = { "train_batch_size": args.batch_size, "gradient_accumulation_steps": 2, "fp16": { "enabled": True, "auto_cast": False, "loss_scale": 0, "initial_scale_power": 16, "loss_scale_window": 1000, "hysteresis": 2, "consecutive_hysteresis": False, "min_loss_scale": 1 }, "zero_optimization": { "stage": 3, "contiguous_gradients": True, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 5e8, "allgather_bucket_size": 5e8, "offload_optimizer": { "device": "cpu" }, "offload_param": { "device": "cpu" } }, "optimizer": { "type": "AdamW", "params": { "lr": args.lr, "weight_decay": args.weight_decay, "betas": [ 0.9, 0.999 ], "eps": 1e-8, "amsgrad": False } }, } # build model model, param_list = build_segmenter(args) # logger.info(model) logger.info(args) # build optimizer & lr scheduler # optimizer = torch.optim.AdamW(param_list, # lr=args.lr, # weight_decay=args.weight_decay, # amsgrad=args.amsgrad # ) # build dataset grad_acc_steps = deepspeed_config['gradient_accumulation_steps'] args.batch_size = int(args.batch_size / args.ngpus_per_node / grad_acc_steps) args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node) args.workers = int( (args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node) train_data = RefDataset(lmdb_dir=args.train_lmdb, mask_dir=args.mask_root, dataset=args.dataset, split=args.train_split, mode='train', input_size=args.input_size, word_length=args.word_len, args=args) val_data = RefDataset(lmdb_dir=args.val_lmdb, mask_dir=args.mask_root, dataset=args.dataset, split=args.val_split, mode='val', input_size=args.input_size, word_length=args.word_len, args=args) # build dataloader init_fn = partial(worker_init_fn, num_workers=args.workers, rank=args.rank, seed=args.manual_seed) train_sampler = data.distributed.DistributedSampler(train_data, shuffle=True) val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) train_loader = data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, worker_init_fn=init_fn, sampler=train_sampler, #collate_fn=collate_fn, drop_last=True) val_loader = data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers_val, pin_memory=True, sampler=val_sampler, drop_last=False, #collate_fn=collate_fn, ) #scheduler = WarmupLR(optimizer) scaler = None # amp.GradScaler() # torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9) scheduler = partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda x: (1 - x / (len(train_data)/args.batch_size * args.epochs)) ** 0.9) #len(train_loader) model, optimizer, _, scheduler = deepspeed.initialize(model=model, config_params=deepspeed_config, model_parameters=param_list, lr_scheduler=scheduler, dist_init_required=True) best_IoU = 0.0 # resume if args.resume: # if os.path.isfile(args.resume): # logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load( # args.resume, map_location=lambda storage, loc: storage.cuda()) # args.start_epoch = checkpoint['epoch'] # best_IoU = checkpoint["best_iou"] # checkpoint['model_state_dict'].pop('decoder.tokens.weight') # optimizer.load_state_dict(checkpoint['optimizer']) # scheduler.load_state_dict(checkpoint['scheduler']) # logger.info("=> loaded checkpoint '{}' (epoch {})".format( # args.resume, checkpoint['epoch'])) if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume, args.ckpt_id)) _, client_sd = model.load_checkpoint(args.resume) else: raise ValueError( "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" .format(args.resume)) # start training start_time = time.time() for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 # shuffle loader train_sampler.set_epoch(epoch_log) # train train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) # evaluation iou, prec_dict = validate(val_loader, model, epoch_log, args) # save model # if dist.get_rank() == 0: # lastname = os.path.join(args.output_dir, "last_model.pth") # torch.save( # { # 'epoch': epoch_log, # 'cur_iou': iou, # 'best_iou': best_IoU, # 'prec': prec_dict, # 'model_state_dict': model.module.state_dict(), # 'optimizer': optimizer.state_dict(), # 'scheduler': scheduler.state_dict() # }, lastname) # if iou >= best_IoU and epoch_log<50: # best_IoU = iou # bestname = os.path.join(args.output_dir, "best_model.pth") # shutil.copyfile(lastname, bestname) if dist.get_rank()==0: found_best_epoch = [(iou>=best_IoU and epoch_log<50)] else: found_best_epoch = [False] dist.broadcast_object_list(found_best_epoch, src=0, device=model.device) model.save_checkpoint(args.output_dir, "last_model") if found_best_epoch[0]: model.save_checkpoint(args.output_dir, "best_model") torch.cuda.empty_cache() time.sleep(2) if dist.get_rank() == 0: wandb.finish() logger.info("* Best IoU={} * ".format(best_IoU)) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('* Training time {} *'.format(total_time_str)) if __name__ == '__main__': args = get_parser() args.manual_seed = init_random_seed(args.manual_seed) set_random_seed(args.manual_seed, deterministic=True) args.ngpus_per_node = torch.cuda.device_count() args.world_size = args.ngpus_per_node * args.world_size #mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args, )) main(args) sys.exit(0)