| import argparse |
| import datetime |
| import os |
| import shutil |
| import sys |
| import time |
| import warnings |
| from functools import partial |
|
|
| import cv2 |
| import torch |
| import torch.cuda.amp as amp |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| import torch.nn as nn |
| import torch.nn.parallel |
| import torch.optim |
| import torch.utils.data as data |
| from loguru import logger |
| from torch.optim.lr_scheduler import MultiStepLR |
|
|
| import utils.config as config |
| import wandb |
| from model import build_segmenter |
| |
| |
| from utils.misc import (init_random_seed, set_random_seed, setup_logger, |
| worker_init_fn) |
|
|
| warnings.filterwarnings("ignore") |
| cv2.setNumThreads(0) |
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| description='Pytorch Referring Expression Segmentation') |
| parser.add_argument('--config', |
| default='path to xxx.yaml', |
| type=str, |
| help='config file') |
| parser.add_argument('--opts', |
| default=None, |
| nargs=argparse.REMAINDER, |
| help='override some settings in the config.') |
| parser.add_argument('--local-rank', |
| type=int, |
| default=0, |
| help='local rank for distributed training') |
|
|
| args = parser.parse_args() |
| assert args.config is not None |
| cfg = config.load_cfg_from_cfg_file(args.config) |
| if args.opts is not None: |
| cfg = config.merge_cfg_from_list(cfg, args.opts) |
| return cfg |
|
|
|
|
|
|
| @logger.catch |
| def main(): |
| args = get_parser() |
| args.manual_seed = init_random_seed(args.manual_seed) |
| set_random_seed(args.manual_seed, deterministic=True) |
| |
| if 'LOCAL_RANK' in os.environ: |
| args.local_rank = int(os.environ['LOCAL_RANK']) |
| logger.info(f"LOCAL_RANK from env: {args.local_rank}") |
| |
| if 'LOCAL_RANK' in os.environ: |
| main_worker_ddp(args) |
| else: |
| args.ngpus_per_node = torch.cuda.device_count() |
| args.world_size = args.ngpus_per_node * getattr(args, 'world_size', 1) |
| mp.spawn(main_worker_mp, nprocs=args.ngpus_per_node, args=(args,)) |
|
|
|
|
| def main_worker_ddp(args): |
| args.output_dir = os.path.join(args.output_folder, args.exp_name) |
| |
| args.gpu = args.local_rank |
| args.rank = args.local_rank |
| args.world_size = int(os.environ.get('WORLD_SIZE', 1)) |
| |
| torch.cuda.set_device(args.gpu) |
| |
| setup_logger(args.output_dir, |
| distributed_rank=args.gpu, |
| filename="train.log", |
| mode="a") |
| |
| logger.info(f"Starting with GPU: {args.gpu}, Rank: {args.rank}, World Size: {args.world_size}") |
| |
| dist_url = os.environ.get('MASTER_ADDR', 'localhost') + ':' + os.environ.get('MASTER_PORT', '12355') |
| dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'), |
| init_method=f"env://", |
| world_size=args.world_size, |
| rank=args.rank) |
| |
| run_training(args) |
|
|
|
|
| def main_worker_mp(gpu, args): |
| args.output_dir = os.path.join(args.output_folder, args.exp_name) |
|
|
| |
| args.gpu = gpu |
|
|
| rank = getattr(args, 'rank', 0) |
| args.rank = rank * args.ngpus_per_node + gpu |
| torch.cuda.set_device(args.gpu) |
|
|
| setup_logger(args.output_dir, |
| distributed_rank=args.gpu, |
| filename="train.log", |
| mode="a") |
| |
| dist_url = getattr(args, 'dist_url', f'tcp://localhost:12355') |
| dist.init_process_group(backend=getattr(args, 'dist_backend', 'nccl'), |
| init_method=dist_url, |
| world_size=args.world_size, |
| rank=args.rank) |
| |
| run_training(args) |
|
|
|
|
| def run_training(args): |
| |
| if args.rank == 0: |
| wandb.init(job_type="training", |
| mode="offline", |
| config=args, |
| project=args.exp_name, |
| name=args.exp_name, |
| tags=[args.dataset]) |
| dist.barrier() |
| |
| |
| if args.dataset == 'ref-zom' : |
| from engine.engine_refzom_2 import train, validate |
| from utils.dataset_zom_sbert import RefZom_FilterDataset, Refzom_DistributedSampler |
| |
| model, param_list = build_segmenter(args) |
| |
| model = model.cuda(args.gpu) |
| |
| if hasattr(model, 'text_encoder'): |
| model.text_encoder = model.text_encoder.cuda(args.gpu) |
| |
| logger.info(f"Model moved to GPU: {args.gpu}") |
| logger.info(args) |
| |
| |
| optimizer = torch.optim.AdamW(param_list, |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| amsgrad=args.amsgrad |
| ) |
|
|
| model = torch.nn.parallel.DistributedDataParallel( |
| model, |
| device_ids=[args.gpu], |
| find_unused_parameters=True |
| ) |
|
|
| scaler = amp.GradScaler() |
| |
| args.batch_size = int(args.batch_size / dist.get_world_size()) |
| args.batch_size_val = int(args.batch_size_val / dist.get_world_size()) |
| args.workers = int((args.workers + dist.get_world_size() - 1) / dist.get_world_size()) |
| |
| |
| train_data = RefZom_FilterDataset(lmdb_dir=args.train_lmdb, |
| mask_dir=args.mask_root, |
| dataset=args.dataset, |
| split=args.train_split, |
| mode='train', |
| input_size=args.input_size, |
| word_length=args.word_len, |
| args=args |
| ) |
| val_data = RefZom_FilterDataset(lmdb_dir=args.val_lmdb, |
| mask_dir=args.mask_root, |
| dataset=args.dataset, |
| split=args.val_split, |
| mode='test', |
| input_size=args.input_size, |
| word_length=args.word_len, |
| args=args |
| ) |
|
|
| |
| init_fn = partial(worker_init_fn, |
| num_workers=args.workers, |
| rank=args.rank, |
| seed=args.manual_seed) |
| train_sampler = Refzom_DistributedSampler( |
| train_data, num_replicas=args.world_size, rank=args.rank, shuffle=True |
| ) |
| |
| val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) |
| |
| train_loader = data.DataLoader(train_data, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.workers, |
| pin_memory=True, |
| worker_init_fn=init_fn, |
| sampler=train_sampler, |
| drop_last=True) |
| val_loader = data.DataLoader(val_data, |
| batch_size=args.batch_size_val, |
| shuffle=False, |
| num_workers=args.workers_val, |
| pin_memory=True, |
| sampler=val_sampler, |
| drop_last=False, |
| ) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR( |
| optimizer, lambda x: (1 - x / (len(train_loader) * args.epochs)) ** 0.9) |
|
|
| best_IoU = 0.0 |
| best_oIoU = 0.0 |
|
|
| |
| if args.resume: |
| if os.path.isfile(args.resume): |
| logger.info("=> loading checkpoint '{}'".format(args.resume)) |
| checkpoint = torch.load( |
| args.resume, map_location=lambda storage, loc: storage.cuda()) |
| args.start_epoch = checkpoint['epoch'] |
| best_IoU = checkpoint["best_iou"] |
| checkpoint['model_state_dict'].pop('decoder.tokens.weight') |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| scheduler.load_state_dict(checkpoint['scheduler']) |
| logger.info("=> loaded checkpoint '{}' (epoch {})".format( |
| args.resume, checkpoint['epoch'])) |
| else: |
| raise ValueError( |
| "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" |
| .format(args.resume)) |
|
|
| |
| start_time = time.time() |
| for epoch in range(args.start_epoch, args.epochs): |
| epoch_log = epoch + 1 |
|
|
| |
| train_sampler.set_epoch(epoch_log) |
|
|
| |
| train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) |
|
|
| torch.cuda.empty_cache() |
| dist.barrier() |
| |
| |
| print("Start Evaluation : epoch ", epoch_log) |
| iou, oiou, prec_dict, mean_acc = validate(val_loader, model, epoch_log, args) |
|
|
| |
| if dist.get_rank() == 0: |
| lastname = os.path.join(args.output_dir, "last_model.pth") |
| torch.save( |
| { |
| 'epoch': epoch_log, |
| 'cur_iou': iou, |
| 'best_iou': best_IoU, |
| 'best_oiou' : best_oIoU, |
| 'prec': prec_dict, |
| 'mean_acc' : mean_acc, |
| 'model_state_dict': model.module.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'scheduler': scheduler.state_dict() |
| }, lastname) |
| if iou >= best_IoU and epoch_log<50: |
| best_IoU = iou |
| bestname = os.path.join(args.output_dir, "best_model_miou.pth") |
| shutil.copyfile(lastname, bestname) |
| if oiou >= best_oIoU and epoch_log<50: |
| best_oIoU = oiou |
| bestname_oiou = os.path.join(args.output_dir, "best_model_oiou.pth") |
| shutil.copyfile(lastname, bestname_oiou) |
|
|
| torch.cuda.empty_cache() |
|
|
| time.sleep(2) |
| if dist.get_rank() == 0: |
| try: |
| wandb.finish() |
| except AttributeError: |
| logger.warning("Failed to properly finish wandb run due to StreamToLoguru conflict") |
|
|
| logger.info("* Best IoU={} * ".format(best_IoU)) |
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| logger.info('* Training time {} *'.format(total_time_str)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
| sys.exit(0) |