| | import os |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import pdb |
| |
|
| |
|
| | def dist_pdb(rank, in_rank=0): |
| | if rank != in_rank: |
| | dist.barrier() |
| | else: |
| | pdb.set_trace() |
| | dist.barrier() |
| |
|
| |
|
| | def init_distributed_mode(args): |
| | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| | args.rank = int(os.environ["RANK"]) |
| | args.world_size = int(os.environ['WORLD_SIZE']) |
| | args.gpu = int(os.environ['LOCAL_RANK']) |
| | elif 'SLURM_PROCID' in os.environ: |
| | args.rank = int(os.environ['SLURM_PROCID']) |
| | args.gpu = args.rank % torch.cuda.device_count() |
| | else: |
| | print('Not using distributed mode') |
| | args.distributed = False |
| | return |
| |
|
| | args.distributed = True |
| |
|
| | torch.cuda.set_device(args.gpu) |
| | args.dist_backend = 'nccl' |
| | print('| distributed init (rank {}): {}'.format( |
| | args.rank, args.dist_url), flush=True) |
| | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| | world_size=args.world_size, rank=args.rank) |
| | dist.barrier() |
| |
|
| |
|
| | def cleanup(): |
| | dist.destroy_process_group() |
| |
|
| |
|
| | def is_dist_avail_and_initialized(): |
| | """检查是否支持分布式环境""" |
| | if not dist.is_available(): |
| | return False |
| | if not dist.is_initialized(): |
| | return False |
| | return True |
| |
|
| |
|
| | def get_world_size(): |
| | if not is_dist_avail_and_initialized(): |
| | return 1 |
| | return dist.get_world_size() |
| |
|
| |
|
| | def get_rank(): |
| | if not is_dist_avail_and_initialized(): |
| | return 0 |
| | return dist.get_rank() |
| |
|
| |
|
| | def is_main_process(): |
| | return get_rank() == 0 |
| |
|
| |
|
| | def reduce_value(value, average=True): |
| | world_size = get_world_size() |
| | if world_size < 2: |
| | return value |
| |
|
| | with torch.no_grad(): |
| | dist.all_reduce(value) |
| | if average: |
| | value /= world_size |
| |
|
| | return value |
| |
|