| from typing import Optional |
|
|
| import torch |
| import torch.distributed._functional_collectives as funcol |
| import torch.distributed.tensor |
| from diffusers.utils import is_accelerate_available |
| from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard |
| from torch.distributed._composable.replicate import replicate |
|
|
| from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES |
|
|
|
|
| if is_accelerate_available(): |
| from accelerate import Accelerator |
| from accelerate.utils import ( |
| DataLoaderConfiguration, |
| DistributedDataParallelKwargs, |
| InitProcessGroupKwargs, |
| ProjectConfiguration, |
| ) |
|
|
|
|
| def apply_fsdp2_ptd( |
| model: torch.nn.Module, |
| dp_mesh: torch.distributed.device_mesh.DeviceMesh, |
| param_dtype: torch.dtype, |
| reduce_dtype: torch.dtype, |
| output_dtype: torch.dtype, |
| pp_enabled: bool = False, |
| cpu_offload: bool = False, |
| ) -> None: |
| r"""Apply FSDP2 on a model.""" |
| mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True) |
| fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} |
|
|
| if cpu_offload: |
| fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True) |
|
|
| def apply_fully_shard(blocks): |
| for layer_index, block in enumerate(blocks): |
| if pp_enabled: |
| |
| |
| reshard_after_forward = False |
| else: |
| |
| |
| reshard_after_forward = layer_index < len(blocks) - 1 |
| fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward) |
|
|
| for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: |
| blocks = getattr(model, transformer_block_name, None) |
| if blocks is not None: |
| apply_fully_shard(blocks) |
|
|
| fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) |
|
|
|
|
| def apply_ddp_accelerate( |
| model: torch.nn.Module, |
| project_config: Optional[ProjectConfiguration] = None, |
| ddp_kwargs: Optional[DistributedDataParallelKwargs] = None, |
| init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None, |
| dataloader_config: Optional[DataLoaderConfiguration] = None, |
| gradient_accumulation_steps: Optional[int] = None, |
| accelerator: Optional[Accelerator] = None, |
| ) -> torch.nn.Module: |
| if accelerator is None: |
| accelerator = Accelerator( |
| project_config=project_config, |
| dataloader_config=dataloader_config, |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| log_with=None, |
| kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], |
| ) |
| if torch.backends.mps.is_available(): |
| accelerator.native_amp = False |
| accelerator.prepare_model(model) |
| return accelerator, model |
|
|
|
|
| def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None: |
| replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) |
|
|
|
|
| def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: |
| if isinstance(x, torch.distributed.tensor.DTensor): |
| |
| x = x.full_tensor() |
| assert x.numel() == 1 |
| return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() |
|
|
|
|
| def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: |
| return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) |
|
|
|
|
| def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: |
| return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh) |
|
|