| | import os |
| | import warnings |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.distributed as dist |
| |
|
| | try: |
| | import horovod.torch as hvd |
| | except ImportError: |
| | hvd = None |
| |
|
| |
|
| | def is_global_master(args): |
| | return args.rank == 0 |
| |
|
| |
|
| | def is_local_master(args): |
| | return args.local_rank == 0 |
| |
|
| |
|
| | def is_master(args, local=False): |
| | return is_local_master(args) if local else is_global_master(args) |
| |
|
| |
|
| | def is_device_available(device): |
| | device_type = torch.device(device).type |
| | is_avail = False |
| | is_known = False |
| | if device_type == 'cuda': |
| | is_avail = torch.cuda.is_available() |
| | is_known = True |
| | elif device_type == 'npu': |
| | |
| | is_avail = torch.npu.is_available() |
| | is_known = True |
| | elif device_type == 'mps': |
| | is_avail = torch.backends.mps.is_available() |
| | is_known = True |
| | elif device_type == 'cpu': |
| | is_avail = True |
| | is_known = True |
| |
|
| | return is_avail, is_known |
| |
|
| |
|
| | def set_device(device): |
| | if device.startswith('cuda:'): |
| | torch.cuda.set_device(device) |
| | elif device.startswith('npu:'): |
| | torch.npu.set_device(device) |
| |
|
| |
|
| | def is_using_horovod(): |
| | |
| | |
| | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] |
| | pmi_vars = ["PMI_RANK", "PMI_SIZE"] |
| | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): |
| | return True |
| | else: |
| | return False |
| |
|
| |
|
| | def is_using_distributed(): |
| | if 'WORLD_SIZE' in os.environ: |
| | return int(os.environ['WORLD_SIZE']) > 1 |
| | if 'SLURM_NTASKS' in os.environ: |
| | return int(os.environ['SLURM_NTASKS']) > 1 |
| | return False |
| |
|
| |
|
| | def world_info_from_env(): |
| | local_rank = 0 |
| | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): |
| | if v in os.environ: |
| | local_rank = int(os.environ[v]) |
| | break |
| | global_rank = 0 |
| | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): |
| | if v in os.environ: |
| | global_rank = int(os.environ[v]) |
| | break |
| | world_size = 1 |
| | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): |
| | if v in os.environ: |
| | world_size = int(os.environ[v]) |
| | break |
| |
|
| | return local_rank, global_rank, world_size |
| |
|
| |
|
| | def init_distributed_device(args): |
| | |
| | |
| | args.distributed = False |
| | args.world_size = 1 |
| | args.rank = 0 |
| | args.local_rank = 0 |
| | result = init_distributed_device_so( |
| | device=getattr(args, 'device', 'cuda'), |
| | dist_backend=getattr(args, 'dist_backend', None), |
| | dist_url=getattr(args, 'dist_url', None), |
| | horovod=getattr(args, 'horovod', False), |
| | no_set_device_rank=getattr(args, 'no_set_device_rank', False), |
| | ) |
| | args.device = result['device'] |
| | args.world_size = result['world_size'] |
| | args.rank = result['global_rank'] |
| | args.local_rank = result['local_rank'] |
| | args.distributed = result['distributed'] |
| | device = torch.device(args.device) |
| | return device |
| |
|
| |
|
| | def init_distributed_device_so( |
| | device: str = 'cuda', |
| | dist_backend: Optional[str] = None, |
| | dist_url: Optional[str] = None, |
| | horovod: bool = False, |
| | no_set_device_rank: bool = False, |
| | ): |
| | |
| | |
| | distributed = False |
| | world_size = 1 |
| | global_rank = 0 |
| | local_rank = 0 |
| | device_type, *device_idx = device.split(':', maxsplit=1) |
| | is_avail, is_known = is_device_available(device_type) |
| | if not is_known: |
| | warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") |
| | elif not is_avail: |
| | warnings.warn(f"Device {device} was not available, falling back to CPU.") |
| | device_type = device = 'cpu' |
| |
|
| | if horovod: |
| | import horovod.torch as hvd |
| | assert hvd is not None, "Horovod is not installed" |
| | hvd.init() |
| | local_rank = int(hvd.local_rank()) |
| | global_rank = hvd.rank() |
| | world_size = hvd.size() |
| | distributed = True |
| | elif is_using_distributed(): |
| | if dist_backend is None: |
| | dist_backends = { |
| | "cuda": "nccl", |
| | "hpu": "hccl", |
| | "npu": "hccl", |
| | "xpu": "ccl", |
| | } |
| | dist_backend = dist_backends.get(device_type, 'gloo') |
| |
|
| | dist_url = dist_url or 'env://' |
| |
|
| | if 'SLURM_PROCID' in os.environ: |
| | |
| | local_rank, global_rank, world_size = world_info_from_env() |
| | |
| | os.environ['LOCAL_RANK'] = str(local_rank) |
| | os.environ['RANK'] = str(global_rank) |
| | os.environ['WORLD_SIZE'] = str(world_size) |
| | torch.distributed.init_process_group( |
| | backend=dist_backend, |
| | init_method=dist_url, |
| | world_size=world_size, |
| | rank=global_rank, |
| | ) |
| | else: |
| | |
| | local_rank, _, _ = world_info_from_env() |
| | torch.distributed.init_process_group( |
| | backend=dist_backend, |
| | init_method=dist_url, |
| | ) |
| | world_size = torch.distributed.get_world_size() |
| | global_rank = torch.distributed.get_rank() |
| | distributed = True |
| |
|
| | if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'): |
| | |
| | |
| | if device_idx: |
| | warnings.warn(f'device index {device_idx[0]} removed from specified ({device}).') |
| | device = f'{device_type}:{local_rank}' |
| | set_device(device) |
| |
|
| | return dict( |
| | device=device, |
| | global_rank=global_rank, |
| | local_rank=local_rank, |
| | world_size=world_size, |
| | distributed=distributed, |
| | ) |
| |
|
| |
|
| | def broadcast_object(args, obj, src=0): |
| | |
| | if args.horovod: |
| | return hvd.broadcast_object(obj, root_rank=src) |
| | else: |
| | if args.rank == src: |
| | objects = [obj] |
| | else: |
| | objects = [None] |
| | dist.broadcast_object_list(objects, src=src) |
| | return objects[0] |
| |
|
| |
|
| | def all_gather_object(args, obj, dst=0): |
| | |
| | if args.horovod: |
| | return hvd.allgather_object(obj) |
| | else: |
| | objects = [None for _ in range(args.world_size)] |
| | dist.all_gather_object(objects, obj) |
| | return objects |
| |
|