Spaces:
Runtime error
Runtime error
| import os | |
| import warnings | |
| from typing import Optional | |
| import torch | |
| import torch.distributed as dist | |
| try: | |
| import horovod.torch as hvd | |
| except ImportError: | |
| hvd = None | |
| def is_global_master(args): | |
| return args.rank == 0 | |
| def is_local_master(args): | |
| return args.local_rank == 0 | |
| def is_master(args, local=False): | |
| return is_local_master(args) if local else is_global_master(args) | |
| def is_device_available(device): | |
| device_type = torch.device(device).type | |
| is_avail = False | |
| is_known = False | |
| if device_type == 'cuda': | |
| is_avail = torch.cuda.is_available() | |
| is_known = True | |
| elif device_type == 'npu': | |
| # NOTE autoload device extension needed for this not to error out on this check | |
| is_avail = torch.npu.is_available() | |
| is_known = True | |
| elif device_type == 'mps': | |
| is_avail = torch.backends.mps.is_available() | |
| is_known = True | |
| elif device_type == 'cpu': | |
| is_avail = True | |
| is_known = True | |
| return is_avail, is_known | |
| def set_device(device): | |
| if device.startswith('cuda:'): | |
| torch.cuda.set_device(device) | |
| elif device.startswith('npu:'): | |
| torch.npu.set_device(device) | |
| def is_using_horovod(): | |
| # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set | |
| # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... | |
| ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] | |
| pmi_vars = ["PMI_RANK", "PMI_SIZE"] | |
| if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): | |
| return True | |
| else: | |
| return False | |
| def is_using_distributed(): | |
| if 'WORLD_SIZE' in os.environ: | |
| return int(os.environ['WORLD_SIZE']) > 1 | |
| if 'SLURM_NTASKS' in os.environ: | |
| return int(os.environ['SLURM_NTASKS']) > 1 | |
| return False | |
| def world_info_from_env(): | |
| local_rank = 0 | |
| for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): | |
| if v in os.environ: | |
| local_rank = int(os.environ[v]) | |
| break | |
| global_rank = 0 | |
| for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): | |
| if v in os.environ: | |
| global_rank = int(os.environ[v]) | |
| break | |
| world_size = 1 | |
| for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): | |
| if v in os.environ: | |
| world_size = int(os.environ[v]) | |
| break | |
| return local_rank, global_rank, world_size | |
| def init_distributed_device(args): | |
| # Distributed training = training on more than one GPU. | |
| # Works in both single and multi-node scenarios. | |
| args.distributed = False | |
| args.world_size = 1 | |
| args.rank = 0 # global rank | |
| args.local_rank = 0 | |
| result = init_distributed_device_so( | |
| device=getattr(args, 'device', 'cuda'), | |
| dist_backend=getattr(args, 'dist_backend', None), | |
| dist_url=getattr(args, 'dist_url', None), | |
| horovod=getattr(args, 'horovod', False), | |
| no_set_device_rank=getattr(args, 'no_set_device_rank', False), | |
| ) | |
| args.device = result['device'] | |
| args.world_size = result['world_size'] | |
| args.rank = result['global_rank'] | |
| args.local_rank = result['local_rank'] | |
| args.distributed = result['distributed'] | |
| device = torch.device(args.device) | |
| return device | |
| def init_distributed_device_so( | |
| device: str = 'cuda', | |
| dist_backend: Optional[str] = None, | |
| dist_url: Optional[str] = None, | |
| horovod: bool = False, | |
| no_set_device_rank: bool = False, | |
| ): | |
| # Distributed training = training on more than one GPU. | |
| # Works in both single and multi-node scenarios. | |
| distributed = False | |
| world_size = 1 | |
| global_rank = 0 | |
| local_rank = 0 | |
| device_type, *device_idx = device.split(':', maxsplit=1) | |
| is_avail, is_known = is_device_available(device_type) | |
| if not is_known: | |
| warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") | |
| elif not is_avail: | |
| warnings.warn(f"Device {device} was not available, falling back to CPU.") | |
| device_type = device = 'cpu' | |
| if horovod: | |
| import horovod.torch as hvd | |
| assert hvd is not None, "Horovod is not installed" | |
| hvd.init() | |
| local_rank = int(hvd.local_rank()) | |
| global_rank = hvd.rank() | |
| world_size = hvd.size() | |
| distributed = True | |
| elif is_using_distributed(): | |
| if dist_backend is None: | |
| dist_backends = { | |
| "cuda": "nccl", | |
| "hpu": "hccl", | |
| "npu": "hccl", | |
| "xpu": "ccl", | |
| } | |
| dist_backend = dist_backends.get(device_type, 'gloo') | |
| dist_url = dist_url or 'env://' | |
| if 'SLURM_PROCID' in os.environ: | |
| # DDP via SLURM | |
| local_rank, global_rank, world_size = world_info_from_env() | |
| # SLURM var -> torch.distributed vars in case needed | |
| os.environ['LOCAL_RANK'] = str(local_rank) | |
| os.environ['RANK'] = str(global_rank) | |
| os.environ['WORLD_SIZE'] = str(world_size) | |
| torch.distributed.init_process_group( | |
| backend=dist_backend, | |
| init_method=dist_url, | |
| world_size=world_size, | |
| rank=global_rank, | |
| ) | |
| else: | |
| # DDP via torchrun, torch.distributed.launch | |
| local_rank, _, _ = world_info_from_env() | |
| torch.distributed.init_process_group( | |
| backend=dist_backend, | |
| init_method=dist_url, | |
| ) | |
| world_size = torch.distributed.get_world_size() | |
| global_rank = torch.distributed.get_rank() | |
| distributed = True | |
| if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'): | |
| # Ignore manually specified device index in distributed mode and | |
| # override with resolved local rank, fewer headaches in most setups. | |
| if device_idx: | |
| warnings.warn(f'device index {device_idx[0]} removed from specified ({device}).') | |
| device = f'{device_type}:{local_rank}' | |
| set_device(device) | |
| return dict( | |
| device=device, | |
| global_rank=global_rank, | |
| local_rank=local_rank, | |
| world_size=world_size, | |
| distributed=distributed, | |
| ) | |
| def broadcast_object(args, obj, src=0): | |
| # broadcast a pickle-able python object from rank-0 to all ranks | |
| if args.horovod: | |
| return hvd.broadcast_object(obj, root_rank=src) | |
| else: | |
| if args.rank == src: | |
| objects = [obj] | |
| else: | |
| objects = [None] | |
| dist.broadcast_object_list(objects, src=src) | |
| return objects[0] | |
| def all_gather_object(args, obj, dst=0): | |
| # gather a pickle-able python object across all ranks | |
| if args.horovod: | |
| return hvd.allgather_object(obj) | |
| else: | |
| objects = [None for _ in range(args.world_size)] | |
| dist.all_gather_object(objects, obj) | |
| return objects | |