| import argparse |
| import datetime |
| import os |
| import shutil |
| import sys |
| import time |
| import warnings |
| from functools import partial |
|
|
| import cv2 |
| import torch |
| import torch.cuda.amp as amp |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| import torch.nn as nn |
| import torch.nn.parallel |
| import torch.optim |
| import torch.utils.data as data |
| from loguru import logger |
| from torch.optim.lr_scheduler import MultiStepLR |
|
|
| import utils.config as config |
| import wandb |
| from utils.dataset_mosaic import RefDataset |
| from engine.engine import train, validate |
| from model import build_segmenter |
| from utils.misc import (init_random_seed, set_random_seed, setup_logger, |
| worker_init_fn, build_scheduler) |
|
|
| warnings.filterwarnings("ignore") |
| warnings.filterwarnings("ignore", category=UserWarning) |
| cv2.setNumThreads(0) |
|
|
| torch.cuda.empty_cache() |
|
|
| import deepspeed |
| from deepspeed.runtime.lr_schedules import WarmupLR |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| description='Pytorch Referring Expression Segmentation') |
| parser.add_argument('--config', |
| default='path to xxx.yaml', |
| type=str, |
| help='config file') |
| parser.add_argument('--opts', |
| default=None, |
| nargs=argparse.REMAINDER, |
| help='override some settings in the config.') |
| parser.add_argument("--local_rank", type=int, default=0) |
| parser = deepspeed.add_config_arguments(parser) |
| args = parser.parse_args() |
| assert args.config is not None |
| cfg = config.load_cfg_from_cfg_file(args.config) |
| if args.opts is not None: |
| cfg = config.merge_cfg_from_list(cfg, args.opts) |
| return cfg |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def main(args): |
| |
| |
| args.rank = int(os.environ["RANK"]) |
| args.world_size = int(os.environ['WORLD_SIZE']) |
| args.gpu = int(os.environ['LOCAL_RANK']) |
| args.output_dir = os.path.join(args.output_folder, args.exp_name) |
| args.dist_url = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" |
| torch.cuda.set_device(args.gpu) |
|
|
| |
| setup_logger(args.output_dir, |
| distributed_rank=args.gpu, |
| filename="train.log", |
| mode="a") |
| |
| |
| |
| |
| |
| deepspeed.init_distributed(init_method=args.dist_url, |
| world_size=args.world_size, |
| rank=args.rank) |
| print("dist init done") |
|
|
| |
| if args.rank == 0: |
| wandb.init(job_type="training", |
| mode="offline", |
| config=args, |
| project=args.exp_name, |
| name=args.exp_name, |
| tags=[args.dataset]) |
| dist.barrier() |
| |
| deepspeed_config = { |
| "train_batch_size": args.batch_size, |
| "gradient_accumulation_steps": 2, |
| "fp16": { |
| "enabled": True, |
| "auto_cast": False, |
| "loss_scale": 0, |
| "initial_scale_power": 16, |
| "loss_scale_window": 1000, |
| "hysteresis": 2, |
| "consecutive_hysteresis": False, |
| "min_loss_scale": 1 |
| }, |
| "zero_optimization": { |
| "stage": 3, |
| "contiguous_gradients": True, |
| "overlap_comm": True, |
| "reduce_scatter": True, |
| "reduce_bucket_size": 5e8, |
| "allgather_bucket_size": 5e8, |
| "offload_optimizer": { |
| "device": "cpu" |
| }, |
| "offload_param": { |
| "device": "cpu" |
| } |
| }, |
| "optimizer": { |
| "type": "AdamW", |
| "params": { |
| "lr": args.lr, |
| "weight_decay": args.weight_decay, |
| "betas": [ |
| 0.9, |
| 0.999 |
| ], |
| "eps": 1e-8, |
| "amsgrad": False |
| } |
| }, |
| } |
| |
| |
| model, param_list = build_segmenter(args) |
| |
| logger.info(args) |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| grad_acc_steps = deepspeed_config['gradient_accumulation_steps'] |
| args.batch_size = int(args.batch_size / args.ngpus_per_node / grad_acc_steps) |
| args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node) |
| args.workers = int( |
| (args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node) |
| train_data = RefDataset(lmdb_dir=args.train_lmdb, |
| mask_dir=args.mask_root, |
| dataset=args.dataset, |
| split=args.train_split, |
| mode='train', |
| input_size=args.input_size, |
| word_length=args.word_len, |
| args=args) |
| val_data = RefDataset(lmdb_dir=args.val_lmdb, |
| mask_dir=args.mask_root, |
| dataset=args.dataset, |
| split=args.val_split, |
| mode='val', |
| input_size=args.input_size, |
| word_length=args.word_len, |
| args=args) |
|
|
| |
| init_fn = partial(worker_init_fn, |
| num_workers=args.workers, |
| rank=args.rank, |
| seed=args.manual_seed) |
| train_sampler = data.distributed.DistributedSampler(train_data, |
| shuffle=True) |
| val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) |
| |
| train_loader = data.DataLoader(train_data, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.workers, |
| pin_memory=True, |
| worker_init_fn=init_fn, |
| sampler=train_sampler, |
| |
| drop_last=True) |
| val_loader = data.DataLoader(val_data, |
| batch_size=args.batch_size_val, |
| shuffle=False, |
| num_workers=args.workers_val, |
| pin_memory=True, |
| sampler=val_sampler, |
| drop_last=False, |
| |
| ) |
|
|
| |
| scaler = None |
| |
| scheduler = partial(torch.optim.lr_scheduler.LambdaLR, |
| lr_lambda=lambda x: (1 - x / (len(train_data)/args.batch_size * args.epochs)) ** 0.9) |
| model, optimizer, _, scheduler = deepspeed.initialize(model=model, |
| config_params=deepspeed_config, |
| model_parameters=param_list, |
| lr_scheduler=scheduler, |
| dist_init_required=True) |
| |
| best_IoU = 0.0 |
| |
| if args.resume: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if os.path.isfile(args.resume): |
| logger.info("=> loading checkpoint '{}'".format(args.resume, args.ckpt_id)) |
| _, client_sd = model.load_checkpoint(args.resume) |
| else: |
| raise ValueError( |
| "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" |
| .format(args.resume)) |
|
|
| |
| start_time = time.time() |
| for epoch in range(args.start_epoch, args.epochs): |
| epoch_log = epoch + 1 |
|
|
| |
| train_sampler.set_epoch(epoch_log) |
|
|
| |
| train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) |
|
|
| |
| iou, prec_dict = validate(val_loader, model, epoch_log, args) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if dist.get_rank()==0: |
| found_best_epoch = [(iou>=best_IoU and epoch_log<50)] |
| else: |
| found_best_epoch = [False] |
| dist.broadcast_object_list(found_best_epoch, src=0, device=model.device) |
| |
| model.save_checkpoint(args.output_dir, "last_model") |
| if found_best_epoch[0]: |
| model.save_checkpoint(args.output_dir, "best_model") |
| |
| torch.cuda.empty_cache() |
|
|
| time.sleep(2) |
| if dist.get_rank() == 0: |
| wandb.finish() |
|
|
| logger.info("* Best IoU={} * ".format(best_IoU)) |
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| logger.info('* Training time {} *'.format(total_time_str)) |
|
|
|
|
| if __name__ == '__main__': |
| args = get_parser() |
| args.manual_seed = init_random_seed(args.manual_seed) |
| set_random_seed(args.manual_seed, deterministic=True) |
| args.ngpus_per_node = torch.cuda.device_count() |
| args.world_size = args.ngpus_per_node * args.world_size |
| |
| main(args) |
| sys.exit(0) |
|
|