| | |
| | |
| |
|
| | |
| | |
| |
|
| | import datetime |
| | import functools |
| | import io |
| | import logging |
| | import os |
| | import random |
| | import tempfile |
| | import time |
| | from typing import Any, Callable, List, Tuple |
| |
|
| | import torch |
| | import torch.autograd as autograd |
| | import torch.distributed as dist |
| |
|
| |
|
| | |
| | _cuda_device_index: int = 0 |
| |
|
| | |
| | _CPU_DEVICE_INDEX = -1 |
| | _PRIMARY_RANK = 0 |
| |
|
| |
|
| | @functools.lru_cache() |
| | def _get_global_gloo_group(): |
| | """ |
| | Return a process group based on gloo backend, containing all the ranks |
| | The result is cached. |
| | """ |
| |
|
| | if dist.get_backend() == "nccl": |
| | |
| | |
| | |
| | timeout = 43200 |
| | return dist.new_group( |
| | backend="gloo", |
| | timeout=datetime.timedelta(seconds=timeout), |
| | ) |
| |
|
| | return dist.group.WORLD |
| |
|
| |
|
| | def is_main_process(): |
| | """Return true if the current process is the main one""" |
| | return get_rank() == 0 |
| |
|
| |
|
| | def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False): |
| | """ |
| | Run all_gather on arbitrary picklable data (not necessarily tensors), similar to |
| | `all_gather` above, but using filesystem instead of collective ops. |
| | |
| | If gather_to_rank_0_only is True, only rank 0 will load the gathered object list |
| | (and other ranks will have an empty list). |
| | """ |
| | world_size = get_world_size() |
| | if world_size == 1: |
| | return [data] |
| |
|
| | print("gathering via files") |
| | cpu_group = _get_global_gloo_group() |
| |
|
| | |
| | if filesys_save_dir is not None: |
| | save_dir = filesys_save_dir |
| | elif "EXP_DIR" in os.environ: |
| | save_dir = os.environ["EXP_DIR"] |
| | else: |
| | |
| | save_dir = filesys_save_dir or os.path.dirname(__file__) |
| | save_dir = os.path.join(save_dir, "all_gather_via_filesys") |
| | if is_main_process(): |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | |
| | timestamp = int(time.time()) if is_main_process() else 0 |
| | salt = random.randint(0, 2**31 - 1) if is_main_process() else 0 |
| | |
| | |
| | timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long) |
| | dist.all_reduce(timestamp_and_salt, group=cpu_group) |
| | timestamp, salt = timestamp_and_salt.tolist() |
| |
|
| | |
| | rank_save = get_rank() |
| | save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl" |
| | save_data_path = os.path.join(save_dir, save_data_filename) |
| | assert not os.path.exists(save_data_path), f"{save_data_path} already exists" |
| | torch.save(data, save_data_path) |
| | dist.barrier(group=cpu_group) |
| |
|
| | |
| | data_list = [] |
| | if rank_save == 0 or not gather_to_rank_0_only: |
| | for rank_load in range(world_size): |
| | load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl" |
| | load_data_path = os.path.join(save_dir, load_data_filename) |
| | assert os.path.exists(load_data_path), f"cannot read {save_data_path}" |
| | data_list.append(torch.load(load_data_path)) |
| | dist.barrier(group=cpu_group) |
| |
|
| | |
| | os.remove(save_data_path) |
| | return data_list |
| |
|
| |
|
| | def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None): |
| | """ |
| | Run all_gather on arbitrary picklable data (not necessarily tensors) |
| | Args: |
| | data: any picklable object |
| | Returns: |
| | list[data]: list of data gathered from each rank |
| | """ |
| |
|
| | world_size = get_world_size() |
| | if world_size == 1: |
| | return [data] |
| |
|
| | if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1": |
| | return all_gather_via_filesys( |
| | data, filesys_save_dir, gather_to_rank_0_only=True |
| | ) |
| |
|
| | if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys: |
| | return all_gather_via_filesys(data, filesys_save_dir) |
| |
|
| | cpu_group = None |
| | if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu: |
| | cpu_group = _get_global_gloo_group() |
| |
|
| | buffer = io.BytesIO() |
| | torch.save(data, buffer) |
| | data_view = buffer.getbuffer() |
| | device = "cuda" if cpu_group is None else "cpu" |
| | tensor = torch.ByteTensor(data_view).to(device) |
| |
|
| | |
| | local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) |
| | size_list = [ |
| | torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size) |
| | ] |
| | if cpu_group is None: |
| | dist.all_gather(size_list, local_size) |
| | else: |
| | print("gathering on cpu") |
| | dist.all_gather(size_list, local_size, group=cpu_group) |
| | size_list = [int(size.item()) for size in size_list] |
| | max_size = max(size_list) |
| | assert isinstance(local_size.item(), int) |
| | local_size = int(local_size.item()) |
| |
|
| | |
| | |
| | |
| | tensor_list = [] |
| | for _ in size_list: |
| | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) |
| | if local_size != max_size: |
| | padding = torch.empty( |
| | size=(max_size - local_size,), dtype=torch.uint8, device=device |
| | ) |
| | tensor = torch.cat((tensor, padding), dim=0) |
| | if cpu_group is None: |
| | dist.all_gather(tensor_list, tensor) |
| | else: |
| | dist.all_gather(tensor_list, tensor, group=cpu_group) |
| |
|
| | data_list = [] |
| | for size, tensor in zip(size_list, tensor_list): |
| | tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] |
| | buffer = io.BytesIO(tensor.cpu().numpy()) |
| | obj = torch.load(buffer) |
| | data_list.append(obj) |
| |
|
| | return data_list |
| |
|
| |
|
| | def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: |
| | """ |
| | For some backends, such as NCCL, communication only works if the |
| | tensor is on the GPU. This helper function converts to the correct |
| | device and returns the tensor + original device. |
| | """ |
| | orig_device = "cpu" if not tensor.is_cuda else "gpu" |
| | if ( |
| | torch.distributed.is_available() |
| | and torch.distributed.get_backend() == torch.distributed.Backend.NCCL |
| | and not tensor.is_cuda |
| | ): |
| | tensor = tensor.cuda() |
| | return (tensor, orig_device) |
| |
|
| |
|
| | def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: |
| | """ |
| | For some backends, such as NCCL, communication only works if the |
| | tensor is on the GPU. This converts the tensor back to original device. |
| | """ |
| | if tensor.is_cuda and orig_device == "cpu": |
| | tensor = tensor.cpu() |
| | return tensor |
| |
|
| |
|
| | def is_distributed_training_run() -> bool: |
| | return ( |
| | torch.distributed.is_available() |
| | and torch.distributed.is_initialized() |
| | and (torch.distributed.get_world_size() > 1) |
| | ) |
| |
|
| |
|
| | def is_primary() -> bool: |
| | """ |
| | Returns True if this is rank 0 of a distributed training job OR if it is |
| | a single trainer job. Otherwise False. |
| | """ |
| | return get_rank() == _PRIMARY_RANK |
| |
|
| |
|
| | def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Wrapper over torch.distributed.all_reduce for performing mean reduction |
| | of tensor over all processes. |
| | """ |
| | return all_reduce_op( |
| | tensor, |
| | torch.distributed.ReduceOp.SUM, |
| | lambda t: t / torch.distributed.get_world_size(), |
| | ) |
| |
|
| |
|
| | def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Wrapper over torch.distributed.all_reduce for performing sum |
| | reduction of tensor over all processes in both distributed / |
| | non-distributed scenarios. |
| | """ |
| | return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM) |
| |
|
| |
|
| | def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Wrapper over torch.distributed.all_reduce for performing min |
| | reduction of tensor over all processes in both distributed / |
| | non-distributed scenarios. |
| | """ |
| | return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN) |
| |
|
| |
|
| | def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Wrapper over torch.distributed.all_reduce for performing min |
| | reduction of tensor over all processes in both distributed / |
| | non-distributed scenarios. |
| | """ |
| | return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX) |
| |
|
| |
|
| | def all_reduce_op( |
| | tensor: torch.Tensor, |
| | op: torch.distributed.ReduceOp, |
| | after_op_func: Callable[[torch.Tensor], torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Wrapper over torch.distributed.all_reduce for performing |
| | reduction of tensor over all processes in both distributed / |
| | non-distributed scenarios. |
| | """ |
| | if is_distributed_training_run(): |
| | tensor, orig_device = convert_to_distributed_tensor(tensor) |
| | torch.distributed.all_reduce(tensor, op) |
| | if after_op_func is not None: |
| | tensor = after_op_func(tensor) |
| | tensor = convert_to_normal_tensor(tensor, orig_device) |
| | return tensor |
| |
|
| |
|
| | def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]: |
| | """ |
| | Wrapper over torch.distributed.all_gather for performing |
| | 'gather' of 'tensor' over all processes in both distributed / |
| | non-distributed scenarios. |
| | """ |
| | if tensor.ndim == 0: |
| | |
| | tensor = tensor.unsqueeze(0) |
| |
|
| | if is_distributed_training_run(): |
| | tensor, orig_device = convert_to_distributed_tensor(tensor) |
| | gathered_tensors = [ |
| | torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) |
| | ] |
| | torch.distributed.all_gather(gathered_tensors, tensor) |
| | gathered_tensors = [ |
| | convert_to_normal_tensor(_tensor, orig_device) |
| | for _tensor in gathered_tensors |
| | ] |
| | else: |
| | gathered_tensors = [tensor] |
| |
|
| | return gathered_tensors |
| |
|
| |
|
| | def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: |
| | gathered_tensors = gather_tensors_from_all(tensor) |
| | gathered_tensor = torch.cat(gathered_tensors, 0) |
| | return gathered_tensor |
| |
|
| |
|
| | def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: |
| | """ |
| | Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source |
| | to all processes in both distributed / non-distributed scenarios. |
| | """ |
| | if is_distributed_training_run(): |
| | tensor, orig_device = convert_to_distributed_tensor(tensor) |
| | torch.distributed.broadcast(tensor, src) |
| | tensor = convert_to_normal_tensor(tensor, orig_device) |
| | return tensor |
| |
|
| |
|
| | def barrier() -> None: |
| | """ |
| | Wrapper over torch.distributed.barrier, returns without waiting |
| | if the distributed process group is not initialized instead of throwing error. |
| | """ |
| | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): |
| | return |
| | torch.distributed.barrier() |
| |
|
| |
|
| | def get_world_size() -> int: |
| | """ |
| | Simple wrapper for correctly getting worldsize in both distributed |
| | / non-distributed settings |
| | """ |
| | return ( |
| | torch.distributed.get_world_size() |
| | if torch.distributed.is_available() and torch.distributed.is_initialized() |
| | else 1 |
| | ) |
| |
|
| |
|
| | def get_rank() -> int: |
| | """ |
| | Simple wrapper for correctly getting rank in both distributed |
| | / non-distributed settings |
| | """ |
| | return ( |
| | torch.distributed.get_rank() |
| | if torch.distributed.is_available() and torch.distributed.is_initialized() |
| | else 0 |
| | ) |
| |
|
| |
|
| | def get_primary_rank() -> int: |
| | return _PRIMARY_RANK |
| |
|
| |
|
| | def set_cuda_device_index(idx: int) -> None: |
| | global _cuda_device_index |
| | _cuda_device_index = idx |
| | torch.cuda.set_device(_cuda_device_index) |
| |
|
| |
|
| | def set_cpu_device() -> None: |
| | global _cuda_device_index |
| | _cuda_device_index = _CPU_DEVICE_INDEX |
| |
|
| |
|
| | def get_cuda_device_index() -> int: |
| | return _cuda_device_index |
| |
|
| |
|
| | def init_distributed_data_parallel_model( |
| | model: torch.nn.Module, |
| | broadcast_buffers: bool = False, |
| | find_unused_parameters: bool = True, |
| | bucket_cap_mb: int = 25, |
| | ) -> torch.nn.parallel.DistributedDataParallel: |
| | global _cuda_device_index |
| |
|
| | if _cuda_device_index == _CPU_DEVICE_INDEX: |
| | |
| | return torch.nn.parallel.DistributedDataParallel( |
| | model, |
| | broadcast_buffers=broadcast_buffers, |
| | find_unused_parameters=find_unused_parameters, |
| | bucket_cap_mb=bucket_cap_mb, |
| | ) |
| | else: |
| | |
| | return torch.nn.parallel.DistributedDataParallel( |
| | model, |
| | device_ids=[_cuda_device_index], |
| | output_device=_cuda_device_index, |
| | broadcast_buffers=broadcast_buffers, |
| | find_unused_parameters=find_unused_parameters, |
| | bucket_cap_mb=bucket_cap_mb, |
| | ) |
| |
|
| |
|
| | def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any: |
| | """Broadcast an object from a source to all workers. |
| | |
| | Args: |
| | obj: Object to broadcast, must be serializable |
| | src: Source rank for broadcast (default is primary) |
| | use_disk: If enabled, removes redundant CPU memory copies by writing to |
| | disk |
| | """ |
| | |
| | |
| | if get_rank() == src: |
| | |
| | buffer = io.BytesIO() |
| | torch.save(obj, buffer) |
| | data_view = buffer.getbuffer() |
| | length_tensor = torch.LongTensor([len(data_view)]) |
| | length_tensor = broadcast(length_tensor, src=src) |
| | data_tensor = torch.ByteTensor(data_view) |
| | data_tensor = broadcast(data_tensor, src=src) |
| | else: |
| | |
| | length_tensor = torch.LongTensor([0]) |
| | length_tensor = broadcast(length_tensor, src=src) |
| | data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8) |
| | data_tensor = broadcast(data_tensor, src=src) |
| | if use_disk: |
| | with tempfile.TemporaryFile("r+b") as f: |
| | f.write(data_tensor.numpy()) |
| | |
| | |
| | del data_tensor |
| | f.seek(0) |
| | obj = torch.load(f) |
| | else: |
| | buffer = io.BytesIO(data_tensor.numpy()) |
| | obj = torch.load(buffer) |
| | return obj |
| |
|
| |
|
| | def all_gather_tensor(tensor: torch.Tensor, world_size=None): |
| | if world_size is None: |
| | world_size = get_world_size() |
| | |
| | assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!" |
| | tensor, orig_device = convert_to_distributed_tensor(tensor) |
| | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] |
| | dist.all_gather(tensor_all, tensor, async_op=False) |
| | tensor_all = [ |
| | convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all |
| | ] |
| | return tensor_all |
| |
|
| |
|
| | def all_gather_batch(tensors: List[torch.Tensor]): |
| | """ |
| | Performs all_gather operation on the provided tensors. |
| | """ |
| | |
| | world_size = get_world_size() |
| | |
| | if world_size == 1: |
| | return tensors |
| | tensor_list = [] |
| | output_tensor = [] |
| | for tensor in tensors: |
| | tensor_all = all_gather_tensor(tensor, world_size) |
| | tensor_list.append(tensor_all) |
| |
|
| | for tensor_all in tensor_list: |
| | output_tensor.append(torch.cat(tensor_all, dim=0)) |
| | return output_tensor |
| |
|
| |
|
| | class GatherLayer(autograd.Function): |
| | """ |
| | Gather tensors from all workers with support for backward propagation: |
| | This implementation does not cut the gradients as torch.distributed.all_gather does. |
| | """ |
| |
|
| | @staticmethod |
| | def forward(ctx, x): |
| | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] |
| | dist.all_gather(output, x) |
| | return tuple(output) |
| |
|
| | @staticmethod |
| | def backward(ctx, *grads): |
| | all_gradients = torch.stack(grads) |
| | dist.all_reduce(all_gradients) |
| | return all_gradients[dist.get_rank()] |
| |
|
| |
|
| | def all_gather_batch_with_grad(tensors): |
| | """ |
| | Performs all_gather operation on the provided tensors. |
| | Graph remains connected for backward grad computation. |
| | """ |
| | |
| | world_size = get_world_size() |
| | |
| | if world_size == 1: |
| | return tensors |
| | tensor_list = [] |
| | output_tensor = [] |
| |
|
| | for tensor in tensors: |
| | tensor_all = GatherLayer.apply(tensor) |
| | tensor_list.append(tensor_all) |
| |
|
| | for tensor_all in tensor_list: |
| | output_tensor.append(torch.cat(tensor_all, dim=0)) |
| | return output_tensor |
| |
|
| |
|
| | def unwrap_ddp_if_wrapped(model): |
| | if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
| | return model.module |
| | return model |
| |
|
| |
|
| | def create_new_process_group(group_size): |
| | """ |
| | Creates process groups of a gives `group_size` and returns |
| | process group that current GPU participates in. |
| | |
| | `group_size` must divide the total number of GPUs (world_size). |
| | |
| | Modified from |
| | https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60 |
| | |
| | Args: |
| | group_size (int): number of GPU's to collaborate for sync bn |
| | """ |
| |
|
| | assert group_size > 0 |
| |
|
| | world_size = torch.distributed.get_world_size() |
| | if world_size <= 8: |
| | if group_size > world_size: |
| | logging.warning( |
| | f"Requested group size [{group_size}] > world size [{world_size}]. " |
| | "Assuming local debug run and capping it to world size." |
| | ) |
| | group_size = world_size |
| | assert world_size >= group_size |
| | assert world_size % group_size == 0 |
| |
|
| | group = None |
| | for group_num in range(world_size // group_size): |
| | group_ids = range(group_num * group_size, (group_num + 1) * group_size) |
| | cur_group = torch.distributed.new_group(ranks=group_ids) |
| | if torch.distributed.get_rank() // group_size == group_num: |
| | group = cur_group |
| | |
| |
|
| | assert group is not None |
| | return group |
| |
|
| |
|
| | def is_dist_avail_and_initialized(): |
| | if not dist.is_available(): |
| | return False |
| | if not dist.is_initialized(): |
| | return False |
| | return True |
| |
|