|
|
|
|
|
|
|
|
|
|
| import logging
|
| import os
|
| import random
|
| import socket
|
| import subprocess
|
| from datetime import timedelta
|
| from enum import Enum
|
| from typing import List, Sequence
|
|
|
| import torch
|
| import torch.distributed as dist
|
|
|
| logger = logging.getLogger("dinov3")
|
|
|
| _DEFAULT_PROCESS_GROUP = None
|
| _PROCESS_SUBGROUP = None
|
| _BUILTIN_PRINT = None
|
|
|
|
|
| def is_distributed_enabled() -> bool:
|
| """
|
| Returns:
|
| True if distributed training is enabled.
|
| """
|
| return dist.is_available() and dist.is_initialized()
|
|
|
|
|
| def get_rank(group=None) -> int:
|
| """
|
| Returns:
|
| The rank of the current process within the specified process group.
|
| """
|
| if not is_distributed_enabled():
|
| return 0
|
| return dist.get_rank(group=group)
|
|
|
|
|
| def get_world_size(group=None) -> int:
|
| """
|
| Returns:
|
| The number of processes in the specified process group.
|
| """
|
| if not is_distributed_enabled():
|
| return 1
|
| return dist.get_world_size(group=group)
|
|
|
|
|
| def is_main_process(group=None) -> bool:
|
| """
|
| Returns:
|
| True if the current process is the main one in the specified process group.
|
| """
|
| return get_rank(group) == 0
|
|
|
|
|
| def save_in_main_process(*args, **kwargs) -> None:
|
| """Utility function to save only from the main process."""
|
| group = kwargs.pop("group", None)
|
| if not is_main_process(group):
|
| return
|
| torch.save(*args, **kwargs)
|
|
|
|
|
| def _restrict_print_to_main_process() -> None:
|
| """This function disables printing when not in the main process."""
|
| import builtins as __builtin__
|
|
|
| global _BUILTIN_PRINT
|
| _BUILTIN_PRINT = __builtin__.print
|
|
|
| def print(*args, **kwargs):
|
| force = kwargs.pop("force", False)
|
| if is_main_process() or force:
|
| _BUILTIN_PRINT(*args, **kwargs)
|
|
|
| __builtin__.print = print
|
|
|
|
|
| def _get_master_port(seed: int = 0) -> int:
|
| MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
|
|
|
| master_port_str = os.environ.get("MASTER_PORT")
|
| if master_port_str is None:
|
| rng = random.Random(seed)
|
| return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
|
|
| return int(master_port_str)
|
|
|
|
|
| def _get_available_port() -> int:
|
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
|
|
|
| s.bind(("", 0))
|
| port = s.getsockname()[1]
|
| return port
|
|
|
|
|
| def _parse_slurm_node_list(s: str) -> List[str]:
|
| return subprocess.check_output(["scontrol", "show", "hostnames", s], text=True).splitlines()
|
|
|
|
|
| class JobType(Enum):
|
| TORCHELASTIC = "TorchElastic"
|
| SLURM = "Slurm"
|
| MANUAL = "manual"
|
|
|
|
|
| class TorchDistributedEnvironment:
|
| """
|
| Helper class to get (and set) distributed job information from the
|
| environment. Identifies and supports (in this order):
|
| - TorchElastic,
|
| - Slurm,
|
| - Manual launch (single-node).
|
| """
|
|
|
| def __init__(self):
|
| if "TORCHELASTIC_RUN_ID" in os.environ:
|
|
|
| self.job_id = os.environ["TORCHELASTIC_RUN_ID"]
|
| self.job_type = JobType.TORCHELASTIC
|
|
|
| self.master_addr = os.environ["MASTER_ADDR"]
|
| self.master_port = int(os.environ["MASTER_PORT"])
|
| self.rank = int(os.environ["RANK"])
|
| self.world_size = int(os.environ["WORLD_SIZE"])
|
| self.local_rank = int(os.environ["LOCAL_RANK"])
|
| self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
| elif "SLURM_JOB_ID" in os.environ:
|
|
|
| self.job_id = int(os.environ["SLURM_JOB_ID"])
|
| self.job_type = JobType.SLURM
|
|
|
| node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
|
| nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
|
| assert len(nodes) == node_count
|
|
|
| self.master_addr = nodes[0]
|
| self.master_port = _get_master_port(seed=self.job_id)
|
| self.rank = int(os.environ["SLURM_PROCID"])
|
| self.world_size = int(os.environ["SLURM_NTASKS"])
|
| self.local_rank = int(os.environ["SLURM_LOCALID"])
|
| self.local_world_size = self.world_size // node_count
|
| else:
|
|
|
| self.job_id = None
|
| self.job_type = JobType.MANUAL
|
|
|
| self.master_addr = "127.0.0.1"
|
| self.master_port = _get_available_port()
|
| self.rank = 0
|
| self.world_size = 1
|
| self.local_rank = 0
|
| self.local_world_size = 1
|
|
|
| assert self.rank < self.world_size
|
| assert self.local_rank < self.local_world_size
|
|
|
| def export(
|
| self,
|
| *,
|
| overwrite: bool,
|
| nccl_async_error_handling: bool = False,
|
| ) -> "TorchDistributedEnvironment":
|
|
|
|
|
|
|
| env_vars = {
|
| "MASTER_ADDR": self.master_addr,
|
| "MASTER_PORT": str(self.master_port),
|
| "RANK": str(self.rank),
|
| "WORLD_SIZE": str(self.world_size),
|
| "LOCAL_RANK": str(self.local_rank),
|
| "LOCAL_WORLD_SIZE": str(self.local_world_size),
|
| }
|
| if nccl_async_error_handling:
|
| env_vars.update(
|
| {
|
| "TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
|
| }
|
| )
|
|
|
| if not overwrite:
|
| for k, v in env_vars.items():
|
|
|
| if k not in os.environ:
|
| continue
|
| if os.environ[k] == v:
|
| continue
|
| raise RuntimeError(f"Cannot export environment variables as {k} is already set")
|
|
|
| os.environ.update(env_vars)
|
| return self
|
|
|
| @property
|
| def is_main_process(self) -> bool:
|
| return self.rank == 0
|
|
|
| def __str__(self):
|
| return (
|
| f"{self.job_type.value} job "
|
| + (f"({self.job_id}) " if self.job_id else "")
|
| + f"using {self.master_addr}:{self.master_port} "
|
| f"(rank={self.rank}, world size={self.world_size})"
|
| )
|
|
|
| def __repr__(self):
|
| return (
|
| f"{self.__class__.__name__}("
|
| f"master_addr={self.master_addr},"
|
| f"master_port={self.master_port},"
|
| f"rank={self.rank},"
|
| f"world_size={self.world_size},"
|
| f"local_rank={self.local_rank},"
|
| f"local_world_size={self.local_world_size}"
|
| ")"
|
| )
|
|
|
|
|
| def enable_distributed(
|
| *,
|
| set_cuda_current_device: bool = True,
|
| overwrite: bool = False,
|
| nccl_async_error_handling: bool = False,
|
| restrict_print_to_main_process: bool = True,
|
| timeout: timedelta | None = None,
|
| ):
|
| """Enable distributed mode.
|
|
|
| Args:
|
| set_cuda_current_device: If True, call torch.cuda.set_device() to set the
|
| current PyTorch CUDA device to the one matching the local rank.
|
| overwrite: If True, overwrites already set variables. Else fails.
|
| nccl_async_error_handling: Enables NCCL asynchronous error handling. As a
|
| side effect, this enables timing out PyTorch distributed operations
|
| after a default 30 minutes delay).
|
| restrict_print_to_main_process: If True, the print function of non-main processes
|
| (ie rank>0) is disabled. Use print(..., force=True) to print anyway.
|
| If False, nothing is changed and all processes can print as usual.
|
| timeout: Timeout for operations executed against the process group.
|
| Default value is 10 minutes for NCCL and 30 minutes for other backends.
|
| """
|
| global _DEFAULT_PROCESS_GROUP
|
|
|
| if _DEFAULT_PROCESS_GROUP is not None:
|
| raise RuntimeError("Distributed mode has already been enabled")
|
|
|
| torch_env = TorchDistributedEnvironment()
|
| logger.info(f"PyTorch distributed environment: {torch_env}")
|
| torch_env.export(
|
| overwrite=overwrite,
|
| nccl_async_error_handling=nccl_async_error_handling,
|
| )
|
|
|
| if set_cuda_current_device:
|
| torch.cuda.set_device(torch_env.local_rank)
|
|
|
| dist.init_process_group(backend="nccl", timeout=timeout)
|
| dist.barrier()
|
|
|
| if restrict_print_to_main_process:
|
| _restrict_print_to_main_process()
|
|
|
|
|
| _DEFAULT_PROCESS_GROUP = torch.distributed.group.WORLD
|
|
|
|
|
| def get_default_process_group():
|
| return _DEFAULT_PROCESS_GROUP
|
|
|
|
|
| def disable_distributed() -> None:
|
| global _BUILTIN_PRINT
|
| if _BUILTIN_PRINT is not None:
|
| import builtins as __builtin__
|
|
|
| __builtin__.print = _BUILTIN_PRINT
|
|
|
| global _PROCESS_SUBGROUP
|
|
|
| if _PROCESS_SUBGROUP is not None:
|
| torch.distributed.destroy_process_group(_PROCESS_SUBGROUP)
|
| _PROCESS_SUBGROUP = None
|
|
|
| global _DEFAULT_PROCESS_GROUP
|
| if _DEFAULT_PROCESS_GROUP is not None:
|
| torch.distributed.destroy_process_group(_DEFAULT_PROCESS_GROUP)
|
| _DEFAULT_PROCESS_GROUP = None
|
|
|
|
|
| def new_subgroups(all_subgroup_ranks: Sequence[Sequence[int]]):
|
| """Create new process subgroups according to the provided specification.
|
|
|
| Args:
|
| all_subgroup_ranks: a sequence of rank sequences (first rank, ..., last rank),
|
| one for each process subgroup. Example: ((0, 1), (2, 3), (4, 5, 6, 7)).
|
|
|
| Note:
|
| This is similar to the (non-documented) new_subgroups_by_enumeration().
|
| This should be called once (and not sequentially) to create all subgroups.
|
| """
|
| all_ranks = tuple(rank for subgroup_ranks in all_subgroup_ranks for rank in subgroup_ranks)
|
| rank = get_rank()
|
| assert len(all_ranks) == len(set(all_ranks))
|
| assert rank in all_ranks
|
|
|
| global _PROCESS_SUBGROUP
|
| assert _PROCESS_SUBGROUP is None
|
|
|
| for subgroup_ranks in all_subgroup_ranks:
|
| subgroup = torch.distributed.new_group(subgroup_ranks)
|
| if rank in subgroup_ranks:
|
| _PROCESS_SUBGROUP = subgroup
|
|
|
|
|
| def get_process_subgroup():
|
| """
|
| Returns:
|
| The process subgroup of this rank (or None).
|
| """
|
| return _PROCESS_SUBGROUP or _DEFAULT_PROCESS_GROUP
|
|
|
|
|
| def get_subgroup_rank() -> int:
|
| """
|
| Returns:
|
| The rank of the current process within its process subgroup.
|
| """
|
| return get_rank(group=get_process_subgroup())
|
|
|
|
|
| def get_subgroup_size() -> int:
|
| """
|
| Returns:
|
| The number of processes in the process subgroup
|
| """
|
| return get_world_size(group=get_process_subgroup())
|
|
|
|
|
| def is_subgroup_main_process() -> bool:
|
| """
|
| Returns:
|
| True if the current process is the main one within its process subgroup.
|
| """
|
| return get_rank(group=get_process_subgroup()) == 0
|
|
|