| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | General namespace and dataclass related classes |
| | """ |
| |
|
| | import argparse |
| | import copy |
| | import enum |
| | import functools |
| | import logging |
| | import os |
| | import warnings |
| | from collections.abc import Iterable |
| | from contextlib import contextmanager |
| | from dataclasses import dataclass, field |
| | from datetime import timedelta |
| | from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_args |
| |
|
| | import torch |
| |
|
| | from .constants import ( |
| | BETA_CP_AVAILABLE_PYTORCH_VERSION, |
| | BETA_TP_AVAILABLE_PYTORCH_VERSION, |
| | BETA_TP_AVAILABLE_TRANSFORMERS_VERSION, |
| | FSDP2_PYTORCH_VERSION, |
| | FSDP_AUTO_WRAP_POLICY, |
| | FSDP_BACKWARD_PREFETCH, |
| | FSDP_SHARDING_STRATEGY, |
| | MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, |
| | XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, |
| | ) |
| | from .environment import parse_flag_from_env, str_to_bool |
| | from .imports import ( |
| | is_cuda_available, |
| | is_hpu_available, |
| | is_mlu_available, |
| | is_msamp_available, |
| | is_musa_available, |
| | is_npu_available, |
| | is_transformer_engine_available, |
| | is_xpu_available, |
| | ) |
| | from .versions import compare_versions, is_torch_version |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | |
| | from torchao.float8 import Float8LinearConfig |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class KwargsHandler: |
| | """ |
| | Internal mixin that implements a `to_kwargs()` method for a dataclass. |
| | """ |
| |
|
| | def to_dict(self): |
| | return copy.deepcopy(self.__dict__) |
| |
|
| | def to_kwargs(self): |
| | """ |
| | Returns a dictionary containing the attributes with values different from the default of this class. |
| | """ |
| | |
| | from .environment import clear_environment |
| |
|
| | with clear_environment(): |
| | default_dict = self.__class__().to_dict() |
| | this_dict = self.to_dict() |
| | return {k: v for k, v in this_dict.items() if default_dict[k] != v} |
| |
|
| |
|
| | class EnumWithContains(enum.EnumMeta): |
| | "A metaclass that adds the ability to check if `self` contains an item with the `in` operator" |
| |
|
| | def __contains__(cls, item): |
| | try: |
| | cls(item) |
| | except ValueError: |
| | return False |
| | return True |
| |
|
| |
|
| | class BaseEnum(enum.Enum, metaclass=EnumWithContains): |
| | "An enum class that can get the value of an item with `str(Enum.key)`" |
| |
|
| | def __str__(self): |
| | return self.value |
| |
|
| | @classmethod |
| | def list(cls): |
| | "Method to list all the possible items in `cls`" |
| | return list(map(str, cls)) |
| |
|
| |
|
| | @dataclass |
| | class AutocastKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the |
| | documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more |
| | information on each argument. |
| | |
| | Example: |
| | |
| | ```python |
| | from accelerate import Accelerator |
| | from accelerate.utils import AutocastKwargs |
| | |
| | kwargs = AutocastKwargs(cache_enabled=True) |
| | accelerator = Accelerator(kwargs_handlers=[kwargs]) |
| | ``` |
| | """ |
| |
|
| | enabled: bool = True |
| | cache_enabled: Optional[bool] = None |
| |
|
| |
|
| | class DDPCommunicationHookType(BaseEnum): |
| | """ |
| | Represents a type of communication hook used in DDP. |
| | |
| | Values: |
| | |
| | - **NO** -- no communication hook |
| | - **FP16** -- DDP communication hook to compress the gradients in FP16 |
| | - **BF16** -- DDP communication hook to compress the gradients in BF16 |
| | - **POWER_SGD** -- DDP communication hook to use PowerSGD |
| | - **BATCHED_POWER_SGD** -- DDP communication hook to use batched PowerSGD |
| | """ |
| |
|
| | NO = "no" |
| | FP16 = "fp16" |
| | BF16 = "bf16" |
| | POWER_SGD = "power_sgd" |
| | BATCHED_POWER_SGD = "batched_power_sgd" |
| |
|
| |
|
| | @dataclass |
| | class DistributedDataParallelKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize how your model is wrapped in a |
| | `torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this |
| | [wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more |
| | information on each argument. |
| | |
| | <Tip warning={true}> |
| | |
| | `gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions. |
| | |
| | `static_graph` is only available in PyTorch 1.11.0 and later versions. |
| | |
| | </Tip> |
| | |
| | Example: |
| | |
| | ```python |
| | from accelerate import Accelerator |
| | from accelerate.utils import DistributedDataParallelKwargs |
| | |
| | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| | accelerator = Accelerator(kwargs_handlers=[kwargs]) |
| | ``` |
| | """ |
| |
|
| | dim: int = 0 |
| | broadcast_buffers: bool = True |
| | bucket_cap_mb: int = 25 |
| | find_unused_parameters: bool = False |
| | check_reduction: bool = False |
| | gradient_as_bucket_view: bool = False |
| | static_graph: bool = False |
| |
|
| | comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO |
| | comm_wrapper: Literal[ |
| | DDPCommunicationHookType.NO, |
| | DDPCommunicationHookType.FP16, |
| | DDPCommunicationHookType.BF16, |
| | ] = DDPCommunicationHookType.NO |
| | comm_state_option: dict = field(default_factory=dict) |
| |
|
| | def to_dict(self, ignore_keys=("comm_hook", "comm_wrapper", "comm_state_option")): |
| | return {k: v for k, v in super().to_dict().items() if k not in ignore_keys} |
| |
|
| | def register_comm_hook(self, model): |
| | from torch.distributed.algorithms.ddp_comm_hooks import ( |
| | default_hooks, |
| | powerSGD_hook, |
| | ) |
| |
|
| | hook_map: dict[DDPCommunicationHookType, Callable] = { |
| | DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook, |
| | DDPCommunicationHookType.BF16: default_hooks.bf16_compress_hook, |
| | DDPCommunicationHookType.POWER_SGD: powerSGD_hook.powerSGD_hook, |
| | DDPCommunicationHookType.BATCHED_POWER_SGD: powerSGD_hook.batched_powerSGD_hook, |
| | } |
| |
|
| | wrapper_map: dict[DDPCommunicationHookType, Callable] = { |
| | DDPCommunicationHookType.FP16: default_hooks.fp16_compress_wrapper, |
| | DDPCommunicationHookType.BF16: default_hooks.bf16_compress_wrapper, |
| | } |
| |
|
| | hook: Optional[Callable] = hook_map.get(self.comm_hook) |
| | wrapper: Optional[Callable] = wrapper_map.get(self.comm_wrapper) |
| |
|
| | if hook and wrapper: |
| | hook = wrapper(hook) |
| |
|
| | if hook: |
| | state = ( |
| | powerSGD_hook.PowerSGDState(None, **self.comm_state_option) |
| | if self.comm_hook |
| | in ( |
| | DDPCommunicationHookType.POWER_SGD, |
| | DDPCommunicationHookType.BATCHED_POWER_SGD, |
| | ) |
| | else None |
| | ) |
| | model.register_comm_hook( |
| | state=state, |
| | hook=hook, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class GradScalerKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the |
| | `torch.amp.GradScaler` or `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this |
| | [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument. |
| | |
| | <Tip warning={true}> |
| | |
| | `torch.cuda.amp.GradScaler` is only available in PyTorch 1.5.0 and later versions, and `torch.amp.GradScaler` is |
| | only available in PyTorch 2.4.0 and later versions. |
| | |
| | </Tip> |
| | |
| | Example: |
| | |
| | ```python |
| | from accelerate import Accelerator |
| | from accelerate.utils import GradScalerKwargs |
| | |
| | kwargs = GradScalerKwargs(backoff_factor=0.25) |
| | accelerator = Accelerator(kwargs_handlers=[kwargs]) |
| | ``` |
| | """ |
| |
|
| | init_scale: float = 65536.0 |
| | growth_factor: float = 2.0 |
| | backoff_factor: float = 0.5 |
| | growth_interval: int = 2000 |
| | enabled: bool = True |
| |
|
| |
|
| | @dataclass |
| | class InitProcessGroupKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer |
| | to the documentation of this |
| | [method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more |
| | information on each argument. |
| | |
| | Note: If `timeout` is set to `None`, the default will be based upon how `backend` is set. |
| | |
| | ```python |
| | from datetime import timedelta |
| | from accelerate import Accelerator |
| | from accelerate.utils import InitProcessGroupKwargs |
| | |
| | kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800)) |
| | accelerator = Accelerator(kwargs_handlers=[kwargs]) |
| | ``` |
| | """ |
| |
|
| | backend: Optional[str] = "nccl" |
| | init_method: Optional[str] = None |
| | timeout: Optional[timedelta] = None |
| |
|
| | def __post_init__(self): |
| | if self.timeout is None: |
| | seconds = 1800 if self.backend != "nccl" else 600 |
| | self.timeout = timedelta(seconds=seconds) |
| |
|
| |
|
| | |
| | Backend = Literal["MSAMP", "TE"] |
| | OptLevel = Literal["O1", "O2"] |
| | FP8Format = Literal["HYBRID", "E4M3", "E5M2"] |
| | AmaxComputeAlgorithm = Literal["max", "most_recent"] |
| |
|
| |
|
| | |
| | @dataclass |
| | class AORecipeKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision |
| | training with `torchao` FP8. |
| | |
| | Args: |
| | config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): |
| | The configuration for the FP8 training. In general, the default config should be sufficient. |
| | module_filter_func (`Callable`, *optional*, default to `None`): |
| | Optional function that must take in a module and layer name, and returns a boolean indicating whether the |
| | module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an |
| | example. |
| | """ |
| |
|
| | config: Optional["Float8LinearConfig"] = None |
| | module_filter_func: Optional[Callable] = None |
| |
|
| |
|
| | @dataclass |
| | class TERecipeKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision |
| | training with `transformer-engine`. |
| | |
| | <Tip> |
| | |
| | For more information on the args, please refer to the API |
| | [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). |
| | |
| | </Tip> |
| | |
| | ```python |
| | from accelerate import Accelerator |
| | from accelerate.utils import TERecipeKwargs |
| | |
| | kwargs = TERecipeKwargs(fp8_format="HYBRID") |
| | accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) |
| | ``` |
| | |
| | Args: |
| | use_autocast_during_eval (`bool`, *optional*, default to `False`): |
| | Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`. |
| | margin (`int`, *optional*, default to 0): |
| | The margin to use for the gradient scaling. |
| | interval (`int`, *optional*, default to 1): |
| | The interval to use for how often the scaling factor is recomputed. |
| | fp8_format (`str`, *optional*, default to "HYBRID"): |
| | The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for |
| | training, `E4M3` or `E5M2` for evaluation) |
| | amax_history_len (`int`, *optional*, default to 1024): |
| | The length of the history to use for the scaling factor computation |
| | amax_compute_algo (`str`, *optional*, default to "most_recent"): |
| | The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. |
| | override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): |
| | Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. |
| | """ |
| |
|
| | use_autocast_during_eval: Optional[bool] = None |
| | margin: Optional[int] = None |
| | interval: Optional[int] = None |
| | fp8_format: FP8Format = None |
| | amax_history_len: Optional[int] = None |
| | amax_compute_algo: AmaxComputeAlgorithm = None |
| | override_linear_precision: tuple[bool, bool, bool] = None |
| | use_mxfp8_block_scaling: Optional[bool] = None |
| |
|
| | def __post_init__(self): |
| | env_prefix = "ACCELERATE_FP8_" |
| | if not is_transformer_engine_available(): |
| | raise ImportError("TransformerEngine is not available. Please install it or use a different backend.") |
| | if self.use_autocast_during_eval is None: |
| | self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") |
| | if self.margin is None: |
| | self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) |
| | if self.interval is None: |
| | self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) |
| | if self.fp8_format is None: |
| | self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") |
| | self.fp8_format = self.fp8_format.upper() |
| | if self.fp8_format not in get_args(FP8Format): |
| | raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") |
| | if self.amax_compute_algo is None: |
| | self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") |
| | self.amax_compute_algo = self.amax_compute_algo.lower() |
| | if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): |
| | raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") |
| | if self.amax_history_len is None: |
| | self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) |
| | if self.override_linear_precision is None: |
| | fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") |
| | dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") |
| | wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") |
| | self.override_linear_precision = (fprop, dgrad, wgrad) |
| | if self.use_mxfp8_block_scaling is None: |
| | self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING") |
| |
|
| |
|
| | @dataclass |
| | class MSAMPRecipeKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision |
| | training with `ms-amp`. |
| | """ |
| |
|
| | opt_level: OptLevel = None |
| |
|
| | def __post_init__(self): |
| | env_prefix = "ACCELERATE_FP8_" |
| | if self.opt_level is None: |
| | self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") |
| | if self.opt_level not in get_args(OptLevel): |
| | raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}") |
| |
|
| |
|
| | @dataclass |
| | class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): |
| | """ |
| | Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` |
| | instead. |
| | """ |
| |
|
| | backend: Backend = None |
| |
|
| | def __post_init__(self): |
| | env_prefix = "ACCELERATE_FP8_" |
| | warnings.warn( |
| | "FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. " |
| | "Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.", |
| | FutureWarning, |
| | ) |
| | default_backend = "msamp" if is_msamp_available() else "te" |
| | if self.backend is None: |
| | self.backend = os.environ.get(env_prefix + "BACKEND", default_backend) |
| | self.backend = self.backend.upper() |
| | if self.backend not in get_args(Backend): |
| | raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.") |
| | super().__post_init__() |
| |
|
| |
|
| | |
| | ProfilerActivity = Literal["cpu", "xpu", "mtia", "cuda", "hpu"] |
| |
|
| |
|
| | @dataclass |
| | class ProfileKwargs(KwargsHandler): |
| | """ |
| | Use this object in your [`Accelerator`] to customize the initialization of the profiler. Please refer to the |
| | documentation of this [context manager](https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile) for |
| | more information on each argument. |
| | |
| | <Tip warning={true}> |
| | |
| | `torch.profiler` is only available in PyTorch 1.8.1 and later versions. |
| | |
| | </Tip> |
| | |
| | Example: |
| | |
| | ```python |
| | from accelerate import Accelerator |
| | from accelerate.utils import ProfileKwargs |
| | |
| | kwargs = ProfileKwargs(activities=["cpu", "cuda"]) |
| | accelerator = Accelerator(kwargs_handlers=[kwargs]) |
| | ``` |
| | |
| | Args: |
| | activities (`List[str]`, *optional*, default to `None`): |
| | The list of activity groups to use in profiling. Must be one of `"cpu"`, `"xpu"`, `"mtia"`, "hpu" or |
| | `"cuda"`. |
| | schedule_option (`Dict[str, int]`, *optional*, default to `None`): |
| | The schedule option to use for the profiler. Available keys are `wait`, `warmup`, `active`, `repeat` and |
| | `skip_first`. The profiler will skip the first `skip_first` steps, then wait for `wait` steps, then do the |
| | warmup for the next `warmup` steps, then do the active recording for the next `active` steps and then |
| | repeat the cycle starting with `wait` steps. The optional number of cycles is specified with the `repeat` |
| | parameter, the zero value means that the cycles will continue until the profiling is finished. |
| | on_trace_ready (`Callable`, *optional*, default to `None`): |
| | Callable that is called at each step when schedule returns `ProfilerAction.RECORD_AND_SAVE` during the |
| | profiling. |
| | record_shapes (`bool`, *optional*, default to `False`): |
| | Save information about operator’s input shapes. |
| | profile_memory (`bool`, *optional*, default to `False`): |
| | Track tensor memory allocation/deallocation |
| | with_stack (`bool`, *optional*, default to `False`): |
| | Record source information (file and line number) for the ops. |
| | with_flops (`bool`, *optional*, default to `False`): |
| | Use formula to estimate the FLOPS of specific operators |
| | with_modules (`bool`, *optional*, default to `False`): |
| | Record module hierarchy (including function names) corresponding to the callstack of the op. |
| | output_trace_dir (`str`, *optional*, default to `None`): |
| | Exports the collected trace in Chrome JSON format. Chrome use 'chrome://tracing' view json file. Defaults |
| | to None, which means profiling does not store json files. |
| | """ |
| |
|
| | activities: Optional[list[ProfilerActivity]] = None |
| | schedule_option: Optional[dict[str, int]] = None |
| | on_trace_ready: Optional[Callable] = None |
| | record_shapes: bool = False |
| | profile_memory: bool = False |
| | with_stack: bool = False |
| | with_flops: bool = False |
| | with_modules: bool = False |
| | output_trace_dir: Optional[str] = None |
| |
|
| | def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.profiler.ProfilerActivity: |
| | """Get the profiler activity from the string. |
| | |
| | Args: |
| | activity (str): The profiler activity name. |
| | |
| | Returns: |
| | torch.profiler.ProfilerActivity: The profiler activity. |
| | """ |
| |
|
| | profiler_activity_map: dict[str, torch.profiler.ProfilerActivity] = { |
| | "cpu": torch.profiler.ProfilerActivity.CPU, |
| | "cuda": torch.profiler.ProfilerActivity.CUDA, |
| | } |
| |
|
| | if is_hpu_available(): |
| | profiler_activity_map["hpu"] = torch.profiler.ProfilerActivity.HPU |
| |
|
| | if is_torch_version(">=", XPU_PROFILING_AVAILABLE_PYTORCH_VERSION): |
| | if torch.xpu.is_available(): |
| | profiler_activity_map["xpu"] = torch.profiler.ProfilerActivity.XPU |
| |
|
| | if is_torch_version(">=", MITA_PROFILING_AVAILABLE_PYTORCH_VERSION): |
| | if torch.mtia.is_available(): |
| | profiler_activity_map["mtia"] = torch.profiler.ProfilerActivity.MTIA |
| |
|
| | if activity not in profiler_activity_map: |
| | raise ValueError(f"Invalid profiler activity: {activity}. Must be one of {list(profiler_activity_map)}.") |
| | return profiler_activity_map[activity] |
| |
|
| | def build(self) -> torch.profiler.profile: |
| | """ |
| | Build a profiler object with the current configuration. |
| | |
| | Returns: |
| | torch.profiler.profile: The profiler object. |
| | """ |
| | activities: Optional[list[ProfilerActivity]] = None |
| | if self.activities is not None: |
| | activities = [self._get_profiler_activity(activity) for activity in self.activities] |
| | schedule: Optional[torch.profiler.schedule] = None |
| | if self.schedule_option is not None: |
| | schedule = torch.profiler.schedule(**self.schedule_option) |
| |
|
| | return torch.profiler.profile( |
| | activities=activities, |
| | schedule=schedule, |
| | on_trace_ready=self.on_trace_ready, |
| | record_shapes=self.record_shapes, |
| | profile_memory=self.profile_memory, |
| | with_stack=self.with_stack, |
| | with_flops=self.with_flops, |
| | with_modules=self.with_modules, |
| | ) |
| |
|
| |
|
| | class DistributedType(str, enum.Enum): |
| | """ |
| | Represents a type of distributed environment. |
| | |
| | Values: |
| | |
| | - **NO** -- Not a distributed environment, just a single process. |
| | - **MULTI_CPU** -- Distributed on multiple CPU nodes. |
| | - **MULTI_GPU** -- Distributed on multiple GPUs. |
| | - **MULTI_MLU** -- Distributed on multiple MLUs. |
| | - **MULTI_SDAA** -- Distributed on multiple SDAAs. |
| | - **MULTI_MUSA** -- Distributed on multiple MUSAs. |
| | - **MULTI_NPU** -- Distributed on multiple NPUs. |
| | - **MULTI_XPU** -- Distributed on multiple XPUs. |
| | - **MULTI_HPU** -- Distributed on multiple HPUs. |
| | - **DEEPSPEED** -- Using DeepSpeed. |
| | - **XLA** -- Using TorchXLA. |
| | """ |
| |
|
| | |
| | NO = "NO" |
| | MULTI_CPU = "MULTI_CPU" |
| | MULTI_GPU = "MULTI_GPU" |
| | MULTI_NPU = "MULTI_NPU" |
| | MULTI_MLU = "MULTI_MLU" |
| | MULTI_SDAA = "MULTI_SDAA" |
| | MULTI_MUSA = "MULTI_MUSA" |
| | MULTI_XPU = "MULTI_XPU" |
| | DEEPSPEED = "DEEPSPEED" |
| | FSDP = "FSDP" |
| | XLA = "XLA" |
| | MEGATRON_LM = "MEGATRON_LM" |
| | MULTI_HPU = "MULTI_HPU" |
| |
|
| |
|
| | class SageMakerDistributedType(str, enum.Enum): |
| | """ |
| | Represents a type of distributed environment. |
| | |
| | Values: |
| | |
| | - **NO** -- Not a distributed environment, just a single process. |
| | - **DATA_PARALLEL** -- using sagemaker distributed data parallelism. |
| | - **MODEL_PARALLEL** -- using sagemaker distributed model parallelism. |
| | """ |
| |
|
| | |
| | NO = "NO" |
| | DATA_PARALLEL = "DATA_PARALLEL" |
| | MODEL_PARALLEL = "MODEL_PARALLEL" |
| |
|
| |
|
| | class FP8BackendType(str, enum.Enum): |
| | """ |
| | Represents the backend used for FP8. |
| | |
| | Values: |
| | |
| | - **TE** -- using TransformerEngine. |
| | - **MSAMP** -- using msamp. |
| | """ |
| |
|
| | |
| | NO = "NO" |
| | TE = "TE" |
| | MSAMP = "MSAMP" |
| | AO = "AO" |
| |
|
| |
|
| | class ComputeEnvironment(str, enum.Enum): |
| | """ |
| | Represents a type of the compute environment. |
| | |
| | Values: |
| | |
| | - **LOCAL_MACHINE** -- private/custom cluster hardware. |
| | - **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment. |
| | """ |
| |
|
| | |
| | LOCAL_MACHINE = "LOCAL_MACHINE" |
| | AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER" |
| |
|
| |
|
| | class DynamoBackend(str, BaseEnum): |
| | """ |
| | Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html). |
| | |
| | Values: |
| | |
| | - **NO** -- Do not use torch dynamo. |
| | - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo |
| | issues. |
| | - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's |
| | extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups. |
| | - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton |
| | kernels. [Read |
| | more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747) |
| | - **AOT_TS_NVFUSER** -- nvFuser with AotAutograd/TorchScript. [Read |
| | more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) |
| | - **NVPRIMS_NVFUSER** -- nvFuser with PrimTorch. [Read |
| | more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) |
| | - **CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757) |
| | - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read |
| | more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) |
| | - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read |
| | more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) |
| | - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/) |
| | - **TENSORRT** -- Uses ONNXRT to run TensorRT for inference optimizations. [Read |
| | more](https://github.com/onnx/onnx-tensorrt) |
| | - **AOT_TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for training. [Read |
| | more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) |
| | - **TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for inference. [Read |
| | more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) |
| | - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read |
| | more](https://github.com/intel/intel-extension-for-pytorch). |
| | - **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/) |
| | - **HPU_BACKEND** -- Uses HPU backend for inference optimizations. |
| | |
| | """ |
| |
|
| | |
| | NO = "NO" |
| | EAGER = "EAGER" |
| | AOT_EAGER = "AOT_EAGER" |
| | INDUCTOR = "INDUCTOR" |
| | AOT_TS_NVFUSER = "AOT_TS_NVFUSER" |
| | NVPRIMS_NVFUSER = "NVPRIMS_NVFUSER" |
| | CUDAGRAPHS = "CUDAGRAPHS" |
| | OFI = "OFI" |
| | FX2TRT = "FX2TRT" |
| | ONNXRT = "ONNXRT" |
| | TENSORRT = "TENSORRT" |
| | AOT_TORCHXLA_TRACE_ONCE = "AOT_TORCHXLA_TRACE_ONCE" |
| | TORCHXLA_TRACE_ONCE = "TORCHXLA_TRACE_ONCE" |
| | IPEX = "IPEX" |
| | TVM = "TVM" |
| | HPU_BACKEND = "HPU_BACKEND" |
| |
|
| |
|
| | class LoggerType(BaseEnum): |
| | """Represents a type of supported experiment tracker |
| | |
| | Values: |
| | |
| | - **ALL** -- all available trackers in the environment that are supported |
| | - **TENSORBOARD** -- TensorBoard as an experiment tracker |
| | - **WANDB** -- wandb as an experiment tracker |
| | - **TRACKIO** -- trackio as an experiment tracker |
| | - **COMETML** -- comet_ml as an experiment tracker |
| | - **MLFLOW** -- mlflow as an experiment tracker |
| | - **CLEARML** -- clearml as an experiment tracker |
| | - **DVCLIVE** -- dvclive as an experiment tracker |
| | - **SWANLAB** -- swanlab as an experiment tracker |
| | """ |
| |
|
| | ALL = "all" |
| | AIM = "aim" |
| | TENSORBOARD = "tensorboard" |
| | WANDB = "wandb" |
| | TRACKIO = "trackio" |
| | COMETML = "comet_ml" |
| | MLFLOW = "mlflow" |
| | CLEARML = "clearml" |
| | DVCLIVE = "dvclive" |
| | SWANLAB = "swanlab" |
| |
|
| |
|
| | class PrecisionType(str, BaseEnum): |
| | """Represents a type of precision used on floating point values |
| | |
| | Values: |
| | |
| | - **NO** -- using full precision (FP32) |
| | - **FP16** -- using half precision |
| | - **BF16** -- using brain floating point precision |
| | """ |
| |
|
| | NO = "no" |
| | FP8 = "fp8" |
| | FP16 = "fp16" |
| | BF16 = "bf16" |
| |
|
| |
|
| | class RNGType(BaseEnum): |
| | TORCH = "torch" |
| | CUDA = "cuda" |
| | MLU = "mlu" |
| | SDAA = "sdaa" |
| | MUSA = "musa" |
| | NPU = "npu" |
| | XLA = "xla" |
| | XPU = "xpu" |
| | HPU = "hpu" |
| | GENERATOR = "generator" |
| |
|
| |
|
| | class CustomDtype(enum.Enum): |
| | r""" |
| | An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`. |
| | """ |
| |
|
| | FP8 = "fp8" |
| | INT4 = "int4" |
| | INT2 = "int2" |
| |
|
| |
|
| | |
| |
|
| |
|
| | @dataclass |
| | class TensorInformation: |
| | shape: torch.Size |
| | dtype: torch.dtype |
| |
|
| |
|
| | @dataclass |
| | class DataLoaderConfiguration: |
| | """ |
| | Configuration for dataloader-related items when calling `accelerator.prepare`. |
| | |
| | Args: |
| | split_batches (`bool`, defaults to `False`): |
| | Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If |
| | `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a |
| | round multiple of `num_processes` you are using. If `False`, actual batch size used will be the one set in |
| | your script multiplied by the number of processes. |
| | dispatch_batches (`bool`, defaults to `None`): |
| | If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process |
| | and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose |
| | underlying dataset is an `IterableDataset`, `False` otherwise. |
| | even_batches (`bool`, defaults to `True`): |
| | If set to `True`, in cases where the total batch size across all processes does not exactly divide the |
| | dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among |
| | all workers. |
| | use_seedable_sampler (`bool`, defaults to `False`): |
| | Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures |
| | training results are fully reproducible using a different sampling technique. While seed-to-seed results |
| | may differ, on average the differences are negligible when using multiple different seeds to compare. |
| | Should also be ran with [`~utils.set_seed`] for the best results. |
| | data_seed (`int`, defaults to `None`): |
| | The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator |
| | will use the current default seed from torch. |
| | non_blocking (`bool`, defaults to `False`): |
| | If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device |
| | transfers, allowing for better overlap between dataloader communication and computation. Recommended that |
| | the prepared dataloader has `pin_memory` set to `True` to work properly. |
| | use_stateful_dataloader (`bool`, defaults to `False`): |
| | If set to `True`, the dataloader prepared by the Accelerator will be backed by |
| | [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). |
| | This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed. |
| | """ |
| |
|
| | split_batches: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" |
| | " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" |
| | " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" |
| | " in your script multiplied by the number of processes." |
| | }, |
| | ) |
| | dispatch_batches: bool = field( |
| | default=None, |
| | metadata={ |
| | "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" |
| | " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" |
| | " underlying dataset is an `IterableDataset`, `False` otherwise." |
| | }, |
| | ) |
| | even_batches: bool = field( |
| | default=True, |
| | metadata={ |
| | "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" |
| | " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" |
| | " all workers." |
| | }, |
| | ) |
| | use_seedable_sampler: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])." |
| | "Ensures training results are fully reproducible using a different sampling technique. " |
| | "While seed-to-seed results may differ, on average the differences are negligible when using" |
| | "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." |
| | }, |
| | ) |
| | data_seed: int = field( |
| | default=None, |
| | metadata={ |
| | "help": "The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator" |
| | " will use the current default seed from torch." |
| | }, |
| | ) |
| | non_blocking: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device" |
| | " transfers, allowing for better overlap between dataloader communication and computation. Recommended that the" |
| | " prepared dataloader has `pin_memory` set to `True` to work properly." |
| | }, |
| | ) |
| | use_stateful_dataloader: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by " |
| | "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." |
| | }, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class ProjectConfiguration: |
| | """ |
| | Configuration for the Accelerator object based on inner-project needs. |
| | |
| | Args: |
| | project_dir (`str`, defaults to `None`): |
| | A path to a directory for storing data. |
| | logging_dir (`str`, defaults to `None`): |
| | A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`. |
| | automatic_checkpoint_naming (`bool`, defaults to `False`): |
| | Whether saved states should be automatically iteratively named. |
| | total_limit (`int`, defaults to `None`): |
| | The maximum number of total saved states to keep. |
| | iteration (`int`, defaults to `0`): |
| | The current save iteration. |
| | save_on_each_node (`bool`, defaults to `False`): |
| | When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on |
| | the main one. |
| | """ |
| |
|
| | project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."}) |
| | logging_dir: str = field( |
| | default=None, |
| | metadata={ |
| | "help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`." |
| | }, |
| | ) |
| | automatic_checkpoint_naming: bool = field( |
| | default=False, |
| | metadata={"help": "Whether saved states should be automatically iteratively named."}, |
| | ) |
| |
|
| | total_limit: int = field( |
| | default=None, |
| | metadata={"help": "The maximum number of total saved states to keep."}, |
| | ) |
| |
|
| | iteration: int = field( |
| | default=0, |
| | metadata={"help": "The current save iteration."}, |
| | ) |
| |
|
| | save_on_each_node: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": ( |
| | "When doing multi-node distributed training, whether to save models and checkpoints on each node, or" |
| | " only on the main one" |
| | ) |
| | }, |
| | ) |
| |
|
| | def set_directories(self, project_dir: Optional[str] = None): |
| | "Sets `self.project_dir` and `self.logging_dir` to the appropriate values." |
| | self.project_dir = project_dir |
| | if self.logging_dir is None: |
| | self.logging_dir = project_dir |
| |
|
| | def __post_init__(self): |
| | self.set_directories(self.project_dir) |
| |
|
| |
|
| | @dataclass |
| | class GradientAccumulationPlugin(KwargsHandler): |
| | """ |
| | A plugin to configure gradient accumulation behavior. You can only pass one of `gradient_accumulation_plugin` or |
| | `gradient_accumulation_steps` to [`Accelerator`]. Passing both raises an error. |
| | |
| | Parameters: |
| | num_steps (`int`): |
| | The number of steps to accumulate gradients for. |
| | adjust_scheduler (`bool`, *optional*, defaults to `True`): |
| | Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be |
| | `True` if the used scheduler was not adjusted for gradient accumulation. |
| | sync_with_dataloader (`bool`, *optional*, defaults to `True`): |
| | Whether to synchronize setting the gradients when at the end of the dataloader. |
| | sync_each_batch (`bool`, *optional*): |
| | Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory |
| | requirements when using gradient accumulation with distributed training, at expense of speed. |
| | |
| | Example: |
| | |
| | ```python |
| | from accelerate.utils import GradientAccumulationPlugin |
| | |
| | gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2) |
| | accelerator = Accelerator(gradient_accumulation_plugin=gradient_accumulation_plugin) |
| | ``` |
| | """ |
| |
|
| | num_steps: int = field( |
| | default=None, |
| | metadata={"help": "The number of steps to accumulate gradients for."}, |
| | ) |
| | adjust_scheduler: bool = field( |
| | default=True, |
| | metadata={ |
| | "help": "Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation." |
| | }, |
| | ) |
| | sync_with_dataloader: bool = field( |
| | default=True, |
| | metadata={ |
| | "help": "Whether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing." |
| | }, |
| | ) |
| | sync_each_batch: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed." |
| | }, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class TorchDynamoPlugin(KwargsHandler): |
| | """ |
| | This plugin is used to compile a model with PyTorch 2.0 |
| | |
| | Args: |
| | backend (`DynamoBackend`, defaults to `None`): |
| | A valid Dynamo backend. See https://pytorch.org/docs/stable/torch.compiler.html for more details. |
| | mode (`str`, defaults to `None`): |
| | Possible options are 'default', 'reduce-overhead' or 'max-autotune'. |
| | fullgraph (`bool`, defaults to `None`): |
| | Whether it is ok to break model into several subgraphs. |
| | dynamic (`bool`, defaults to `None`): |
| | Whether to use dynamic shape for tracing. |
| | options (`Any`, defaults to `None`): |
| | A dictionary of options to pass to the backend. |
| | disable (`bool`, defaults to `False`): |
| | Turn torch.compile() into a no-op for testing |
| | use_regional_compilation (`bool`, defaults to `None`): |
| | Use it to reduce the cold start compilation time of torch.compile() by targeting repeated blocks of the |
| | same class and compiling them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`, |
| | the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the |
| | model (e.g model.lm_head) is compiled separately. |
| | """ |
| |
|
| | backend: DynamoBackend = field( |
| | default=None, |
| | metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"}, |
| | ) |
| | mode: str = field( |
| | default=None, |
| | metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"}, |
| | ) |
| | fullgraph: bool = field( |
| | default=None, |
| | metadata={"help": "Whether it is ok to break model into several subgraphs"}, |
| | ) |
| | dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"}) |
| | options: Any = field( |
| | default=None, |
| | metadata={"help": "A dictionary of options to pass to the backend."}, |
| | ) |
| | disable: bool = field( |
| | default=False, |
| | metadata={"help": "Turn torch.compile() into a no-op for testing"}, |
| | ) |
| |
|
| | use_regional_compilation: bool = field( |
| | default=None, |
| | metadata={ |
| | "help": ( |
| | |
| | "Use it to reduce the cold start compilation time of torch.compile() by targeting repeated " |
| | "blocks of the same class and compiling them sequentially to hit the compiler's cache. For " |
| | "example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed " |
| | "as `model.transformer.h[0]`. The rest of the model (e.g model.lm_head) is compiled separately." |
| | ) |
| | }, |
| | ) |
| |
|
| | def __post_init__(self): |
| | prefix = "ACCELERATE_DYNAMO_" |
| | if self.backend is None: |
| | self.backend = os.environ.get(prefix + "BACKEND", "no") |
| | self.backend = DynamoBackend(self.backend.upper()) |
| |
|
| | if self.mode is None: |
| | self.mode = os.environ.get(prefix + "MODE", "default") |
| | if self.fullgraph is None: |
| | self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1 |
| | if self.use_regional_compilation is None: |
| | self.use_regional_compilation = ( |
| | str_to_bool(os.environ.get(prefix + "USE_REGIONAL_COMPILATION", "False")) == 1 |
| | ) |
| |
|
| | if self.dynamic is None and os.environ.get(prefix + "USE_DYNAMIC", None) is not None: |
| | self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1 |
| |
|
| | def to_dict(self): |
| | dynamo_config = copy.deepcopy(self.__dict__) |
| | dynamo_config["backend"] = dynamo_config["backend"].value.lower() |
| | return dynamo_config |
| |
|
| | def to_kwargs(self): |
| | kwargs = super().to_kwargs() |
| | kwargs.pop("use_regional_compilation", None) |
| | return kwargs |
| |
|
| |
|
| | @dataclass |
| | class DeepSpeedPlugin: |
| | """ |
| | This plugin is used to integrate DeepSpeed. |
| | |
| | Args: |
| | hf_ds_config (`Any`, defaults to `None`): |
| | Path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`. |
| | gradient_accumulation_steps (`int`, defaults to `None`): |
| | Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value |
| | from the `Accelerator` directly. |
| | gradient_clipping (`float`, defaults to `None`): |
| | Enable gradient clipping with value. |
| | zero_stage (`int`, defaults to `None`): |
| | Possible options are 0, 1, 2, 3. Default will be taken from environment variable. |
| | is_train_batch_min (`bool`, defaults to `True`): |
| | If both train & eval dataloaders are specified, this will decide the `train_batch_size`. |
| | offload_optimizer_device (`str`, defaults to `None`): |
| | Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3. |
| | offload_param_device (`str`, defaults to `None`): |
| | Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3. |
| | offload_optimizer_nvme_path (`str`, defaults to `None`): |
| | Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. |
| | offload_param_nvme_path (`str`, defaults to `None`): |
| | Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. |
| | zero3_init_flag (`bool`, defaults to `None`): |
| | Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. |
| | zero3_save_16bit_model (`bool`, defaults to `None`): |
| | Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. |
| | transformer_moe_cls_names (`str`, defaults to `None`): |
| | Comma-separated list of Transformers MoE layer class names (case-sensitive). For example, |
| | `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention`, `JetMoEBlock`, etc. |
| | enable_msamp (`bool`, defaults to `None`): |
| | Flag to indicate whether to enable MS-AMP backend for FP8 training. |
| | msasmp_opt_level (`Optional[Literal["O1", "O2"]]`, defaults to `None`): |
| | Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one |
| | of ['O1' or 'O2']. |
| | """ |
| |
|
| | hf_ds_config: Any = field( |
| | default=None, |
| | metadata={ |
| | "help": "path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`." |
| | }, |
| | ) |
| | gradient_accumulation_steps: int = field( |
| | default=None, |
| | metadata={ |
| | "help": "Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly." |
| | }, |
| | ) |
| | gradient_clipping: float = field(default=None, metadata={"help": "Enable gradient clipping with value"}) |
| | zero_stage: int = field( |
| | default=None, |
| | metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"}, |
| | ) |
| | is_train_batch_min: bool = field( |
| | default=True, |
| | metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"}, |
| | ) |
| | offload_optimizer_device: str = field( |
| | default=None, |
| | metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."}, |
| | ) |
| | offload_param_device: str = field( |
| | default=None, |
| | metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."}, |
| | ) |
| | offload_optimizer_nvme_path: str = field( |
| | default=None, |
| | metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, |
| | ) |
| | offload_param_nvme_path: str = field( |
| | default=None, |
| | metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, |
| | ) |
| | zero3_init_flag: bool = field( |
| | default=None, |
| | metadata={ |
| | "help": "Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." |
| | "Only applicable with ZeRO Stage-3." |
| | }, |
| | ) |
| | zero3_save_16bit_model: bool = field( |
| | default=None, |
| | metadata={"help": "Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."}, |
| | ) |
| | transformer_moe_cls_names: str = field( |
| | default=None, |
| | metadata={ |
| | "help": "comma-separated list of transformers MoE layer class names (case-sensitive), e.g : " |
| | " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..." |
| | }, |
| | ) |
| | enable_msamp: bool = field( |
| | default=None, |
| | metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."}, |
| | ) |
| | msamp_opt_level: Optional[Literal["O1", "O2"]] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one of ['O1' or 'O2']." |
| | }, |
| | ) |
| |
|
| | def __post_init__(self): |
| | from .deepspeed import HfDeepSpeedConfig |
| |
|
| | if self.gradient_accumulation_steps is None: |
| | gas = os.environ.get("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", "auto") |
| | self.gradient_accumulation_steps = int(gas) if gas.isdigit() else gas |
| |
|
| | if self.gradient_clipping is None: |
| | gradient_clipping = os.environ.get("ACCELERATE_GRADIENT_CLIPPING", "auto") |
| | self.gradient_clipping = gradient_clipping if gradient_clipping == "auto" else float(gradient_clipping) |
| |
|
| | if self.zero_stage is None: |
| | self.zero_stage = int(os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", 2)) |
| |
|
| | if self.offload_optimizer_device is None: |
| | self.offload_optimizer_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none") |
| |
|
| | if self.offload_param_device is None: |
| | self.offload_param_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", "none") |
| |
|
| | if self.offload_optimizer_nvme_path is None: |
| | self.offload_optimizer_nvme_path = os.environ.get( |
| | "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", "none" |
| | ) |
| |
|
| | if self.offload_param_nvme_path is None: |
| | self.offload_param_nvme_path = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", "none") |
| |
|
| | if self.zero3_save_16bit_model is None: |
| | self.zero3_save_16bit_model = ( |
| | os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false").lower() == "true" |
| | ) |
| | if self.enable_msamp is None: |
| | self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP" |
| |
|
| | if self.msamp_opt_level is None: |
| | self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1") |
| |
|
| | if self.hf_ds_config is None: |
| | self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") |
| |
|
| | if ( |
| | isinstance(self.hf_ds_config, dict) |
| | or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none") |
| | or isinstance(self.hf_ds_config, HfDeepSpeedConfig) |
| | ): |
| | if not isinstance(self.hf_ds_config, HfDeepSpeedConfig): |
| | self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config) |
| | if "gradient_accumulation_steps" not in self.hf_ds_config.config: |
| | self.hf_ds_config.config["gradient_accumulation_steps"] = 1 |
| | if "zero_optimization" not in self.hf_ds_config.config: |
| | raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config.") |
| |
|
| | self._deepspeed_config_checks() |
| | plugin_to_config_mapping = { |
| | "gradient_accumulation_steps": "gradient_accumulation_steps", |
| | "gradient_clipping": "gradient_clipping", |
| | "zero_stage": "zero_optimization.stage", |
| | "offload_optimizer_device": "zero_optimization.offload_optimizer.device", |
| | "offload_param_device": "zero_optimization.offload_param.device", |
| | "offload_param_nvme_path": "zero_optimization.offload_param.nvme_path", |
| | "offload_optimizer_nvme_path": "zero_optimization.offload_optimizer.nvme_path", |
| | "zero3_save_16bit_model": "zero_optimization.stage3_gather_16bit_weights_on_model_save", |
| | } |
| | kwargs = {v: getattr(self, k) for k, v in plugin_to_config_mapping.items() if getattr(self, k) is not None} |
| | for key in kwargs.keys(): |
| | self.fill_match(key, **kwargs, must_match=False) |
| | self.hf_ds_config.set_stage_and_offload() |
| |
|
| | |
| | |
| | for key, value in plugin_to_config_mapping.items(): |
| | config_value = self.hf_ds_config.get_value(value) |
| | if config_value is not None and config_value != "auto": |
| | setattr(self, key, config_value) |
| | else: |
| | config = { |
| | "train_batch_size": "auto", |
| | "train_micro_batch_size_per_gpu": "auto", |
| | "gradient_accumulation_steps": self.gradient_accumulation_steps, |
| | "zero_optimization": { |
| | "stage": self.zero_stage, |
| | "offload_optimizer": { |
| | "device": self.offload_optimizer_device, |
| | "nvme_path": ( |
| | self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None |
| | ), |
| | }, |
| | "offload_param": { |
| | "device": self.offload_param_device, |
| | "nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None), |
| | }, |
| | "stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model, |
| | }, |
| | } |
| | if self.gradient_clipping: |
| | config["gradient_clipping"] = self.gradient_clipping |
| | self.hf_ds_config = HfDeepSpeedConfig(config) |
| |
|
| | self.deepspeed_config = self.hf_ds_config.config |
| | self.deepspeed_config["steps_per_print"] = float("inf") |
| | if self.zero3_init_flag is None: |
| | self.zero3_init_flag = ( |
| | str_to_bool( |
| | os.environ.get( |
| | "ACCELERATE_DEEPSPEED_ZERO3_INIT", |
| | str(self.hf_ds_config.is_zero3()), |
| | ) |
| | ) |
| | == 1 |
| | ) |
| | if self.zero3_init_flag and not self.hf_ds_config.is_zero3(): |
| | warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.") |
| | self.zero3_init_flag = False |
| | |
| | |
| | self._set_selected(False) |
| |
|
| | |
| | if self.enable_msamp and "msamp" not in self.deepspeed_config: |
| | if self.zero_stage == 3: |
| | raise NotImplementedError( |
| | "MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead." |
| | ) |
| | if self.msamp_opt_level not in ["O1", "O2"]: |
| | raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].") |
| | self.deepspeed_config["msamp"] = { |
| | "enabled": True, |
| | "opt_level": self.msamp_opt_level, |
| | } |
| |
|
| | def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs): |
| | mismatches = [] if mismatches is None else mismatches |
| | config, ds_key = self.hf_ds_config.find_config_node(ds_key_long) |
| | if config is None: |
| | return |
| |
|
| | if config.get(ds_key) == "auto": |
| | if ds_key_long in kwargs: |
| | config[ds_key] = kwargs[ds_key_long] |
| | return |
| | else: |
| | raise ValueError( |
| | f"`{ds_key_long}` not found in kwargs. " |
| | f"Please specify `{ds_key_long}` without `auto` (set to correct value) in the DeepSpeed config file or " |
| | "pass it in kwargs." |
| | ) |
| |
|
| | if not must_match: |
| | return |
| |
|
| | ds_val = config.get(ds_key) |
| | if ds_val is not None and ds_key_long in kwargs: |
| | if ds_val != kwargs[ds_key_long]: |
| | mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}") |
| |
|
| | def is_auto(self, ds_key_long): |
| | val = self.hf_ds_config.get_value(ds_key_long) |
| | if val is None: |
| | return False |
| | else: |
| | return val == "auto" |
| |
|
| | def get_value(self, ds_key_long, default=None): |
| | return self.hf_ds_config.get_value(ds_key_long, default) |
| |
|
| | def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs): |
| | """Process the DeepSpeed config with the values from the kwargs.""" |
| | mismatches = [] if mismatches is None else mismatches |
| | if config is None: |
| | config = self.deepspeed_config |
| | for key, value in config.items(): |
| | if isinstance(value, dict): |
| | self.deepspeed_config_process( |
| | prefix=prefix + key + ".", |
| | mismatches=mismatches, |
| | config=value, |
| | must_match=must_match, |
| | **kwargs, |
| | ) |
| | else: |
| | self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs) |
| | if len(mismatches) > 0 and prefix == "": |
| | mismatches_msg = "\n".join(mismatches) |
| | raise ValueError( |
| | "Please correct the following DeepSpeed config values that mismatch kwargs " |
| | f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'." |
| | ) |
| |
|
| | def set_mixed_precision(self, mixed_precision): |
| | ds_config = self.deepspeed_config |
| | kwargs = { |
| | "fp16.enabled": mixed_precision == "fp16", |
| | |
| | "bf16.enabled": mixed_precision in ("bf16", "fp8"), |
| | } |
| | if mixed_precision == "fp16": |
| | if "fp16" not in ds_config: |
| | ds_config["fp16"] = {"enabled": True, "auto_cast": True} |
| | elif mixed_precision in ("bf16", "fp8"): |
| | if "bf16" not in ds_config: |
| | ds_config["bf16"] = {"enabled": True} |
| |
|
| | if mixed_precision == "fp8" and self.enable_msamp: |
| | if "msamp" not in ds_config: |
| | ds_config["msamp"] = { |
| | "enabled": True, |
| | "opt_level": self.msamp_opt_level, |
| | } |
| |
|
| | if mixed_precision != "no": |
| | diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" |
| | if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true": |
| | raise ValueError( |
| | f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file." |
| | ) |
| | for dtype in ["fp16", "bf16"]: |
| | if dtype not in ds_config: |
| | ds_config[dtype] = {"enabled": False} |
| | self.fill_match("fp16.enabled", must_match=False, **kwargs) |
| | self.fill_match("bf16.enabled", must_match=False, **kwargs) |
| |
|
| | def set_deepspeed_weakref(self): |
| | from .imports import is_transformers_available |
| |
|
| | ds_config = copy.deepcopy(self.deepspeed_config) |
| | if self.zero3_init_flag: |
| | if not is_transformers_available(): |
| | raise Exception( |
| | "When `zero3_init_flag` is set, it requires Transformers to be installed. " |
| | "Please run `pip install transformers`." |
| | ) |
| | if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto": |
| | ds_config["gradient_accumulation_steps"] = 1 |
| | if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto": |
| | ds_config["train_micro_batch_size_per_gpu"] = 1 |
| | if ds_config.get("train_batch_size", None) == "auto": |
| | del ds_config["train_batch_size"] |
| |
|
| | if compare_versions("transformers", "<", "4.46"): |
| | from transformers.deepspeed import ( |
| | HfDeepSpeedConfig, |
| | unset_hf_deepspeed_config, |
| | ) |
| | else: |
| | from transformers.integrations import ( |
| | HfDeepSpeedConfig, |
| | unset_hf_deepspeed_config, |
| | ) |
| |
|
| | unset_hf_deepspeed_config() |
| | self.dschf = HfDeepSpeedConfig(ds_config) |
| |
|
| | def is_zero3_init_enabled(self): |
| | return self.zero3_init_flag |
| |
|
| | @contextmanager |
| | def zero3_init_context_manager(self, enable=False): |
| | old = self.zero3_init_flag |
| | if old == enable: |
| | yield |
| | else: |
| | self.zero3_init_flag = enable |
| | self.dschf = None |
| | self.set_deepspeed_weakref() |
| | yield |
| | self.zero3_init_flag = old |
| | self.dschf = None |
| | self.set_deepspeed_weakref() |
| |
|
| | def _deepspeed_config_checks(self): |
| | env_variable_names_to_ignore = [ |
| | "ACCELERATE_GRADIENT_ACCUMULATION_STEPS", |
| | "ACCELERATE_GRADIENT_CLIPPING", |
| | "ACCELERATE_DEEPSPEED_ZERO_STAGE", |
| | "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", |
| | "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", |
| | "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", |
| | "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", |
| | "ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", |
| | "ACCELERATE_MIXED_PRECISION", |
| | ] |
| | env_variable_names_to_ignore = [ |
| | name.replace("ACCELERATE_", "").replace("DEEPSPEED_", "").lower() for name in env_variable_names_to_ignore |
| | ] |
| |
|
| | deepspeed_fields_from_accelerate_config = os.environ.get("ACCELERATE_CONFIG_DS_FIELDS", "").split(",") |
| |
|
| | if any(name in env_variable_names_to_ignore for name in deepspeed_fields_from_accelerate_config): |
| | raise ValueError( |
| | f"When using `deepspeed_config_file`, the following accelerate config variables will be ignored: {env_variable_names_to_ignore}.\n" |
| | "Please specify them appropriately in the DeepSpeed config file.\n" |
| | "If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\n" |
| | "The easiest method is to create a new config following the questionnaire via `accelerate config`.\n" |
| | "It will only ask for the necessary config variables when using `deepspeed_config_file`." |
| | ) |
| |
|
| | def set_moe_leaf_modules(self, model): |
| | if self.transformer_moe_cls_names is None: |
| | self.transformer_moe_cls_names = os.environ.get("ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES", None) |
| | if self.transformer_moe_cls_names is not None: |
| | if compare_versions("deepspeed", "<", "0.14.0"): |
| | raise ImportError("DeepSpeed version must be >= 0.14.0 to use MOE support. Please update DeepSpeed.") |
| | from deepspeed.utils import set_z3_leaf_modules |
| |
|
| | class_names = self.transformer_moe_cls_names.split(",") |
| | transformer_moe_cls = [] |
| | for layer_class in class_names: |
| | transformer_cls = get_module_class_from_name(model, layer_class) |
| | if transformer_cls is None: |
| | raise Exception( |
| | f"Could not find a transformer layer class called '{layer_class}' to wrap in the model." |
| | ) |
| | else: |
| | transformer_moe_cls.append(transformer_cls) |
| | set_z3_leaf_modules(model, transformer_moe_cls) |
| |
|
| | def select(self, _from_accelerator_state: bool = False): |
| | """ |
| | Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration |
| | """ |
| | if not _from_accelerator_state: |
| | raise ValueError( |
| | "A `DeepSpeedPlugin` object must be enabled manually by calling `AcceleratorState().enable_deepspeed_plugin(plugin_key)`." |
| | ) |
| | self.set_deepspeed_weakref() |
| | self._set_selected(True) |
| |
|
| | def _unselect(self): |
| | self._set_selected(False) |
| |
|
| | def _set_selected(self, value: bool): |
| | """ |
| | Private setter for the 'enabled' attribute. |
| | """ |
| | self._selected = value |
| |
|
| | @property |
| | def selected(self): |
| | return self._selected |
| |
|
| | @selected.setter |
| | def selected(self, value): |
| | raise NotImplementedError( |
| | "'enabled' can only be set through calling 'AcceleratorState().enable_deepspeed_plugin(key)'." |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class FullyShardedDataParallelPlugin: |
| | """ |
| | This plugin is used to enable fully sharded data parallelism. |
| | |
| | Args: |
| | fsdp_version (`int`, defaults to `1`): |
| | The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to |
| | FSDP2 format. |
| | sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy]`, defaults to `'FULL_SHARD'`): |
| | Sharding strategy to use. Should be either a `str` or an instance of |
| | `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Is deprecated in favor of |
| | `reshard_after_forward`. |
| | reshard_after_forward (`Union[str, torch.distributed.fsdp.ShardingStrategy, bool]`, defaults to `'FULL_SHARD'` for `fsdp_version=1` and `True` for `fsdp_version=2`): |
| | Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of |
| | `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. |
| | backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`): |
| | Backward prefetch strategy to use. Should be either a `str` or an instance of |
| | `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. |
| | mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`): |
| | A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it |
| | should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of |
| | `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it |
| | should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`, |
| | `reduce_dtype`, and `buffer_dtype`. |
| | auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`): |
| | A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one |
| | of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See |
| | `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. |
| | cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]`, defaults to `False`): |
| | Whether to offload parameters to CPU. Should be either a `bool` or an instance of |
| | `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or |
| | `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. |
| | ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`): |
| | A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name |
| | using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used. |
| | state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`): |
| | State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or |
| | `sharded_state_dict`. |
| | state_dict_config (`Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]`, defaults to `None`): |
| | State dict config to use. Is determined based on the `state_dict_type` if not passed in. |
| | optim_state_dict_config (`Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]`, defaults to `None`): |
| | Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in. |
| | limit_all_gathers (`bool`, defaults to `True`): |
| | Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This |
| | bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number |
| | of CUDA malloc retries. |
| | use_orig_params (`bool`, defaults to `False`): |
| | Whether to use the original parameters for the optimizer. |
| | param_init_fn (`Optional[Callable[[torch.nn.Module], None]`, defaults to `None`): |
| | A `Callable[torch.nn.Module] -> None` that specifies how modules that are currently on the meta device |
| | should be initialized onto an actual device. Only applicable when `sync_module_states` is `True`. By |
| | default is a `lambda` which calls `to_empty` on the module. |
| | sync_module_states (`bool`, defaults to `False`): |
| | Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they |
| | are the same across all ranks after initialization. Defaults to `False` unless `cpu_ram_efficient_loading` |
| | is `True`, then will be forcibly enabled. |
| | forward_prefetch (`bool`, defaults to `False`): |
| | Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward |
| | pass. only use with Static graphs. |
| | activation_checkpointing (`bool`, defaults to `False`): |
| | A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a |
| | backward pass. Effectively, this trades extra computation time for reduced memory usage. |
| | cpu_ram_efficient_loading (`bool`, defaults to `None`): |
| | If True, only the first process loads the pretrained model checkoint while all other processes have empty |
| | weights. Only applicable for Transformers. When using this, `sync_module_states` needs to be `True`. |
| | transformer_cls_names_to_wrap (`Optional[List[str]]`, defaults to `None`): |
| | A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is |
| | `transformer_based_wrap`. |
| | min_num_params (`Optional[int]`, defaults to `None`): |
| | The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` |
| | is `size_based_wrap`. |
| | """ |
| |
|
| | fsdp_version: int = field( |
| | default=None, |
| | metadata={ |
| | "help": "The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format." |
| | }, |
| | ) |
| |
|
| | sharding_strategy: Union[str, "torch.distributed.fsdp.ShardingStrategy"] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'. Is deprecated in favor of `reshard_after_forward` " |
| | }, |
| | ) |
| |
|
| | reshard_after_forward: Union[str, "torch.distributed.fsdp.ShardingStrategy", bool] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'" |
| | }, |
| | ) |
| | backward_prefetch: Optional[Union[str, "torch.distributed.fsdp.BackwardPrefetch"]] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'. This becomes obsolete in FSDP2." |
| | }, |
| | ) |
| | mixed_precision_policy: Optional[ |
| | Union[ |
| | dict, |
| | str, |
| | "torch.distributed.fsdp.MixedPrecision", |
| | "torch.distributed.fsdp.MixedPrecisionPolicy", |
| | ] |
| | ] = field( |
| | default=None, |
| | metadata={ |
| | "help": "A config to enable mixed precision training with FullyShardedDataParallel. " |
| | "If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`." |
| | "Can also be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2." |
| | }, |
| | ) |
| | auto_wrap_policy: Optional[Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]] = ( |
| | field( |
| | default=None, |
| | metadata={ |
| | "help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. " |
| | "Defaults to `NO_WRAP`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like" |
| | }, |
| | ) |
| | ) |
| | cpu_offload: Union[ |
| | bool, |
| | "torch.distributed.fsdp.CPUOffload", |
| | "torch.distributed.fsdp.CPUOffloadPolicy", |
| | ] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`" |
| | }, |
| | ) |
| | ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field( |
| | default=None, |
| | metadata={"help": "A list of modules to ignore when wrapping with FSDP."}, |
| | ) |
| |
|
| | state_dict_type: Union[str, "torch.distributed.fsdp.StateDictType"] = field( |
| | default=None, |
| | metadata={ |
| | "help": "State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or `sharded_state_dict`. Defaults to `FULL_STATE_DICT`" |
| | }, |
| | ) |
| | state_dict_config: Optional[ |
| | Union[ |
| | "torch.distributed.fsdp.FullStateDictConfig", |
| | "torch.distributed.fsdp.ShardedStateDictConfig", |
| | ] |
| | ] = field( |
| | default=None, |
| | metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."}, |
| | ) |
| | optim_state_dict_config: Optional[ |
| | Union[ |
| | "torch.distributed.fsdp.FullOptimStateDictConfig", |
| | "torch.distributed.fsdp.ShardedOptimStateDictConfig", |
| | ] |
| | ] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in." |
| | }, |
| | ) |
| | limit_all_gathers: bool = field( |
| | default=True, |
| | metadata={ |
| | "help": "Whether to have FSDP explicitly synchronizes the CPU thread to prevent " |
| | "too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. " |
| | "Enabling this can help lower the number of CUDA malloc retries." |
| | }, |
| | ) |
| | use_orig_params: Optional[bool] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Whether to use the original parameters for the optimizer. Defaults to `False`. This becomes obsolete in FSDP2." |
| | }, |
| | ) |
| | param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( |
| | default=None, |
| | metadata={ |
| | "help": "A Callable[torch.nn.Module] -> None that specifies how modules " |
| | "that are currently on the meta device should be initialized onto an actual device. " |
| | "Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module." |
| | }, |
| | ) |
| | sync_module_states: Optional[bool] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " |
| | "to ensure they are the same across all ranks after initialization. Defaults to `False` unless " |
| | "`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled. This becomes obsolete in FSDP2." |
| | }, |
| | ) |
| | forward_prefetch: bool = field( |
| | default=None, |
| | metadata={ |
| | "help": "Whether to have FSDP explicitly prefetches the next upcoming " |
| | "all-gather while executing in the forward pass. only use with Static graphs. Defaults to `False`" |
| | }, |
| | ) |
| | activation_checkpointing: bool = field( |
| | default=None, |
| | metadata={ |
| | "help": "A technique to reduce memory usage by clearing activations of " |
| | "certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time " |
| | "for reduced memory usage. Defaults to `False`" |
| | }, |
| | ) |
| | cpu_ram_efficient_loading: bool = field( |
| | default=None, |
| | metadata={ |
| | "help": "If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. " |
| | "Only applicable for 🤗 Transformers. When using this, `sync_module_states` needs to be `True`. Defaults to `False`." |
| | }, |
| | ) |
| | transformer_cls_names_to_wrap: Optional[list[str]] = field( |
| | default=None, |
| | metadata={ |
| | "help": "A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is `transformer_based_wrap`." |
| | }, |
| | ) |
| | min_num_params: Optional[int] = field( |
| | default=None, |
| | metadata={ |
| | "help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`." |
| | }, |
| | ) |
| |
|
| | def __post_init__(self): |
| | from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy |
| |
|
| | _fsdp2_warnings = set() |
| |
|
| | env_prefix = "FSDP_" |
| | |
| | if self.fsdp_version is None: |
| | self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1")) |
| |
|
| | if self.fsdp_version == 2: |
| | if not is_torch_version(">=", FSDP2_PYTORCH_VERSION): |
| | raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}") |
| |
|
| | if self.sharding_strategy is not None: |
| | |
| | |
| | _fsdp2_warnings.add( |
| | "sharding_strategy is deprecated in favor of reshard_after_forward. " |
| | "This will be removed in a future version of Accelerate." |
| | ) |
| | if self.fsdp_version == 1: |
| | if self.sharding_strategy is None: |
| | self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD") |
| | if isinstance(self.sharding_strategy, str): |
| | if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY: |
| | self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1 |
| | if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit(): |
| | self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy)) |
| | else: |
| | self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()] |
| |
|
| | |
| | if self.reshard_after_forward is None and self.sharding_strategy is None: |
| | reshard_after_forward = os.environ.get( |
| | env_prefix + "RESHARD_AFTER_FORWARD", |
| | "true" if self.fsdp_version == 2 else "FULL_SHARD", |
| | ) |
| | if self.fsdp_version == 2: |
| | self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True) |
| | else: |
| | self.reshard_after_forward = reshard_after_forward |
| | if isinstance(self.reshard_after_forward, str): |
| | if self.fsdp_version == 2: |
| | self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True) |
| | else: |
| | |
| | if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY: |
| | self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1 |
| | if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit(): |
| | self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward)) |
| | else: |
| | self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()] |
| |
|
| | if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool): |
| | raise ValueError( |
| | f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`" |
| | ) |
| | if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool): |
| | raise ValueError( |
| | f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" |
| | ) |
| |
|
| | if self.cpu_offload is None: |
| | self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1 |
| |
|
| | self.set_cpu_offload() |
| | self.validate_cpu_offload() |
| |
|
| | if self.backward_prefetch is None: |
| | self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None) |
| | if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH": |
| | self.backward_prefetch = None |
| | if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch): |
| | if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH: |
| | self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 |
| | if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): |
| | self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) |
| | else: |
| | self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] |
| | if self.fsdp_version == 2 and self.backward_prefetch is not None: |
| | _fsdp2_warnings.add("backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.") |
| | self.backward_prefetch = None |
| |
|
| | self.set_state_dict_type() |
| |
|
| | if self.auto_wrap_policy is None: |
| | self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP") |
| | if isinstance(self.auto_wrap_policy, str): |
| | if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY: |
| | raise ValueError( |
| | f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}" |
| | ) |
| | from torch.distributed.fsdp.wrap import ( |
| | size_based_auto_wrap_policy, |
| | transformer_auto_wrap_policy, |
| | ) |
| |
|
| | if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP": |
| | self.auto_wrap_policy = transformer_auto_wrap_policy |
| | if self.transformer_cls_names_to_wrap is None: |
| | self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None) |
| | if isinstance(self.transformer_cls_names_to_wrap, str): |
| | self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",") |
| | elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP": |
| | self.auto_wrap_policy = size_based_auto_wrap_policy |
| | if self.min_num_params is None: |
| | self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0)) |
| | elif not isinstance(self.min_num_params, int): |
| | raise ValueError( |
| | f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}" |
| | ) |
| | elif self.auto_wrap_policy.upper() == "NO_WRAP": |
| | self.auto_wrap_policy = None |
| |
|
| | if self.use_orig_params is None and self.fsdp_version == 1: |
| | self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1 |
| | if self.fsdp_version == 2 and self.use_orig_params is not None: |
| | _fsdp2_warnings.add("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.") |
| | self.use_orig_params = None |
| |
|
| | if self.sync_module_states is None and self.fsdp_version == 1: |
| | self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1 |
| | if self.fsdp_version == 2 and self.sync_module_states is not None: |
| | _fsdp2_warnings.add( |
| | "sync_module_states is obsolete in FSDP2, as it is not needed anymore." |
| | "Setting sync_module_states to None." |
| | ) |
| | self.sync_module_states = None |
| |
|
| | if self.forward_prefetch is None and self.fsdp_version == 1: |
| | self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1 |
| | if self.fsdp_version == 2 and self.forward_prefetch is not None: |
| | raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`") |
| |
|
| | if self.activation_checkpointing is None: |
| | self.activation_checkpointing = ( |
| | str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 |
| | ) |
| |
|
| | if self.ignored_modules is None: |
| | self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None) |
| |
|
| | if self.cpu_ram_efficient_loading is None: |
| | self.cpu_ram_efficient_loading = ( |
| | str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1 |
| | ) |
| | else: |
| | |
| | os.environ[env_prefix + "CPU_RAM_EFFICIENT_LOADING"] = str(self.cpu_ram_efficient_loading) |
| | |
| | if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states: |
| | warnings.warn( |
| | "sync_module_states cannot be False since efficient cpu ram loading enabled. " |
| | "Setting sync_module_states to True." |
| | ) |
| | self.sync_module_states = True |
| | if isinstance(self.mixed_precision_policy, str): |
| | |
| | |
| | self.set_mixed_precision(self.mixed_precision_policy, override=True) |
| | elif isinstance(self.mixed_precision_policy, dict): |
| | self.set_mixed_precision(self.mixed_precision_policy) |
| | if self.mixed_precision_policy is not None: |
| | self.validate_mixed_precision_policy() |
| |
|
| | if self.sync_module_states: |
| | if is_npu_available(): |
| | device = torch.npu.current_device() |
| | elif is_mlu_available(): |
| | device = torch.mlu.current_device() |
| | elif is_musa_available(): |
| | device = torch.musa.current_device() |
| | elif is_cuda_available(): |
| | device = torch.cuda.current_device() |
| | elif is_xpu_available(): |
| | device = torch.xpu.current_device() |
| | elif is_hpu_available(): |
| | device = torch.hpu.current_device() |
| | else: |
| | raise RuntimeError( |
| | "There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'." |
| | ) |
| | |
| | |
| | self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) |
| | if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None: |
| | _fsdp2_warnings.add( |
| | "FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0" |
| | "Setting ignored_modules to None." |
| | ) |
| | self.ignored_modules = None |
| | |
| | if _fsdp2_warnings: |
| | logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings)) |
| |
|
| | def set_state_dict_type(self, state_dict_type=None): |
| | """ |
| | Set the state dict config based on the `StateDictType`. |
| | """ |
| | from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
| | FullOptimStateDictConfig, |
| | FullStateDictConfig, |
| | ShardedOptimStateDictConfig, |
| | ShardedStateDictConfig, |
| | StateDictType, |
| | ) |
| |
|
| | |
| | |
| | if state_dict_type is not None: |
| | self.state_dict_type = state_dict_type |
| |
|
| | if self.state_dict_type is None: |
| | self.state_dict_type = os.environ.get( |
| | "FSDP_STATE_DICT_TYPE", |
| | "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT", |
| | ) |
| | if isinstance(self.state_dict_type, str): |
| | if self.state_dict_type.isdigit(): |
| | self.state_dict_type = StateDictType(int(self.state_dict_type)) |
| | else: |
| | self.state_dict_type = StateDictType[self.state_dict_type.upper()] |
| |
|
| | if self.state_dict_type == StateDictType.FULL_STATE_DICT: |
| | if self.state_dict_config is None: |
| | self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| | if self.optim_state_dict_config is None: |
| | self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| | elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT: |
| | if self.state_dict_config is None: |
| | self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) |
| | if self.optim_state_dict_config is None: |
| | self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True) |
| |
|
| | if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT: |
| | raise ValueError( |
| | "FSDP2 does not support LOCAL_STATE_DICT. " |
| | "Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`." |
| | ) |
| |
|
| | def set_auto_wrap_policy(self, model): |
| | """ |
| | Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the |
| | `transformer_cls_to_wrap` |
| | """ |
| | from torch.distributed.fsdp.wrap import ( |
| | size_based_auto_wrap_policy, |
| | transformer_auto_wrap_policy, |
| | ) |
| |
|
| | |
| | no_split_modules = getattr(model, "_no_split_modules", None) |
| | default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else [] |
| | if self.auto_wrap_policy == transformer_auto_wrap_policy: |
| | if self.transformer_cls_names_to_wrap is None: |
| | self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap |
| | transformer_cls_to_wrap = set() |
| | for layer_class in self.transformer_cls_names_to_wrap: |
| | transformer_cls = get_module_class_from_name(model, layer_class) |
| | if transformer_cls is None: |
| | raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.") |
| | transformer_cls_to_wrap.add(transformer_cls) |
| | |
| | self.auto_wrap_policy = functools.partial( |
| | self.auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap |
| | ) |
| |
|
| | elif self.auto_wrap_policy == size_based_auto_wrap_policy: |
| | |
| | if self.min_num_params > 0: |
| | self.auto_wrap_policy = functools.partial(self.auto_wrap_policy, min_num_params=self.min_num_params) |
| | else: |
| | self.auto_wrap_policy = None |
| |
|
| | def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False): |
| | "Sets the mixed precision policy for FSDP" |
| | mixed_precision_mapping = { |
| | "fp8": torch.bfloat16, |
| | "fp16": torch.float16, |
| | "bf16": torch.bfloat16, |
| | "fp32": torch.float32, |
| | } |
| | dtype = mixed_precision |
| | if isinstance(mixed_precision, str): |
| | dtype = mixed_precision_mapping.get(mixed_precision, None) |
| | if dtype is None: |
| | raise ValueError( |
| | f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}" |
| | ) |
| | elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values(): |
| | raise ValueError( |
| | f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}" |
| | ) |
| |
|
| | buffer_type = torch.float32 if buffer_autocast else dtype |
| |
|
| | if self.fsdp_version == 1: |
| | from torch.distributed.fsdp import MixedPrecision |
| | elif self.fsdp_version == 2: |
| | from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision |
| |
|
| | if override or self.mixed_precision_policy is None: |
| | dtype_args = {"param_dtype": dtype, "reduce_dtype": dtype} |
| | if self.fsdp_version == 1: |
| | dtype_args["buffer_dtype"] = buffer_type |
| | else: |
| | dtype_args["output_dtype"] = dtype |
| | |
| | self.mixed_precision_policy = MixedPrecision(**dtype_args) |
| | elif isinstance(self.mixed_precision_policy, dict): |
| | |
| | valid_keys = ["param_dtype", "reduce_dtype"] + ( |
| | ["buffer_dtype"] if self.fsdp_version == 1 else ["output_dtype"] |
| | ) |
| | missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy] |
| | invalid_values = [ |
| | k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values() |
| | ] |
| | if missing_keys or invalid_values: |
| | raise ValueError( |
| | f"Invalid mixed precision policy: {self.mixed_precision_policy}. " |
| | f"Must be a `dict` with keys {valid_keys}." |
| | f"Values must be one of {list(mixed_precision_mapping.values())}" |
| | ) |
| | self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) |
| |
|
| | def validate_mixed_precision_policy(self): |
| | """ |
| | Validates the mixed precision policy, abstracted away to not bring in the imports if not needed. |
| | """ |
| | if self.fsdp_version == 2: |
| | from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision |
| | else: |
| | from torch.distributed.fsdp import MixedPrecision |
| |
|
| | if not isinstance(self.mixed_precision_policy, MixedPrecision): |
| | required_type = ( |
| | "`torch.distributed.fsdp.MixedPrecisionPolicy`" |
| | if self.fsdp_version == 2 |
| | else "`torch.distributed.fsdp.MixedPrecision`" |
| | ) |
| | raise ValueError(f"mixed_precision_policy must be an instance of {required_type}.") |
| |
|
| | def set_cpu_offload(self): |
| | if self.fsdp_version == 2: |
| | from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy |
| | else: |
| | from torch.distributed.fsdp import CPUOffload |
| |
|
| | if isinstance(self.cpu_offload, bool): |
| | if self.fsdp_version == 2: |
| | if not self.cpu_offload: |
| | self.cpu_offload = OffloadPolicy() |
| | else: |
| | self.cpu_offload = CPUOffloadPolicy() |
| | else: |
| | self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) |
| |
|
| | def validate_cpu_offload(self): |
| | if self.fsdp_version == 2: |
| | from torch.distributed.fsdp import OffloadPolicy |
| | else: |
| | from torch.distributed.fsdp import CPUOffload |
| |
|
| | if self.fsdp_version == 2 and not isinstance(self.cpu_offload, OffloadPolicy): |
| | raise ValueError( |
| | f"`cpu_offload` must be an instance of `torch.distributed.fsdp.OffloadPolicy` in FSDP2, got {self.cpu_offload}" |
| | ) |
| | if self.fsdp_version == 1 and not isinstance(self.cpu_offload, CPUOffload): |
| | raise ValueError( |
| | f"`cpu_offload` must be an instance of `torch.distributed.fsdp.CPUOffload` in FSDP1, got {self.cpu_offload}" |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class TorchTensorParallelPlugin: |
| | """ |
| | This plugin is used to enable tensor parallelism using PyTorch >= 2.0. |
| | """ |
| |
|
| | tp_size: int = field( |
| | default=1, |
| | metadata={"help": "tensor parallel size will be used in the device mesh preparation"}, |
| | ) |
| |
|
| | |
| | torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None) |
| |
|
| |
|
| | @dataclass |
| | class TorchContextParallelConfig: |
| | """ |
| | This class holds the configuration for context parallelism in PyTorch. |
| | """ |
| |
|
| | cp_comm_strategy: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'." |
| | }, |
| | ) |
| |
|
| | def __post_init__(self): |
| | if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION): |
| | raise ValueError( |
| | f"FSDP2-based Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. " |
| | "Please upgrade your PyTorch version." |
| | ) |
| |
|
| | if self.cp_comm_strategy is None: |
| | self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather") |
| | if self.cp_comm_strategy not in ["allgather", "alltoall"]: |
| | raise ValueError( |
| | f"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'." |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class DeepSpeedSequenceParallelConfig: |
| | sp_seq_length: Optional[int] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `sp_seq_length_is_variable=True` and leave this field unset" |
| | }, |
| | ) |
| | sp_seq_length_is_variable: Optional[bool] = field( |
| | default=None, |
| | metadata={ |
| | "help": "If `True` will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by cp size or remain unset. If `False` then `sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`." |
| | }, |
| | ) |
| | sp_attn_implementation: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`." |
| | }, |
| | ) |
| |
|
| | def __post_init__(self): |
| | |
| | if self.sp_seq_length_is_variable is None: |
| | self.sp_seq_length_is_variable = ( |
| | os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE", "true").lower() == "true" |
| | ) |
| |
|
| | if not self.sp_seq_length_is_variable and self.sp_seq_length is None: |
| | if "PARALLELISM_CONFIG_SP_SEQ_LENGTH" not in os.environ: |
| | raise ValueError( |
| | "when `sp_seq_length_is_variable` is `False` `sp_seq_length` must be provided either through the constructor or the environment variable PARALLELISM_CONFIG_SP_SEQ_LENGTH" |
| | ) |
| | else: |
| | self.sp_seq_length = os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH") |
| | self.sp_seq_length = None if self.sp_seq_length == "None" else int(self.sp_seq_length) |
| |
|
| | if self.sp_attn_implementation is None: |
| | self.sp_attn_implementation = os.environ.get("PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION", None) |
| |
|
| | if self.sp_attn_implementation is not None and self.sp_attn_implementation not in [ |
| | "flash_attention_2", |
| | "flash_attention_3", |
| | "sdpa", |
| | ]: |
| | raise ValueError( |
| | f"Invalid sp_attn_implementation: {self.sp_attn_implementation}. Must be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'." |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class TorchTensorParallelConfig: |
| | """ |
| | Use this object in your [`Accelerator`] to customize your torch tensor parallelism. |
| | """ |
| |
|
| | enable_async_tp: bool = False |
| |
|
| | def __post_init__(self): |
| | if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION): |
| | raise ValueError( |
| | f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. " |
| | "Please upgrade your PyTorch version." |
| | ) |
| |
|
| | if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION): |
| | raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}") |
| |
|
| | if self.enable_async_tp: |
| | warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.") |
| |
|
| |
|
| | @dataclass |
| | class MegatronLMPlugin: |
| | """ |
| | Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective |
| | activation recomputation and optimized fused kernels. |
| | |
| | Args: |
| | tp_degree (`int`, defaults to `None`): |
| | Tensor parallelism degree. |
| | pp_degree (`int`, defaults to `None`): |
| | Pipeline parallelism degree. |
| | num_micro_batches (`int`, defaults to `None`): |
| | Number of micro-batches. |
| | gradient_clipping (`float`, defaults to `None`): |
| | Gradient clipping value based on global L2 Norm (0 to disable). |
| | sequence_parallelism (`bool`, defaults to `None`): |
| | Enable sequence parallelism. |
| | recompute_activations (`bool`, defaults to `None`): |
| | Enable selective activation recomputation. |
| | use_distributed_optimizr (`bool`, defaults to `None`): |
| | Enable distributed optimizer. |
| | pipeline_model_parallel_split_rank (`int`, defaults to `None`): |
| | Rank where encoder and decoder should be split. |
| | num_layers_per_virtual_pipeline_stage (`int`, defaults to `None`): |
| | Number of layers per virtual pipeline stage. |
| | is_train_batch_min (`str`, defaults to `True`): |
| | If both tran & eval dataloaders are specified, this will decide the `micro_batch_size`. |
| | train_iters (`int`, defaults to `None`): |
| | Total number of samples to train over all training runs. Note that either train-iters or train-samples |
| | should be provided when using `MegatronLMDummyScheduler`. |
| | train_samples (`int`, defaults to `None`): |
| | Total number of samples to train over all training runs. Note that either train-iters or train-samples |
| | should be provided when using `MegatronLMDummyScheduler`. |
| | weight_decay_incr_style (`str`, defaults to `'constant'`): |
| | Weight decay increment function. choices=["constant", "linear", "cosine"]. |
| | start_weight_decay (`float`, defaults to `None`): |
| | Initial weight decay coefficient for L2 regularization. |
| | end_weight_decay (`float`, defaults to `None`): |
| | End of run weight decay coefficient for L2 regularization. |
| | lr_decay_style (`str`, defaults to `'linear'`): |
| | Learning rate decay function. choices=['constant', 'linear', 'cosine']. |
| | lr_decay_iters (`int`, defaults to `None`): |
| | Number of iterations for learning rate decay. If None defaults to `train_iters`. |
| | lr_decay_samples (`int`, defaults to `None`): |
| | Number of samples for learning rate decay. If None defaults to `train_samples`. |
| | lr_warmup_iters (`int`, defaults to `None`): |
| | Number of iterations to linearly warmup learning rate over. |
| | lr_warmup_samples (`int`, defaults to `None`): |
| | Number of samples to linearly warmup learning rate over. |
| | lr_warmup_fraction (`float`, defaults to `None`): |
| | Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over. |
| | min_lr (`float`, defaults to `0`): |
| | Minimum value for learning rate. The scheduler clip values below this threshold. |
| | consumed_samples (`List`, defaults to `None`): |
| | Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call. |
| | no_wd_decay_cond (`Optional`, defaults to `None`): |
| | Condition to disable weight decay. |
| | scale_lr_cond (`Optional`, defaults to `None`): |
| | Condition to scale learning rate. |
| | lr_mult (`float`, defaults to `1.0`): |
| | Learning rate multiplier. |
| | megatron_dataset_flag (`bool`, defaults to `False`): |
| | Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format. |
| | seq_length (`int`, defaults to `None`): |
| | Maximum sequence length to process. |
| | encoder_seq_length (`int`, defaults to `None`): |
| | Maximum sequence length to process for the encoder. |
| | decoder_seq_length (`int`, defaults to `None`): |
| | Maximum sequence length to process for the decoder. |
| | tensorboard_dir (`str`, defaults to `None`): |
| | Path to save tensorboard logs. |
| | set_all_logging_options (`bool`, defaults to `False`): |
| | Whether to set all logging options. |
| | eval_iters (`int`, defaults to `100`): |
| | Number of iterations to run for evaluation validation/test for. |
| | eval_interval (`int`, defaults to `1000`): |
| | Interval between running evaluation on validation set. |
| | return_logits (`bool`, defaults to `False`): |
| | Whether to return logits from the model. |
| | custom_train_step_class (`Optional`, defaults to `None`): |
| | Custom train step class. |
| | custom_train_step_kwargs (`Optional`, defaults to `None`): |
| | Custom train step kwargs. |
| | custom_model_provider_function (`Optional`, defaults to `None`): |
| | Custom model provider function. |
| | custom_prepare_model_function (`Optional`, defaults to `None`): |
| | Custom prepare model function. |
| | custom_megatron_datasets_provider_function (`Optional`, defaults to `None`): |
| | Custom megatron train_valid_test datasets provider function. |
| | custom_get_batch_function (`Optional`, defaults to `None`): |
| | Custom get batch function. |
| | custom_loss_function (`Optional`, defaults to `None`): |
| | Custom loss function. |
| | other_megatron_args (`Optional`, defaults to `None`): |
| | Other Megatron-LM arguments. Please refer Megatron-LM. |
| | """ |
| |
|
| | tp_degree: int = field(default=None, metadata={"help": "tensor parallelism degree."}) |
| | pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."}) |
| | num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."}) |
| | gradient_clipping: float = field( |
| | default=None, |
| | metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"}, |
| | ) |
| | sequence_parallelism: bool = field( |
| | default=None, |
| | metadata={"help": "enable sequence parallelism"}, |
| | ) |
| | recompute_activations: bool = field( |
| | default=None, |
| | metadata={"help": "enable selective activation recomputation"}, |
| | ) |
| | use_distributed_optimizer: bool = field( |
| | default=None, |
| | metadata={"help": "enable distributed optimizer"}, |
| | ) |
| | pipeline_model_parallel_split_rank: int = field( |
| | default=None, |
| | metadata={"help": "Rank where encoder and decoder should be split."}, |
| | ) |
| | num_layers_per_virtual_pipeline_stage: int = field( |
| | default=None, metadata={"help": "Number of layers per virtual pipeline stage."} |
| | ) |
| | is_train_batch_min: str = field( |
| | default=True, |
| | metadata={"help": "If both train & eval dataloaders are specified, this will decide the micro_batch_size"}, |
| | ) |
| | train_iters: int = field( |
| | default=None, |
| | metadata={ |
| | "help": "Total number of iterations to train over all training runs. " |
| | "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" |
| | }, |
| | ) |
| | train_samples: int = field( |
| | default=None, |
| | metadata={ |
| | "help": "Total number of samples to train over all training runs. " |
| | "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" |
| | }, |
| | ) |
| | weight_decay_incr_style: str = field( |
| | default="constant", |
| | metadata={"help": 'Weight decay increment function. choices=["constant", "linear", "cosine"]. '}, |
| | ) |
| | start_weight_decay: float = field( |
| | default=None, |
| | metadata={"help": "Initial weight decay coefficient for L2 regularization."}, |
| | ) |
| | end_weight_decay: float = field( |
| | default=None, |
| | metadata={"help": "End of run weight decay coefficient for L2 regularization."}, |
| | ) |
| | lr_decay_style: str = field( |
| | default="linear", |
| | metadata={"help": "Learning rate decay function. choices=['constant', 'linear', 'cosine']."}, |
| | ) |
| | lr_decay_iters: int = field( |
| | default=None, |
| | metadata={"help": "Number of iterations for learning rate decay. If None defaults to `train_iters`."}, |
| | ) |
| | lr_decay_samples: int = field( |
| | default=None, |
| | metadata={"help": "Number of samples for learning rate decay. If None defaults to `train_samples`."}, |
| | ) |
| | lr_warmup_iters: int = field( |
| | default=None, |
| | metadata={"help": "number of iterations to linearly warmup learning rate over."}, |
| | ) |
| | lr_warmup_samples: int = field( |
| | default=None, |
| | metadata={"help": "number of samples to linearly warmup learning rate over."}, |
| | ) |
| | lr_warmup_fraction: float = field( |
| | default=None, |
| | metadata={"help": "fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over."}, |
| | ) |
| | min_lr: float = field( |
| | default=0, |
| | metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."}, |
| | ) |
| | consumed_samples: list[int] = field( |
| | default=None, |
| | metadata={ |
| | "help": "Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call." |
| | }, |
| | ) |
| | no_wd_decay_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to disable weight decay."}) |
| | scale_lr_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to scale learning rate."}) |
| | lr_mult: float = field(default=1.0, metadata={"help": "Learning rate multiplier."}) |
| | megatron_dataset_flag: bool = field( |
| | default=False, |
| | metadata={"help": "Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format."}, |
| | ) |
| | seq_length: int = field( |
| | default=None, |
| | metadata={"help": "Maximum sequence length to process."}, |
| | ) |
| | encoder_seq_length: int = field( |
| | default=None, |
| | metadata={"help": "Maximum sequence length to process for the encoder."}, |
| | ) |
| | decoder_seq_length: int = field( |
| | default=None, |
| | metadata={"help": "Maximum sequence length to process for the decoder."}, |
| | ) |
| | tensorboard_dir: str = field( |
| | default=None, |
| | metadata={"help": "Path to save tensorboard logs."}, |
| | ) |
| | set_all_logging_options: bool = field( |
| | default=False, |
| | metadata={"help": "Whether to set all logging options."}, |
| | ) |
| | eval_iters: int = field( |
| | default=100, |
| | metadata={"help": "Number of iterations to run for evaluation validation/test for."}, |
| | ) |
| | eval_interval: int = field( |
| | default=1000, |
| | metadata={"help": "Interval between running evaluation on validation set."}, |
| | ) |
| | return_logits: bool = field( |
| | default=False, |
| | metadata={"help": "Whether to return logits from the model."}, |
| | ) |
| |
|
| | |
| | custom_train_step_class: Optional[Any] = field( |
| | default=None, |
| | metadata={"help": "Custom train step class."}, |
| | ) |
| | custom_train_step_kwargs: Optional[dict[str, Any]] = field( |
| | default=None, |
| | metadata={"help": "Custom train step kwargs."}, |
| | ) |
| |
|
| | |
| | custom_model_provider_function: Optional[Callable] = field( |
| | default=None, |
| | metadata={"help": "Custom model provider function."}, |
| | ) |
| | custom_prepare_model_function: Optional[Callable] = field( |
| | default=None, |
| | metadata={"help": "Custom prepare model function."}, |
| | ) |
| | custom_megatron_datasets_provider_function: Optional[Callable] = field( |
| | default=None, |
| | metadata={"help": "Custom megatron train_valid_test datasets provider function."}, |
| | ) |
| | custom_get_batch_function: Optional[Callable] = field( |
| | default=None, |
| | metadata={"help": "Custom get batch function."}, |
| | ) |
| | custom_loss_function: Optional[Callable] = field( |
| | default=None, |
| | metadata={"help": "Custom loss function."}, |
| | ) |
| |
|
| | |
| | |
| | other_megatron_args: Optional[dict[str, Any]] = field( |
| | default=None, |
| | metadata={"help": "Other Megatron-LM arguments. Please refer Megatron-LM"}, |
| | ) |
| |
|
| | def __post_init__(self): |
| | prefix = "MEGATRON_LM_" |
| | if self.tp_degree is None: |
| | self.tp_degree = int(os.environ.get(prefix + "TP_DEGREE", 1)) |
| | if self.pp_degree is None: |
| | self.pp_degree = int(os.environ.get(prefix + "PP_DEGREE", 1)) |
| | if self.num_micro_batches is None: |
| | self.num_micro_batches = int(os.environ.get(prefix + "NUM_MICRO_BATCHES", 1)) |
| | if self.gradient_clipping is None: |
| | self.gradient_clipping = float(os.environ.get(prefix + "GRADIENT_CLIPPING", 1.0)) |
| | if self.recompute_activations is None: |
| | self.recompute_activations = str_to_bool(os.environ.get(prefix + "RECOMPUTE_ACTIVATIONS", "False")) == 1 |
| | if self.use_distributed_optimizer is None: |
| | self.use_distributed_optimizer = ( |
| | str_to_bool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1 |
| | ) |
| | if self.sequence_parallelism is None: |
| | self.sequence_parallelism = str_to_bool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1 |
| |
|
| | if self.pp_degree > 1 or self.use_distributed_optimizer: |
| | self.DDP_impl = "local" |
| | else: |
| | self.DDP_impl = "torch" |
| |
|
| | if self.consumed_samples is not None: |
| | if len(self.consumed_samples) == 1: |
| | self.consumed_samples.extend([0, 0]) |
| | elif len(self.consumed_samples) == 2: |
| | self.consumed_samples.append(0) |
| |
|
| | self.megatron_lm_default_args = { |
| | "tensor_model_parallel_size": self.tp_degree, |
| | "pipeline_model_parallel_size": self.pp_degree, |
| | "pipeline_model_parallel_split_rank": self.pipeline_model_parallel_split_rank, |
| | "num_layers_per_virtual_pipeline_stage": self.num_layers_per_virtual_pipeline_stage, |
| | "DDP_impl": self.DDP_impl, |
| | "use_distributed_optimizer": self.use_distributed_optimizer, |
| | "sequence_parallel": self.sequence_parallelism, |
| | "clip_grad": self.gradient_clipping, |
| | "num_micro_batches": self.num_micro_batches, |
| | "consumed_samples": self.consumed_samples, |
| | "no_wd_decay_cond": self.no_wd_decay_cond, |
| | "scale_lr_cond": self.scale_lr_cond, |
| | "lr_mult": self.lr_mult, |
| | "megatron_dataset_flag": self.megatron_dataset_flag, |
| | "eval_iters": self.eval_iters, |
| | "eval_interval": self.eval_interval, |
| | } |
| | if self.recompute_activations: |
| | self.megatron_lm_default_args["recompute_granularity"] = "selective" |
| | if self.tensorboard_dir is not None: |
| | self.megatron_lm_default_args["tensorboard_dir"] = self.tensorboard_dir |
| | if self.set_all_logging_options: |
| | self.set_tensorboard_logging_options() |
| | if self.other_megatron_args is not None: |
| | self.megatron_lm_default_args.update(self.other_megatron_args) |
| |
|
| | def set_network_size_args(self, model, batch_data=None): |
| | model_config_type = model.config.model_type.lower() |
| | for model_type in MODEL_CONFIGS_TO_MEGATRON_PARSERS.keys(): |
| | if model_type in model_config_type: |
| | MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type](self, model, batch_data) |
| | return |
| | raise ValueError( |
| | f"Accelerate Megatron-LM integration not supports {model_config_type} model. " |
| | "You can add your own model config parser." |
| | ) |
| |
|
| | def set_mixed_precision(self, mixed_precision): |
| | if mixed_precision == "fp16": |
| | self.megatron_lm_default_args["fp16"] = True |
| | elif mixed_precision == "bf16": |
| | self.megatron_lm_default_args["bf16"] = True |
| | self.DDP_impl = "local" |
| | self.megatron_lm_default_args["DDP_impl"] = self.DDP_impl |
| |
|
| | def set_training_args(self, micro_batch_size, dp_degree): |
| | self.data_parallel_size = dp_degree |
| | self.micro_batch_size = micro_batch_size |
| | self.global_batch_size = dp_degree * micro_batch_size * self.num_micro_batches |
| | self.megatron_lm_default_args["data_parallel_size"] = self.data_parallel_size |
| | self.megatron_lm_default_args["micro_batch_size"] = self.micro_batch_size |
| | self.megatron_lm_default_args["global_batch_size"] = self.global_batch_size |
| |
|
| | def set_optimizer_type(self, optimizer): |
| | optimizer_name = optimizer.__class__.__name__.lower() |
| | if "adam" in optimizer_name: |
| | self.megatron_lm_default_args["optimizer"] = "adam" |
| | self.megatron_lm_default_args["adam_beta1"] = optimizer.defaults["betas"][0] |
| | self.megatron_lm_default_args["adam_beta2"] = optimizer.defaults["betas"][1] |
| | self.megatron_lm_default_args["adam_eps"] = optimizer.defaults["eps"] |
| | elif "sgd" in optimizer_name: |
| | self.megatron_lm_default_args["optimizer"] = "sgd" |
| | self.megatron_lm_default_args["sgd_momentum"] = optimizer.defaults["momentum"] |
| | else: |
| | raise ValueError(f"Optimizer {optimizer_name} is not supported by Megatron-LM") |
| |
|
| | self.megatron_lm_default_args["lr"] = optimizer.defaults["lr"] |
| | self.megatron_lm_default_args["weight_decay"] = optimizer.defaults["weight_decay"] |
| |
|
| | def set_scheduler_args(self, scheduler): |
| | if self.train_iters is None: |
| | self.train_iters = scheduler.total_num_steps // self.megatron_lm_default_args["data_parallel_size"] |
| | if self.train_samples is not None: |
| | self.train_samples = None |
| | warnings.warn( |
| | "Ignoring `train_samples` as `train_iters` based on scheduler is being used for training." |
| | ) |
| | if self.lr_warmup_iters is None: |
| | self.lr_warmup_iters = scheduler.warmup_num_steps // self.megatron_lm_default_args["data_parallel_size"] |
| | if self.lr_warmup_samples is not None: |
| | warnings.warn( |
| | "Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training." |
| | ) |
| | self.lr_warmup_samples = 0 |
| |
|
| | self.megatron_lm_default_args["train_iters"] = self.train_iters |
| | self.megatron_lm_default_args["lr_warmup_iters"] = self.lr_warmup_iters |
| | self.megatron_lm_default_args["train_samples"] = self.train_samples |
| | self.megatron_lm_default_args["lr_warmup_samples"] = self.lr_warmup_samples |
| | self.megatron_lm_default_args["lr_decay_iters"] = self.lr_decay_iters |
| | self.megatron_lm_default_args["lr_decay_samples"] = self.lr_decay_samples |
| | self.megatron_lm_default_args["lr_warmup_fraction"] = self.lr_warmup_fraction |
| | self.megatron_lm_default_args["lr_decay_style"] = self.lr_decay_style |
| | self.megatron_lm_default_args["weight_decay_incr_style"] = self.weight_decay_incr_style |
| | self.megatron_lm_default_args["start_weight_decay"] = self.start_weight_decay |
| | self.megatron_lm_default_args["end_weight_decay"] = self.end_weight_decay |
| | self.megatron_lm_default_args["min_lr"] = self.min_lr |
| |
|
| | def set_tensorboard_logging_options(self): |
| | from megatron.training.arguments import _add_logging_args |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser = _add_logging_args(parser) |
| | logging_args = parser.parse_known_args() |
| | self.dataset_args = vars(logging_args[0]) |
| | for key, value in self.dataset_args.items(): |
| | if key.startswith("log_"): |
| | self.megatron_lm_default_args[key] = True |
| | elif key.startswith("no_log_"): |
| | self.megatron_lm_default_args[key.replace("no_", "")] = True |
| |
|
| |
|
| | MODEL_CONFIGS_TO_MEGATRON_PARSERS = {} |
| |
|
| |
|
| | def add_model_config_to_megatron_parser(model_type: str): |
| | def add_model_config_parser_helper(func): |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | return func(*args, **kwargs) |
| |
|
| | MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type] = func |
| | return wrapper |
| |
|
| | return add_model_config_parser_helper |
| |
|
| |
|
| | @add_model_config_to_megatron_parser("megatron-bert") |
| | def parse_bert_config(megatron_lm_plugin, model, batch_data): |
| | model_type_name = "bert" |
| | num_layers = model.config.num_hidden_layers |
| | hidden_size = model.config.hidden_size |
| | num_attention_heads = model.config.num_attention_heads |
| | max_position_embeddings = model.config.max_position_embeddings |
| | num_labels = model.config.num_labels |
| | orig_vocab_size = model.config.vocab_size |
| | pretraining_flag = False |
| | if "maskedlm" in model.__class__.__name__.lower(): |
| | pretraining_flag = True |
| | if megatron_lm_plugin.seq_length is not None: |
| | if megatron_lm_plugin.encoder_seq_length is not None: |
| | warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.") |
| | megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length |
| | elif megatron_lm_plugin.encoder_seq_length is not None: |
| | megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length |
| | elif batch_data is not None: |
| | megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] |
| | else: |
| | megatron_lm_plugin.seq_length = max_position_embeddings |
| | megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length |
| | megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
| | megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
| | megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
| | megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
| | megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
| | megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
| | megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
| | megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
| | megatron_lm_plugin.megatron_lm_default_args["num_labels"] = num_labels |
| |
|
| |
|
| | @add_model_config_to_megatron_parser("gpt2") |
| | def parse_gpt2_config(megatron_lm_plugin, model, batch_data): |
| | model_type_name = "gpt" |
| | num_layers = model.config.n_layer |
| | hidden_size = model.config.n_embd |
| | num_attention_heads = model.config.n_head |
| | max_position_embeddings = model.config.n_positions |
| | orig_vocab_size = model.config.vocab_size |
| | pretraining_flag = True |
| | if megatron_lm_plugin.seq_length is not None: |
| | if megatron_lm_plugin.decoder_seq_length is not None: |
| | warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.") |
| | megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length |
| | elif megatron_lm_plugin.decoder_seq_length is not None: |
| | megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length |
| | elif batch_data is not None: |
| | megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] |
| | else: |
| | megatron_lm_plugin.seq_length = max_position_embeddings |
| | megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length |
| | megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits |
| | megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer" |
| | megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
| | megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
| | megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
| | megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
| | megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
| | megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
| | megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
| | megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
| |
|
| |
|
| | @add_model_config_to_megatron_parser("t5") |
| | def parse_t5_config(megatron_lm_plugin, model, batch_data): |
| | model_type_name = "t5" |
| | num_layers = model.config.num_layers |
| | hidden_size = model.config.d_model |
| | num_attention_heads = model.config.num_heads |
| | max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024 |
| | orig_vocab_size = model.config.vocab_size |
| | pretraining_flag = True |
| | if megatron_lm_plugin.encoder_seq_length is None: |
| | if batch_data is not None: |
| | megatron_lm_plugin.encoder_seq_length = batch_data["input_ids"].shape[1] |
| | else: |
| | megatron_lm_plugin.encoder_seq_length = max_position_embeddings |
| | if megatron_lm_plugin.decoder_seq_length is None: |
| | if batch_data is not None: |
| | megatron_lm_plugin.decoder_seq_length = batch_data["labels"].shape[1] |
| | else: |
| | megatron_lm_plugin.decoder_seq_length = max_position_embeddings |
| | megatron_lm_plugin.megatron_lm_default_args["encoder_seq_length"] = megatron_lm_plugin.encoder_seq_length |
| | megatron_lm_plugin.megatron_lm_default_args["decoder_seq_length"] = megatron_lm_plugin.decoder_seq_length |
| | megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
| | megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
| | megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
| | megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
| | megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
| | megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
| | megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
| | megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
| |
|
| |
|
| | @add_model_config_to_megatron_parser("llama") |
| | def parse_llama_config(megatron_lm_plugin, model, batch_data): |
| | model_type_name = "gpt" |
| | num_layers = model.config.num_hidden_layers |
| | pretraining_flag = True |
| | hidden_size = model.config.hidden_size |
| | num_attention_heads = model.config.num_attention_heads |
| | orig_vocab_size = model.config.vocab_size |
| |
|
| | max_position_embeddings = model.config.max_position_embeddings |
| | seq_length = getattr(model.config, "max_sequence_length", None) |
| | if megatron_lm_plugin.seq_length is None: |
| | if seq_length is not None: |
| | megatron_lm_plugin.seq_length = seq_length |
| | elif megatron_lm_plugin.decoder_seq_length is not None: |
| | megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length |
| | elif batch_data is not None: |
| | megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] |
| | else: |
| | megatron_lm_plugin.seq_length = max_position_embeddings |
| |
|
| | megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits |
| | megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "Llama2Tokenizer" |
| | megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
| | megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
| | megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
| | megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
| | megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
| | megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
| | megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
| | megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length |
| | megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
| |
|
| |
|
| | @dataclass |
| | class BnbQuantizationConfig: |
| | """ |
| | A plugin to enable BitsAndBytes 4bit and 8bit quantization |
| | |
| | Args: |
| | load_in_8bit (`bool`, defaults to `False`): |
| | Enable 8bit quantization. |
| | llm_int8_threshold (`float`, defaults to `6.0`): |
| | Value of the outliner threshold. Only relevant when `load_in_8bit=True`. |
| | load_in_4bit (`bool`, defaults to `False`): |
| | Enable 4bit quantization. |
| | bnb_4bit_quant_type (`str`, defaults to `fp4`): |
| | Set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}. |
| | bnb_4bit_use_double_quant (`bool`, defaults to `False`): |
| | Enable nested quantization where the quantization constants from the first quantization are quantized |
| | again. |
| | bnb_4bit_compute_dtype (`bool`, defaults to `fp16`): |
| | This sets the computational type which might be different than the input time. For example, inputs might be |
| | fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}. |
| | torch_dtype (`torch.dtype`, defaults to `None`): |
| | This sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value |
| | to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model. |
| | skip_modules (`List[str]`, defaults to `None`): |
| | An explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`. |
| | keep_in_fp32_modules (`List`, defaults to `None`): |
| | An explicit list of the modules that we don't quantize. We keep them in `torch.float32`. |
| | """ |
| |
|
| | load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) |
| |
|
| | llm_int8_threshold: float = field( |
| | default=6.0, |
| | metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}, |
| | ) |
| |
|
| | load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) |
| |
|
| | bnb_4bit_quant_type: str = field( |
| | default="fp4", |
| | metadata={ |
| | "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','nf4'}." |
| | }, |
| | ) |
| |
|
| | bnb_4bit_use_double_quant: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "enable nested quantization where the quantization constants from the first quantization are quantized again." |
| | }, |
| | ) |
| |
|
| | bnb_4bit_compute_dtype: str = field( |
| | default="fp16", |
| | metadata={ |
| | "help": "This sets the computational type which might be different than the input time. For example, inputs might be " |
| | "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." |
| | }, |
| | ) |
| |
|
| | torch_dtype: torch.dtype = field( |
| | default=None, |
| | metadata={ |
| | "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" |
| | "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " |
| | }, |
| | ) |
| |
|
| | skip_modules: list[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." |
| | }, |
| | ) |
| |
|
| | keep_in_fp32_modules: list[str] = field( |
| | default=None, |
| | metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, |
| | ) |
| |
|
| | def __post_init__(self): |
| | """ |
| | Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. |
| | """ |
| | if not isinstance(self.load_in_8bit, bool): |
| | raise ValueError("load_in_8bit must be a boolean") |
| |
|
| | if not isinstance(self.load_in_4bit, bool): |
| | raise ValueError("load_in_4bit must be a boolean") |
| |
|
| | if self.load_in_4bit and self.load_in_8bit: |
| | raise ValueError("load_in_4bit and load_in_8bit can't be both True") |
| |
|
| | if not self.load_in_4bit and not self.load_in_8bit: |
| | raise ValueError("load_in_4bit and load_in_8bit can't be both False") |
| |
|
| | if not isinstance(self.llm_int8_threshold, (int, float)): |
| | raise ValueError("llm_int8_threshold must be a float or an int") |
| |
|
| | if not isinstance(self.bnb_4bit_quant_type, str): |
| | raise ValueError("bnb_4bit_quant_type must be a string") |
| | elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]: |
| | raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}") |
| |
|
| | if not isinstance(self.bnb_4bit_use_double_quant, bool): |
| | raise ValueError("bnb_4bit_use_double_quant must be a boolean") |
| |
|
| | if isinstance(self.bnb_4bit_compute_dtype, str): |
| | if self.bnb_4bit_compute_dtype == "fp32": |
| | self.bnb_4bit_compute_dtype = torch.float32 |
| | elif self.bnb_4bit_compute_dtype == "fp16": |
| | self.bnb_4bit_compute_dtype = torch.float16 |
| | elif self.bnb_4bit_compute_dtype == "bf16": |
| | self.bnb_4bit_compute_dtype = torch.bfloat16 |
| | else: |
| | raise ValueError( |
| | f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" |
| | ) |
| | elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): |
| | raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") |
| |
|
| | if self.skip_modules is not None and not isinstance(self.skip_modules, list): |
| | raise ValueError("skip_modules must be a list of strings") |
| |
|
| | if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): |
| | raise ValueError("keep_in_fp_32_modules must be a list of strings") |
| |
|
| | if self.load_in_4bit: |
| | self.target_dtype = CustomDtype.INT4 |
| |
|
| | if self.load_in_8bit: |
| | self.target_dtype = torch.int8 |
| |
|
| | if self.load_in_4bit and self.llm_int8_threshold != 6.0: |
| | warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") |
| |
|
| | if isinstance(self.torch_dtype, str): |
| | if self.torch_dtype == "fp32": |
| | self.torch_dtype = torch.float32 |
| | elif self.torch_dtype == "fp16": |
| | self.torch_dtype = torch.float16 |
| | elif self.torch_dtype == "bf16": |
| | self.torch_dtype = torch.bfloat16 |
| | else: |
| | raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") |
| | if self.load_in_8bit and self.torch_dtype is None: |
| | self.torch_dtype = torch.float16 |
| |
|
| | if self.load_in_4bit and self.torch_dtype is None: |
| | self.torch_dtype = self.bnb_4bit_compute_dtype |
| |
|
| | if not isinstance(self.torch_dtype, torch.dtype): |
| | raise ValueError("torch_dtype must be a torch.dtype") |
| |
|
| |
|
| | def get_module_class_from_name(module, name): |
| | """ |
| | Gets a class from a module by its name. |
| | |
| | Args: |
| | module (`torch.nn.Module`): The module to get the class from. |
| | name (`str`): The name of the class. |
| | """ |
| | modules_children = list(module.children()) |
| | if module.__class__.__name__ == name: |
| | return module.__class__ |
| | elif len(modules_children) == 0: |
| | return |
| | else: |
| | for child_module in modules_children: |
| | module_class = get_module_class_from_name(child_module, name) |
| | if module_class is not None: |
| | return module_class |
| |
|