| | |
| | |
| | |
| | |
| | |
| |
|
| | import contextlib |
| | import math |
| | import os |
| | from collections.abc import Generator, Iterable |
| | from datetime import timedelta |
| |
|
| | import torch |
| | import torch.distributed._functional_collectives as funcol |
| | import torch.distributed.distributed_c10d as c10d |
| | from torch import distributed as dist |
| | from torch.distributed.device_mesh import DeviceMesh |
| | from torch.distributed.tensor import DTensor |
| |
|
| | from torchtitan.components.ft import ft_clip_grad_norm_util, ft_dist_reduce |
| | from torchtitan.tools.logging import logger |
| | from torchtitan.tools.utils import device_module, device_type |
| |
|
| |
|
| | def _dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: |
| | |
| | x, reduceOp, mesh = ft_dist_reduce(x, reduceOp, mesh) |
| |
|
| | if isinstance(x, DTensor): |
| | |
| | x = x.full_tensor() |
| | assert x.numel() == 1 |
| | return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() |
| |
|
| |
|
| | def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float: |
| | return _dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh) |
| |
|
| |
|
| | def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float: |
| | return _dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh) |
| |
|
| |
|
| | def set_determinism( |
| | world_mesh: DeviceMesh | None, |
| | device: torch.device, |
| | seed: int | None = None, |
| | deterministic: bool = False, |
| | ) -> None: |
| | """ |
| | Set the same DTensor manual seed for all ranks within the same DTensor SPMD group, but different |
| | seeds across PP groups (if applicable). |
| | |
| | Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms, |
| | and DTensor manages its own RNG tracker, but we could extend to support both if needed. |
| | |
| | Set Determinism flags for increased reproducibility with loss of performance. |
| | """ |
| | if deterministic: |
| | logger.info("Deterministic algorithm enabled (expect perf degradation).") |
| | torch.use_deterministic_algorithms(True) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| | |
| | |
| | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
| |
|
| | if not world_mesh: |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | os.environ["PYTHONHASHSEED"] = str(seed % 2**32) |
| | logger.debug(f"Single-process job using seed: {seed}") |
| | return |
| |
|
| | |
| | |
| | if seed is None: |
| | |
| | |
| | seed_tensor = torch.get_rng_state()[:8].to(device) |
| | torch.distributed.broadcast(seed_tensor, src=0) |
| | seed = seed_tensor.to("cpu").view(torch.uint64).item() |
| |
|
| | |
| | |
| | if c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names: |
| | pp_mesh = world_mesh["pp"] |
| | seed += pp_mesh.get_local_rank() |
| | seed %= 2**64 |
| |
|
| | logger.debug( |
| | f"PP rank {pp_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}" |
| | ) |
| | spmd_mesh_dims = list( |
| | filter(lambda name: name != "pp", world_mesh.mesh_dim_names) |
| | ) |
| | spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None |
| | else: |
| | spmd_mesh = world_mesh |
| | logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}") |
| |
|
| | |
| | torch.manual_seed(seed) |
| | |
| | os.environ["PYTHONHASHSEED"] = str(seed % 2**32) |
| |
|
| | |
| | |
| | if spmd_mesh and spmd_mesh.get_coordinate() is not None: |
| | torch.distributed.tensor._random.manual_seed(seed, spmd_mesh) |
| |
|
| |
|
| | def create_context_parallel_ctx( |
| | cp_mesh: DeviceMesh, |
| | cp_buffers: list[torch.Tensor], |
| | cp_seq_dims: list[int], |
| | cp_no_restore_buffers: set[torch.Tensor], |
| | cp_rotate_method: str, |
| | ): |
| | try: |
| | from torch.distributed.tensor.experimental import context_parallel |
| | from torch.distributed.tensor.experimental._attention import set_rotate_method |
| | except ImportError: |
| | print( |
| | f"PyTorch version {torch.__version__} does not include the experimental " |
| | "Context Parallel API. Please update to a newer version." |
| | ) |
| |
|
| | set_rotate_method(cp_rotate_method) |
| | return context_parallel( |
| | cp_mesh, |
| | buffers=cp_buffers, |
| | buffer_seq_dims=cp_seq_dims, |
| | no_restore_buffers=cp_no_restore_buffers, |
| | ) |
| |
|
| |
|
| | def get_train_context( |
| | enable_loss_parallel: bool, enable_compiled_autograd: bool |
| | ) -> Generator[None, None, None]: |
| | @contextlib.contextmanager |
| | def context(cp_context: Generator[None, None, None] | None = None): |
| | with contextlib.ExitStack() as stack: |
| | if enable_loss_parallel: |
| | stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) |
| |
|
| | if enable_compiled_autograd: |
| | stack.enter_context( |
| | torch._dynamo.utils.maybe_enable_compiled_autograd(True) |
| | ) |
| |
|
| | if cp_context is not None: |
| | from torch.nn.attention import sdpa_kernel, SDPBackend |
| |
|
| | stack.enter_context( |
| | sdpa_kernel( |
| | [ |
| | SDPBackend.FLASH_ATTENTION, |
| | SDPBackend.EFFICIENT_ATTENTION, |
| | SDPBackend.CUDNN_ATTENTION, |
| | ] |
| | ) |
| | ) |
| | stack.enter_context(cp_context) |
| |
|
| | yield |
| |
|
| | return context |
| |
|
| |
|
| | def init_distributed(job_config): |
| | def _warn_overwrite_env(env, val): |
| | if env in os.environ: |
| | logger.warning( |
| | f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config" |
| | ) |
| | os.environ[env] = val |
| |
|
| | def _get_distributed_backend(job_config): |
| | backend = "nccl" |
| | if device_type in torch.distributed.Backend.default_device_backend_map: |
| | backend = torch.distributed.Backend.default_device_backend_map.get( |
| | device_type |
| | ) |
| | if job_config.training.enable_cpu_offload: |
| | backend = f"{device_type}:{backend},cpu:gloo" |
| | return backend |
| |
|
| | TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" |
| | TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" |
| | DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" |
| | ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING" |
| | SKIP_CLEANUP = "3" |
| |
|
| | |
| | |
| | |
| | |
| | _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) |
| |
|
| | |
| | _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) |
| | if job_config.comm.trace_buf_size > 0: |
| | |
| | _warn_overwrite_env(DUMP_ON_TIMEOUT, "1") |
| | dump_dir = f"{job_config.job.dump_folder}/comm_trace" |
| | os.makedirs(dump_dir, exist_ok=True) |
| | _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") |
| |
|
| | |
| | |
| | |
| | os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" |
| |
|
| | torch.distributed.init_process_group( |
| | backend=_get_distributed_backend(job_config), |
| | timeout=timedelta(seconds=job_config.comm.init_timeout_seconds), |
| | ) |
| |
|
| |
|
| | def set_pg_timeouts(timeout, world_mesh): |
| | """ |
| | Sets the timeout for all PGs in the provided mesh, and the default (world) group. |
| | |
| | Note: synchronizes via a barrier, before changing the timeouts. This is important, because |
| | otherwise you may face a race where the slow rank has not reached the timeout reduction point |
| | yet due to slow operations permitted under the old timeout value, but other faster ranks may |
| | start issuing collectives under the new shorter timeout and then immediately timeout. |
| | """ |
| | logger.info( |
| | f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}" |
| | ) |
| | |
| | |
| | |
| | |
| | torch.distributed.barrier(device_ids=[device_module.current_device()]) |
| | device_module.synchronize() |
| |
|
| | groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] |
| |
|
| | |
| | groups.append(None) |
| | for group in groups: |
| | torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) |
| |
|
| |
|
| | @torch.no_grad() |
| | def clip_grad_norm_( |
| | parameters: torch.Tensor | Iterable[torch.Tensor], |
| | max_norm: float, |
| | norm_type: float = 2.0, |
| | error_if_nonfinite: bool = False, |
| | foreach: bool | None = None, |
| | pp_mesh: DeviceMesh | None = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Clip the gradient norm of an iterable of parameters. |
| | |
| | Gradient norm clipping requires computing the gradient norm over the entire model. |
| | `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. |
| | We need to manually reduce the gradient norm across PP stages. |
| | See https://github.com/pytorch/torchtitan/issues/596 for details. |
| | |
| | Args: |
| | parameters: an iterable of Tensors or a single Tensor that will have gradients normalized |
| | max_norm (float): max norm of the gradients |
| | norm_type (float): type of the used p-norm. Can be ``'inf'`` for |
| | infinity norm. |
| | error_if_nonfinite (bool): if True, an error is thrown if the total |
| | norm of the gradients from :attr:`parameters` is ``nan``, |
| | ``inf``, or ``-inf``. Default: False (will switch to True in the future) |
| | foreach (bool): use the faster foreach-based implementation. |
| | If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently |
| | fall back to the slow implementation for other device types. |
| | Default: ``None`` |
| | pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages. |
| | |
| | Returns: |
| | Total norm of the parameter gradients (viewed as a single vector). |
| | |
| | """ |
| | grads = [p.grad for p in parameters if p.grad is not None] |
| | total_norm = torch.nn.utils.get_total_norm( |
| | grads, norm_type, error_if_nonfinite, foreach |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if isinstance(total_norm, DTensor): |
| | |
| | |
| |
|
| | |
| | total_norm = ft_clip_grad_norm_util(total_norm) |
| | total_norm = total_norm.full_tensor() |
| |
|
| | if pp_mesh is not None: |
| | if math.isinf(norm_type): |
| | dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) |
| | else: |
| | total_norm **= norm_type |
| | dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) |
| | total_norm **= 1.0 / norm_type |
| |
|
| | torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) |
| | return total_norm |
| |
|