| |
| |
| |
| |
| |
|
|
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from torch.distributed._tensor import ( |
| distribute_tensor, |
| DTensor, |
| Partial, |
| Replicate, |
| Shard, |
| ) |
| from torch.utils.checkpoint import ( |
| checkpoint, |
| CheckpointPolicy, |
| create_selective_checkpoint_contexts, |
| ) |
|
|
|
|
| _active_parametrization = True |
|
|
|
|
| @contextmanager |
| def disable_data_parallel(): |
| global _active_parametrization |
| try: |
| _active_parametrization = False |
| yield |
| finally: |
| _active_parametrization = True |
|
|
|
|
| @dataclass(frozen=True) |
| class MixedPrecisionPolicy: |
| param_dtype: Optional[torch.dtype] = None |
| reduce_dtype: Optional[torch.dtype] = None |
|
|
|
|
| def fsdp_policy(): |
| def _fsdp_recomp_policy(): |
| def _custom_policy(ctx, func, *args, **kwargs): |
| to_recompute = func in { |
| torch.ops._c10d_functional.all_gather_into_tensor.default, |
| torch.ops._c10d_functional.wait_tensor.default, |
| torch.ops.aten._to_copy.default, |
| } |
| return ( |
| CheckpointPolicy.MUST_RECOMPUTE |
| if to_recompute |
| else CheckpointPolicy.MUST_SAVE |
| ) |
|
|
| return _custom_policy |
|
|
| return create_selective_checkpoint_contexts(_fsdp_recomp_policy()) |
|
|
|
|
| class ReplicateComputation(torch.nn.Module): |
| def __init__(self, device_mesh, param_sharding, mode, regional_ac, mp_policy): |
| super().__init__() |
| self.device_mesh = device_mesh |
| self.param_sharding = param_sharding |
| self.mode = mode |
| self.compute_placements = [Replicate()] * self.device_mesh.ndim |
| self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim |
| self.regional_ac = regional_ac |
| mp_policy = mp_policy or MixedPrecisionPolicy() |
| self.param_dtype = mp_policy.param_dtype |
| self.reduce_dtype = mp_policy.reduce_dtype |
|
|
| def replicate_compute(self, x): |
| |
| |
| |
|
|
| |
| |
| |
| if self.mode == "fully_shard" and x._spec.mesh.ndim == 2: |
| dp_placement, tp_placement = x._spec.placements |
| dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"] |
|
|
| |
| |
| sharded_local_tensor = x.to_local() |
| sharded_dtensor = DTensor.from_local( |
| sharded_local_tensor, dp_mesh, self.param_sharding |
| ) |
|
|
| |
| |
| |
| replicated_dtensor = sharded_dtensor.redistribute( |
| placements=self.compute_placements, |
| |
| |
| ) |
|
|
| |
| |
| replicated_local_tensor = replicated_dtensor.to_local( |
| grad_placements=self.grad_placements |
| ) |
| output = DTensor.from_local( |
| replicated_local_tensor, tp_mesh, (tp_placement,) |
| ) |
| else: |
| output = x.redistribute( |
| placements=self.compute_placements, |
| |
| |
| ).to_local(grad_placements=self.grad_placements) |
|
|
| return output |
|
|
| def forward(self, x): |
| global _active_parametrization |
| |
| |
| |
| |
| |
| if not _active_parametrization: |
| return x |
|
|
| if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"): |
| |
| output = checkpoint( |
| self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy |
| ) |
| else: |
| output = self.replicate_compute(x) |
|
|
| return output |
|
|
|
|
| def data_parallel( |
| model, |
| device_mesh, |
| mode="replicate", |
| ac_mode: str = "none", |
| mp_policy: Optional[MixedPrecisionPolicy] = None, |
| ): |
| if mode == "replicate": |
| param_sharding = (Replicate(),) |
| elif mode == "fully_shard": |
| param_sharding = (Shard(0),) |
| elif mode == "hybrid_shard": |
| |
| param_sharding = (Replicate(), Shard(0)) |
| assert ( |
| device_mesh.ndim == 2 |
| ), "hybrid sharded data parallel requires 2D DeviceMesh" |
| else: |
| raise ValueError(f"Unsupported mode {mode}") |
|
|
| modules = list(model.modules()) |
|
|
| |
| regional_ac = ac_mode == "none" |
|
|
| for mod in modules: |
| params_dict = dict(mod.named_parameters(recurse=False)) |
| for p_name, p in params_dict.items(): |
| if p is not None and p.numel() > 0: |
| mod.register_parameter( |
| p_name, |
| |
| |
| |
| nn.Parameter(distribute_tensor(p, device_mesh, param_sharding)), |
| ) |
| nn.utils.parametrize.register_parametrization( |
| mod, |
| p_name, |
| ReplicateComputation( |
| device_mesh, |
| param_sharding, |
| mode, |
| regional_ac, |
| mp_policy=mp_policy, |
| ), |
| unsafe=True, |
| ) |
|
|
| return model |
|
|