| import torch
|
| import numpy as np
|
| import torch.backends.cudnn as cudnn
|
| import torch.distributed as dist
|
| import torch.backends.cudnn as cudnn
|
| import os
|
| import logging
|
|
|
| def setup_for_distributed(is_master):
|
| """
|
| This function disables printing when not in master process
|
| """
|
| import builtins as __builtin__
|
| builtin_print = __builtin__.print
|
|
|
| def print(*args, **kwargs):
|
| force = kwargs.pop('force', False)
|
| if is_master or force:
|
| builtin_print(*args, **kwargs)
|
|
|
| __builtin__.print = print
|
|
|
| def is_dist_avail_and_initialized():
|
| if not dist.is_available():
|
| return False
|
| if not dist.is_initialized():
|
| return False
|
| return True
|
|
|
| def get_world_size():
|
| if not is_dist_avail_and_initialized():
|
| return 1
|
| return dist.get_world_size()
|
|
|
| def get_rank():
|
| if not is_dist_avail_and_initialized():
|
| return 0
|
| return dist.get_rank()
|
|
|
| def is_main_process():
|
| return get_rank() == 0
|
|
|
| def save_on_master(*args, **kwargs):
|
| if is_main_process():
|
| torch.save(*args, **kwargs)
|
|
|
| def init_distributed_mode(args):
|
| if args.dist_on_itp:
|
| args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| os.environ['LOCAL_RANK'] = str(args.gpu)
|
| os.environ['RANK'] = str(args.rank)
|
| os.environ['WORLD_SIZE'] = str(args.world_size)
|
|
|
| elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| args.rank = int(os.environ["RANK"])
|
| args.world_size = int(os.environ['WORLD_SIZE'])
|
| args.gpu = int(os.environ['LOCAL_RANK'])
|
| elif 'SLURM_PROCID' in os.environ:
|
| args.rank = int(os.environ['SLURM_PROCID'])
|
| args.gpu = args.rank % torch.cuda.device_count()
|
|
|
| os.environ['RANK'] = str(args.rank)
|
| os.environ['LOCAL_RANK'] = str(args.gpu)
|
| os.environ['WORLD_SIZE'] = str(args.world_size)
|
| else:
|
| logging.info('Not using distributed mode')
|
| args.distributed = False
|
| return
|
|
|
| args.distributed = True
|
|
|
| torch.cuda.set_device(args.gpu)
|
| args.dist_backend = 'nccl'
|
| logging.info('| distributed init (rank {}): {}, gpu {}'.format(
|
| args.rank, args.dist_url, args.gpu))
|
| torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| world_size=args.world_size, rank=args.rank)
|
| torch.distributed.barrier()
|
| setup_for_distributed(args.rank == 0)
|
|
|
| def fix_seed(seed):
|
| torch.manual_seed(seed)
|
| np.random.seed(seed)
|
| cudnn.benchmark = True
|
|
|