| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import logging |
| | import os |
| | import threading |
| | import warnings |
| | import weakref |
| | from contextlib import contextmanager |
| | from functools import partial |
| | from typing import Any, Callable |
| |
|
| | import torch |
| |
|
| | from .utils import ( |
| | DistributedType, |
| | DynamoBackend, |
| | GradientAccumulationPlugin, |
| | check_cuda_fp8_capability, |
| | check_cuda_p2p_ib_support, |
| | deepspeed_required, |
| | get_cpu_distributed_information, |
| | get_int_from_env, |
| | is_ccl_available, |
| | is_datasets_available, |
| | is_deepspeed_available, |
| | is_fp8_available, |
| | is_habana_gaudi1, |
| | is_hpu_available, |
| | is_ipex_available, |
| | is_mlu_available, |
| | is_mps_available, |
| | is_musa_available, |
| | is_npu_available, |
| | is_sdaa_available, |
| | is_torch_xla_available, |
| | is_xccl_available, |
| | is_xpu_available, |
| | parse_choice_from_env, |
| | parse_flag_from_env, |
| | set_numa_affinity, |
| | ) |
| | from .utils.dataclasses import SageMakerDistributedType |
| |
|
| |
|
| | if is_torch_xla_available(): |
| | import torch_xla.core.xla_model as xm |
| | import torch_xla.runtime as xr |
| |
|
| | if is_mlu_available(check_device=False): |
| | import torch_mlu |
| |
|
| | if is_sdaa_available(check_device=False): |
| | import torch_sdaa |
| |
|
| | if is_musa_available(check_device=False): |
| | import torch_musa |
| |
|
| | if is_npu_available(check_device=False): |
| | import torch_npu |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def is_initialized() -> bool: |
| | """ |
| | Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`, |
| | but works as a module method. |
| | """ |
| | return AcceleratorState._shared_state != {} |
| |
|
| |
|
| | |
| | def do_nothing(*args, **kwargs): |
| | return None |
| |
|
| |
|
| | class ThreadLocalSharedDict(threading.local): |
| | """ |
| | Descriptor that holds a dict shared between instances of a class in the same thread. |
| | |
| | Note: Descriptors have slightly different semantics than just a dict field on its own. |
| | `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the |
| | underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside |
| | the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor |
| | object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`). |
| | |
| | See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html |
| | |
| | This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3). |
| | |
| | See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3 |
| | """ |
| |
|
| | def __init__(self, thread_local: bool = False): |
| | self._storage = {} |
| |
|
| | def __get__(self, obj, objtype=None): |
| | return self._storage |
| |
|
| | def __set__(self, obj, value): |
| | self._storage = value |
| |
|
| |
|
| | |
| | SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict |
| |
|
| |
|
| | |
| | class PartialState: |
| | """ |
| | Singleton class that has information about the current training environment and functions to help with process |
| | control. Designed to be used when only process control and device execution states are needed. Does *not* need to |
| | be initialized from `Accelerator`. |
| | |
| | Args: |
| | cpu (`bool`, *optional*): |
| | Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to |
| | `True` and force the execution on the CPU. |
| | kwargs (additional keyword arguments, *optional*): |
| | Additional keyword arguments to pass to the relevant `init_process_group` function. Valid `kwargs` can be |
| | found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage. |
| | |
| | **Available attributes:** |
| | |
| | - **device** (`torch.device`) -- The device to use. |
| | - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently |
| | in use. |
| | - **local_process_index** (`int`) -- The index of the current process on the current server. |
| | - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type |
| | of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8'). |
| | - **num_processes** (`int`) -- The number of processes currently launched in parallel. |
| | - **process_index** (`int`) -- The index of the current process. |
| | - **is_last_process** (`bool`) -- Whether or not the current process is the last one. |
| | - **is_main_process** (`bool`) -- Whether or not the current process is the main one. |
| | - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. |
| | - **debug** (`bool`) -- Whether or not the current script is being run in debug mode. |
| | |
| | Example: |
| | ```python |
| | from accelerate.utils import InitProcessGroupKwargs |
| | |
| | # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()` |
| | kwargs = InitProcessGroupKwargs(...).to_kwargs() |
| | state = PartialState(**kwargs) |
| | ``` |
| | """ |
| |
|
| | _shared_state = SharedDict() |
| | _known_attrs = [ |
| | "_cpu", |
| | "_mixed_precision", |
| | "_shared_state", |
| | "backend", |
| | "debug", |
| | "device", |
| | "distributed_type", |
| | "fork_launched", |
| | "local_process_index", |
| | "num_processes", |
| | "process_index", |
| | ] |
| |
|
| | def __init__(self, cpu: bool = False, **kwargs): |
| | self.__dict__ = self._shared_state |
| | if not self.initialized: |
| | self._cpu = cpu |
| | self.backend = None |
| | env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) |
| | self.device = torch.device(env_device) if env_device is not None else None |
| | self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE") |
| | use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None) |
| | dist_information = None |
| | if use_sagemaker_dp is None: |
| | use_sagemaker_dp = ( |
| | os.environ.get("ACCELERATE_USE_SAGEMAKER", "false").lower() == "true" |
| | and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO |
| | ) |
| |
|
| | |
| | original_backend = kwargs.pop("backend", None) |
| | backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend) |
| | if original_backend is not None and backend != original_backend: |
| | raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}") |
| | self.backend = backend |
| | self.distributed_type = distributed_type |
| | use_deepspeed = False |
| | if not cpu and self.backend != "xla": |
| | if int(os.environ.get("LOCAL_RANK", -1)) != -1: |
| | |
| | if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true": |
| | if not is_deepspeed_available(): |
| | raise ImportError( |
| | "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source" |
| | ) |
| | from deepspeed import comm as dist |
| |
|
| | if not dist.is_initialized(): |
| | if self.backend == "tccl": |
| | local_rank = os.environ.get("LOCAL_RANK", -1) |
| | torch.sdaa.set_device(f"sdaa:{local_rank}") |
| | dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs) |
| | |
| | use_deepspeed = True |
| | |
| | elif ( |
| | self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU) |
| | and not torch.distributed.is_initialized() |
| | ): |
| | if self.backend == "tccl": |
| | local_rank = os.environ.get("LOCAL_RANK", -1) |
| | torch.sdaa.set_device(f"sdaa:{local_rank}") |
| | if ( |
| | self.backend == "nccl" |
| | and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" |
| | and ( |
| | os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true" |
| | or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT" |
| | ) |
| | ): |
| | self.backend = "cuda:nccl,cpu:gloo" |
| | torch.distributed.init_process_group(backend=self.backend, **kwargs) |
| |
|
| | |
| | if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU): |
| | dist_information = get_cpu_distributed_information() |
| | os.environ["RANK"] = str(dist_information.rank) |
| | os.environ["WORLD_SIZE"] = str(dist_information.world_size) |
| | os.environ["LOCAL_RANK"] = str(dist_information.local_rank) |
| | os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size) |
| | if not os.environ.get("MASTER_PORT", None): |
| | os.environ["MASTER_PORT"] = "29500" |
| | if ( |
| | not os.environ.get("MASTER_ADDR", None) |
| | and dist_information.local_world_size != dist_information.world_size |
| | and self.backend != "mpi" |
| | ): |
| | raise ValueError( |
| | "Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, " |
| | "please try exporting rank 0's hostname as `MASTER_ADDR`" |
| | ) |
| | kwargs["rank"] = dist_information.rank |
| | kwargs["world_size"] = dist_information.world_size |
| |
|
| | if ( |
| | self.distributed_type == DistributedType.MULTI_CPU |
| | and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0 |
| | ): |
| | import psutil |
| |
|
| | num_cpu_threads_per_process = int( |
| | psutil.cpu_count(logical=False) / dist_information.local_world_size |
| | ) |
| | if num_cpu_threads_per_process == 0: |
| | num_cpu_threads_per_process = 1 |
| | torch.set_num_threads(num_cpu_threads_per_process) |
| | warnings.warn( |
| | f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob" |
| | " performance." |
| | ) |
| |
|
| | if not torch.distributed.is_initialized(): |
| | torch.distributed.init_process_group(backend=self.backend, **kwargs) |
| |
|
| | |
| | if self.backend is None: |
| | self.distributed_type = DistributedType.NO |
| | self.num_processes = 1 |
| | self.process_index = 0 |
| | self.local_process_index = 0 |
| | elif self.backend == "xla": |
| | |
| | self.set_device() |
| | xm.set_replication(self.device, xm.get_xla_supported_devices()) |
| | self.num_processes = xr.world_size() |
| | self.process_index = xr.global_ordinal() |
| | if is_torch_xla_available(check_is_tpu=True): |
| | self.local_process_index = xm.get_local_ordinal() |
| | else: |
| | self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) |
| | else: |
| | self.num_processes = torch.distributed.get_world_size() |
| | self.process_index = torch.distributed.get_rank() |
| | self.local_process_index = ( |
| | int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank |
| | ) |
| | self.set_device() |
| | |
| | if use_deepspeed: |
| | self.distributed_type = DistributedType.DEEPSPEED |
| |
|
| | |
| | if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False): |
| | set_numa_affinity(self.local_process_index) |
| |
|
| | |
| | if self.device.type == "cuda" and not check_cuda_p2p_ib_support(): |
| | if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ: |
| | raise NotImplementedError( |
| | "Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. " |
| | 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which ' |
| | "will do this automatically." |
| | ) |
| |
|
| | |
| | self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) |
| |
|
| | def __repr__(self) -> str: |
| | return ( |
| | f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n" |
| | f"Num processes: {self.num_processes}\n" |
| | f"Process index: {self.process_index}\n" |
| | f"Local process index: {self.local_process_index}\n" |
| | f"Device: {self.device}\n" |
| | ) |
| |
|
| | @staticmethod |
| | def _reset_state(): |
| | "Resets `_shared_state`, is used internally and should not be called" |
| | PartialState._shared_state.clear() |
| |
|
| | @property |
| | def initialized(self) -> bool: |
| | "Returns whether the `PartialState` has been initialized" |
| | return self._shared_state != {} |
| |
|
| | @property |
| | def use_distributed(self): |
| | """ |
| | Whether the Accelerator is configured for distributed training |
| | """ |
| | return self.distributed_type != DistributedType.NO and self.num_processes > 1 |
| |
|
| | @property |
| | def is_last_process(self) -> bool: |
| | "Returns whether the current process is the last one" |
| | return self.process_index == self.num_processes - 1 |
| |
|
| | @property |
| | def is_main_process(self) -> bool: |
| | "Returns whether the current process is the main process" |
| | return ( |
| | self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process |
| | ) |
| |
|
| | @property |
| | def is_local_main_process(self) -> bool: |
| | "Returns whether the current process is the main process on the local node" |
| | return ( |
| | self.local_process_index == 0 |
| | if self.distributed_type != DistributedType.MEGATRON_LM |
| | else self.is_last_process |
| | ) |
| |
|
| | def wait_for_everyone(self): |
| | """ |
| | Will stop the execution of the current process until every other process has reached that point (so this does |
| | nothing when the script is only run in one process). Useful to do before saving a model. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> # Assuming two GPU processes |
| | >>> import time |
| | >>> from accelerate.state import PartialState |
| | |
| | >>> state = PartialState() |
| | >>> if state.is_main_process: |
| | ... time.sleep(2) |
| | >>> else: |
| | ... print("I'm waiting for the main process to finish its sleep...") |
| | >>> state.wait_for_everyone() |
| | >>> # Should print on every process at the same time |
| | >>> print("Everyone is here") |
| | ``` |
| | """ |
| | if self.distributed_type in ( |
| | DistributedType.MULTI_GPU, |
| | DistributedType.MULTI_MLU, |
| | DistributedType.MULTI_SDAA, |
| | DistributedType.MULTI_MUSA, |
| | DistributedType.MULTI_NPU, |
| | DistributedType.MULTI_XPU, |
| | DistributedType.MULTI_CPU, |
| | DistributedType.MULTI_HPU, |
| | DistributedType.DEEPSPEED, |
| | DistributedType.FSDP, |
| | ): |
| | torch.distributed.barrier(device_ids=[self.local_process_index]) |
| | elif self.distributed_type == DistributedType.XLA: |
| | xm.rendezvous("accelerate.utils.wait_for_everyone") |
| |
|
| | def _goes_first(self, is_main: bool): |
| | if not is_main: |
| | self.wait_for_everyone() |
| |
|
| | yield |
| |
|
| | if is_main: |
| | self.wait_for_everyone() |
| |
|
| | @contextmanager |
| | def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False): |
| | """ |
| | Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing |
| | distributed inference, such as with different prompts. |
| | |
| | Note that when using a `dict`, all keys need to have the same number of elements. |
| | |
| | Args: |
| | inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`): |
| | The input to split between processes. |
| | apply_padding (`bool`, `optional`, defaults to `False`): |
| | Whether to apply padding by repeating the last element of the input so that all processes have the same |
| | number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing |
| | in less inputs than there are processes. If so, just remember to drop the padded elements afterwards. |
| | |
| | |
| | Example: |
| | |
| | ```python |
| | # Assume there are two processes |
| | from accelerate import PartialState |
| | |
| | state = PartialState() |
| | with state.split_between_processes(["A", "B", "C"]) as inputs: |
| | print(inputs) |
| | # Process 0 |
| | ["A", "B"] |
| | # Process 1 |
| | ["C"] |
| | |
| | with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs: |
| | print(inputs) |
| | # Process 0 |
| | ["A", "B"] |
| | # Process 1 |
| | ["C", "C"] |
| | ``` |
| | """ |
| | if self.num_processes == 1: |
| | yield inputs |
| | return |
| | length = len(inputs) |
| | |
| | if isinstance(inputs, dict): |
| | length = len(inputs[list(inputs.keys())[0]]) |
| | if not all(len(v) == length for v in inputs.values()): |
| | raise ValueError("All values in the dictionary must have the same length") |
| | num_samples_per_process, num_extras = divmod(length, self.num_processes) |
| | start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras) |
| | end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0) |
| |
|
| | def _split_values(inputs, start_index, end_index): |
| | if isinstance(inputs, (list, tuple, torch.Tensor)): |
| | if start_index >= len(inputs): |
| | result = inputs[-1:] |
| | else: |
| | result = inputs[start_index:end_index] |
| | if apply_padding: |
| | if isinstance(result, torch.Tensor): |
| | from accelerate.utils import pad_across_processes, send_to_device |
| |
|
| | |
| | tensorized_result = send_to_device(result, self.device) |
| | result = pad_across_processes(tensorized_result, pad_index=inputs[-1]) |
| | else: |
| | result += [result[-1]] * (num_samples_per_process + (1 if num_extras > 0 else 0) - len(result)) |
| | return result |
| | elif isinstance(inputs, dict): |
| | for key in inputs.keys(): |
| | inputs[key] = _split_values(inputs[key], start_index, end_index) |
| | return inputs |
| | else: |
| | if is_datasets_available(): |
| | from datasets import Dataset |
| |
|
| | if isinstance(inputs, Dataset): |
| | if start_index >= len(inputs): |
| | start_index = len(inputs) - 1 |
| | if end_index > len(inputs): |
| | end_index = len(inputs) |
| | result_idcs = list(range(start_index, end_index)) |
| | if apply_padding: |
| | result_idcs += [end_index - 1] * ( |
| | num_samples_per_process + (1 if num_extras > 0 else 0) - len(result_idcs) |
| | ) |
| | return inputs.select(result_idcs) |
| | return inputs |
| |
|
| | yield _split_values(inputs, start_index, end_index) |
| |
|
| | @contextmanager |
| | def main_process_first(self): |
| | """ |
| | Lets the main process go first inside a with block. |
| | |
| | The other processes will enter the with block after the main process exits. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from accelerate import Accelerator |
| | |
| | >>> accelerator = Accelerator() |
| | >>> with accelerator.main_process_first(): |
| | ... # This will be printed first by process 0 then in a seemingly |
| | ... # random order by the other processes. |
| | ... print(f"This will be printed by process {accelerator.process_index}") |
| | ``` |
| | """ |
| | yield from self._goes_first(self.is_main_process) |
| |
|
| | @contextmanager |
| | def local_main_process_first(self): |
| | """ |
| | Lets the local main process go inside a with block. |
| | |
| | The other processes will enter the with block after the main process exits. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from accelerate.state import PartialState |
| | |
| | >>> state = PartialState() |
| | >>> with state.local_main_process_first(): |
| | ... # This will be printed first by local process 0 then in a seemingly |
| | ... # random order by the other processes. |
| | ... print(f"This will be printed by process {state.local_process_index}") |
| | ``` |
| | """ |
| | yield from self._goes_first(self.is_local_main_process) |
| |
|
| | def on_main_process(self, function: Callable[..., Any] | None = None): |
| | """ |
| | Decorator that only runs the decorated function on the main process. |
| | |
| | Args: |
| | function (`Callable`): The function to decorate. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from accelerate.state import PartialState |
| | |
| | >>> state = PartialState() |
| | |
| | |
| | >>> @state.on_main_process |
| | ... def print_something(): |
| | ... print("This will be printed by process 0 only.") |
| | |
| | |
| | >>> print_something() |
| | "This will be printed by process 0 only" |
| | ``` |
| | """ |
| | if not self.initialized: |
| | raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.") |
| | if self.is_main_process or not self.use_distributed: |
| | return function |
| | return do_nothing |
| |
|
| | def on_local_main_process(self, function: Callable[..., Any] | None = None): |
| | """ |
| | Decorator that only runs the decorated function on the local main process. |
| | |
| | Args: |
| | function (`Callable`): The function to decorate. |
| | |
| | Example: |
| | ```python |
| | # Assume we have 2 servers with 4 processes each. |
| | from accelerate.state import PartialState |
| | |
| | state = PartialState() |
| | |
| | |
| | @state.on_local_main_process |
| | def print_something(): |
| | print("This will be printed by process 0 only on each server.") |
| | |
| | |
| | print_something() |
| | # On server 1: |
| | "This will be printed by process 0 only" |
| | # On server 2: |
| | "This will be printed by process 0 only" |
| | ``` |
| | """ |
| | if self.is_local_main_process or not self.use_distributed: |
| | return function |
| | return do_nothing |
| |
|
| | def on_last_process(self, function: Callable[..., Any]): |
| | """ |
| | Decorator that only runs the decorated function on the last process. |
| | |
| | Args: |
| | function (`Callable`): The function to decorate. |
| | |
| | Example: |
| | ```python |
| | # Assume we have 4 processes. |
| | from accelerate.state import PartialState |
| | |
| | state = PartialState() |
| | |
| | |
| | @state.on_last_process |
| | def print_something(): |
| | print(f"Printed on process {state.process_index}") |
| | |
| | |
| | print_something() |
| | "Printed on process 3" |
| | ``` |
| | """ |
| | if self.is_last_process or not self.use_distributed: |
| | return function |
| | return do_nothing |
| |
|
| | def on_process(self, function: Callable[..., Any] | None = None, process_index: int | None = None): |
| | """ |
| | Decorator that only runs the decorated function on the process with the given index. |
| | |
| | Args: |
| | function (`Callable`, `optional`): |
| | The function to decorate. |
| | process_index (`int`, `optional`): |
| | The index of the process on which to run the function. |
| | |
| | Example: |
| | ```python |
| | # Assume we have 4 processes. |
| | from accelerate.state import PartialState |
| | |
| | state = PartialState() |
| | |
| | |
| | @state.on_process(process_index=2) |
| | def print_something(): |
| | print(f"Printed on process {state.process_index}") |
| | |
| | |
| | print_something() |
| | "Printed on process 2" |
| | ``` |
| | """ |
| | if function is None: |
| | return partial(self.on_process, process_index=process_index) |
| | if (self.process_index == process_index) or (not self.use_distributed): |
| | return function |
| | return do_nothing |
| |
|
| | def on_local_process(self, function: Callable[..., Any] | None = None, local_process_index: int | None = None): |
| | """ |
| | Decorator that only runs the decorated function on the process with the given index on the current node. |
| | |
| | Args: |
| | function (`Callable`, *optional*): |
| | The function to decorate. |
| | local_process_index (`int`, *optional*): |
| | The index of the local process on which to run the function. |
| | |
| | Example: |
| | ```python |
| | # Assume we have 2 servers with 4 processes each. |
| | from accelerate import Accelerator |
| | |
| | accelerator = Accelerator() |
| | |
| | |
| | @accelerator.on_local_process(local_process_index=2) |
| | def print_something(): |
| | print(f"Printed on process {accelerator.local_process_index}") |
| | |
| | |
| | print_something() |
| | # On server 1: |
| | "Printed on process 2" |
| | # On server 2: |
| | "Printed on process 2" |
| | ``` |
| | """ |
| | if function is None: |
| | return partial(self.on_local_process, local_process_index=local_process_index) |
| | if (self.local_process_index == local_process_index) or (not self.use_distributed): |
| | return function |
| | return do_nothing |
| |
|
| | def print(self, *args, **kwargs): |
| | if self.is_local_main_process: |
| | print(*args, **kwargs) |
| |
|
| | @property |
| | def default_device(self) -> torch.device: |
| | """ |
| | Returns the default device which is: |
| | - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True. |
| | - CUDA if `torch.cuda.is_available()` |
| | - MLU if `is_mlu_available()` |
| | - SDAA if `is_sdaa_available()` |
| | - MUSA if `is_musa_available()` |
| | - NPU if `is_npu_available()` |
| | - HPU if `is_hpu_available()` |
| | - CPU otherwise |
| | """ |
| | if is_mps_available(): |
| | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| | return torch.device("mps") |
| | elif is_mlu_available(): |
| | return torch.device("mlu") |
| | elif is_sdaa_available(): |
| | return torch.device("sdaa") |
| | elif is_musa_available(): |
| | return torch.device("musa") |
| | |
| | |
| | elif is_npu_available(): |
| | return torch.device("npu") |
| | elif is_hpu_available(): |
| | return torch.device("hpu") |
| | elif torch.cuda.is_available(): |
| | return torch.device("cuda") |
| | elif is_xpu_available(): |
| | return torch.device("xpu") |
| | else: |
| | return torch.device("cpu") |
| |
|
| | def _prepare_backend( |
| | self, cpu: bool = False, sagemaker_dp=False, backend: str | None = None |
| | ) -> tuple[str, DistributedType]: |
| | "Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly" |
| | distributed_type = None |
| | if sagemaker_dp: |
| | import smdistributed.dataparallel.torch.torch_smddp |
| |
|
| | backend = "smddp" |
| | distributed_type = DistributedType.MULTI_GPU |
| | elif is_torch_xla_available(): |
| | backend = "xla" |
| | distributed_type = DistributedType.XLA |
| |
|
| | elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: |
| | if is_mlu_available(): |
| | backend = "cncl" |
| | distributed_type = DistributedType.MULTI_MLU |
| | if is_sdaa_available(): |
| | backend = "tccl" |
| | distributed_type = DistributedType.MULTI_SDAA |
| | elif is_musa_available(): |
| | backend = "mccl" |
| | distributed_type = DistributedType.MULTI_MUSA |
| | |
| | |
| | elif is_npu_available(): |
| | backend = "hccl" |
| | distributed_type = DistributedType.MULTI_NPU |
| | elif is_hpu_available(init_hccl=True): |
| | if backend is None: |
| | backend = "hccl" |
| | distributed_type = DistributedType.MULTI_HPU |
| | elif torch.cuda.is_available(): |
| | if backend is None: |
| | backend = "nccl" |
| | distributed_type = DistributedType.MULTI_GPU |
| | elif is_xpu_available() and is_xccl_available(): |
| | if backend is None: |
| | backend = "xccl" |
| | distributed_type = DistributedType.MULTI_XPU |
| |
|
| | if distributed_type is None and ( |
| | int(os.environ.get("LOCAL_RANK", -1)) != -1 |
| | or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1 |
| | ): |
| | if not cpu and is_xpu_available(): |
| | distributed_type = DistributedType.MULTI_XPU |
| | else: |
| | distributed_type = DistributedType.MULTI_CPU |
| |
|
| | if ( |
| | backend in (None, "ccl") |
| | and is_ccl_available() |
| | and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU) |
| | ): |
| | import oneccl_bindings_for_pytorch |
| |
|
| | backend = "ccl" |
| | elif backend in (None, "mpi") and torch.distributed.is_mpi_available(): |
| | backend = "mpi" |
| | else: |
| | backend = "gloo" |
| | if distributed_type is None: |
| | distributed_type = DistributedType.NO |
| |
|
| | return backend, distributed_type |
| |
|
| | def set_device(self): |
| | """ |
| | Sets the device in `self.device` to the current distributed environment. |
| | """ |
| | if self.device is not None: |
| | return |
| | if self.distributed_type == DistributedType.NO: |
| | self.device = torch.device("cpu") if self._cpu else self.default_device |
| | return |
| | device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower() |
| | if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa"): |
| | raise ValueError( |
| | f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!" |
| | ) |
| | if device == "xla": |
| | self.device = xm.xla_device() |
| | elif device == "hpu": |
| | self.device = torch.device("hpu", torch.hpu.current_device()) |
| | else: |
| | if device == "gpu": |
| | device = "cuda" |
| | device_module = getattr(torch, device) |
| | device_index = self.local_process_index % device_module.device_count() |
| | self.device = torch.device(device, device_index) |
| | device_module.set_device(self.device) |
| |
|
| | def destroy_process_group(self, group=None): |
| | """ |
| | Destroys the process group. If one is not specified, the default process group is destroyed. |
| | """ |
| | if self.fork_launched and group is None: |
| | return |
| | |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.destroy_process_group(group) |
| |
|
| | def __getattr__(self, name: str): |
| | |
| | |
| | if name in self._known_attrs: |
| | raise AttributeError( |
| | f"`PartialState` object has no attribute `{name}`. " |
| | "This happens if `PartialState._reset_state()` was called and " |
| | "an `Accelerator` or `PartialState` was not reinitialized." |
| | ) |
| | |
| | raise AttributeError(f"'PartialState' object has no attribute '{name}'") |
| |
|
| |
|
| | class AcceleratorState: |
| | """ |
| | Singleton class that has information about the current training environment. |
| | |
| | **Available attributes:** |
| | |
| | - **device** (`torch.device`) -- The device to use. |
| | - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently |
| | in use. |
| | - **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the |
| | current training environment. This is used to configure the distributed training environment. |
| | - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`. |
| | - **local_process_index** (`int`) -- The index of the current process on the current server. |
| | - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type |
| | of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8'). |
| | - **num_processes** (`int`) -- The number of processes currently launched in parallel. |
| | - **process_index** (`int`) -- The index of the current process. |
| | - **is_last_process** (`bool`) -- Whether or not the current process is the last one. |
| | - **is_main_process** (`bool`) -- Whether or not the current process is the main one. |
| | - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. |
| | - **debug** (`bool`) -- Whether or not the current script is being run in debug mode. |
| | """ |
| |
|
| | _shared_state = SharedDict() |
| | _known_attrs = PartialState._known_attrs + [ |
| | "deepspeed_plugin", |
| | "use_ipex", |
| | "fsdp_plugin", |
| | "megatron_lm_plugin", |
| | "dynamo_plugin", |
| | ] |
| |
|
| | def __init__( |
| | self, |
| | mixed_precision: str | None = None, |
| | cpu: bool = False, |
| | dynamo_plugin=None, |
| | deepspeed_plugin=None, |
| | fsdp_plugin=None, |
| | torch_tp_plugin=None, |
| | megatron_lm_plugin=None, |
| | parallelism_config=None, |
| | _from_accelerator: bool = False, |
| | **kwargs, |
| | ): |
| | self.__dict__ = self._shared_state |
| | if parse_flag_from_env("ACCELERATE_USE_CPU"): |
| | cpu = True |
| | if PartialState._shared_state == {}: |
| | PartialState(cpu, **kwargs) |
| | self.__dict__.update(PartialState._shared_state) |
| | self._check_initialized(mixed_precision, cpu) |
| | if not self.initialized: |
| | self.deepspeed_plugins = None |
| | self.use_ipex = None |
| | self.torch_tp_plugin = torch_tp_plugin |
| | self.parallelism_config = parallelism_config |
| | self.device_mesh = None |
| | mixed_precision = ( |
| | parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no") |
| | if mixed_precision is None |
| | else mixed_precision.lower() |
| | ) |
| | if mixed_precision == "fp8": |
| | |
| | if not is_fp8_available(): |
| | raise ValueError( |
| | "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed." |
| | ) |
| | elif torch.cuda.is_available() and not check_cuda_fp8_capability(): |
| | logger.warning( |
| | f"The current device has compute capability of {torch.cuda.get_device_capability()} which is " |
| | "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace " |
| | "or higher, compute capability of 8.9 or higher). Will use FP16 instead." |
| | ) |
| | mixed_precision = "fp16" |
| | elif is_habana_gaudi1(): |
| | logger.warning( |
| | "The current HPU device is Gaudi1 which does not support FP8 mixed precision training (requires " |
| | "Gaudi2 or higher). Will use BF16 instead." |
| | ) |
| | mixed_precision = "bf16" |
| |
|
| | self.dynamo_plugin = dynamo_plugin |
| | if not _from_accelerator: |
| | raise ValueError( |
| | "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` " |
| | "before using any functionality from the `accelerate` library." |
| | ) |
| | |
| | |
| | if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8": |
| | self._mixed_precision = "no" |
| | else: |
| | self._mixed_precision = mixed_precision |
| |
|
| | if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True): |
| | if mixed_precision == "bf16": |
| | if os.environ.get("ACCELERATE_DOWNCAST_BF16"): |
| | os.environ["XLA_USE_BF16"] = str(0) |
| | os.environ["XLA_DOWNCAST_BF16"] = str(1) |
| | self.downcast_bfloat = True |
| | else: |
| | os.environ["XLA_USE_BF16"] = str(1) |
| | os.environ["XLA_DOWNCAST_BF16"] = str(0) |
| | self.downcast_bfloat = False |
| | elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true" and not cpu: |
| | self.distributed_type = DistributedType.DEEPSPEED |
| | if not isinstance(deepspeed_plugin, dict): |
| | deepspeed_plugin.set_mixed_precision(mixed_precision) |
| | deepspeed_plugin.select(_from_accelerator_state=True) |
| | else: |
| | for plugin in deepspeed_plugin.values(): |
| | plugin.set_mixed_precision(mixed_precision) |
| | |
| | first_plugin = next(iter(deepspeed_plugin.values())) |
| | first_plugin.select(_from_accelerator_state=True) |
| | self.deepspeed_plugins = deepspeed_plugin |
| | elif self.distributed_type in [ |
| | DistributedType.MULTI_GPU, |
| | DistributedType.MULTI_MLU, |
| | DistributedType.MULTI_SDAA, |
| | DistributedType.MULTI_MUSA, |
| | DistributedType.MULTI_NPU, |
| | DistributedType.MULTI_XPU, |
| | DistributedType.MULTI_HPU, |
| | ]: |
| | |
| | if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true": |
| | if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None: |
| | raise ValueError( |
| | "`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism with `cp_backend=torch`, as we also shard the model across the device mesh to save more memory" |
| | ) |
| | if ( |
| | self.parallelism_config is not None |
| | and self.parallelism_config.cp_enabled |
| | and fsdp_plugin.fsdp_version == 1 |
| | ): |
| | raise ValueError( |
| | "Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. " |
| | ) |
| | if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or ( |
| | self.parallelism_config is not None and self.parallelism_config.cp_enabled |
| | ): |
| | self.distributed_type = DistributedType.FSDP |
| | if self._mixed_precision != "no" and fsdp_plugin is not None: |
| | fsdp_plugin.set_mixed_precision(self._mixed_precision) |
| | self.fsdp_plugin = fsdp_plugin |
| | if os.environ.get( |
| | "ACCELERATE_USE_MEGATRON_LM", "false" |
| | ).lower() == "true" and self.distributed_type not in [ |
| | DistributedType.MULTI_XPU, |
| | ]: |
| | self.distributed_type = DistributedType.MEGATRON_LM |
| | megatron_lm_plugin.set_mixed_precision(self._mixed_precision) |
| | self.megatron_lm_plugin = megatron_lm_plugin |
| | elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]: |
| | if is_ipex_available(): |
| | |
| | self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True) |
| | else: |
| | self.use_ipex = False |
| | if ( |
| | self.dynamo_plugin.backend != DynamoBackend.NO |
| | and self._mixed_precision == "no" |
| | and self.device.type == "cuda" |
| | ): |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | if ( |
| | self.dynamo_plugin.backend != DynamoBackend.NO |
| | and self._mixed_precision == "no" |
| | and self.device.type == "musa" |
| | ): |
| | torch.backends.musa.matmul.allow_tf32 = True |
| | PartialState._shared_state["distributed_type"] = self.distributed_type |
| |
|
| | @property |
| | def initialized(self) -> bool: |
| | return self._shared_state != PartialState._shared_state |
| |
|
| | def __repr__(self): |
| | repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n" |
| | if self.distributed_type == DistributedType.DEEPSPEED: |
| | repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n" |
| | return repr |
| |
|
| | def _check_initialized(self, mixed_precision=None, cpu=None): |
| | "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized" |
| | if self.initialized: |
| | err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`." |
| | if cpu and self.device.type != "cpu": |
| | raise ValueError(err.format(flag="cpu=True")) |
| | if ( |
| | mixed_precision is not None |
| | and mixed_precision != self._mixed_precision |
| | and self.distributed_type != DistributedType.DEEPSPEED |
| | ): |
| | raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'")) |
| |
|
| | @property |
| | def mixed_precision(self): |
| | if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8": |
| | config = self.deepspeed_plugin.deepspeed_config |
| | if config.get("fp16", {}).get("enabled", False): |
| | mixed_precision = "fp16" |
| | elif config.get("bf16", {}).get("enabled", False): |
| | mixed_precision = "bf16" |
| | else: |
| | mixed_precision = "no" |
| | else: |
| | mixed_precision = self._mixed_precision |
| | return mixed_precision |
| |
|
| | @staticmethod |
| | def _reset_state(reset_partial_state: bool = False): |
| | "Resets `_shared_state`, is used internally and should not be called" |
| | AcceleratorState._shared_state.clear() |
| | if reset_partial_state: |
| | PartialState._reset_state() |
| |
|
| | def destroy_process_group(self, group=None): |
| | """ |
| | Destroys the process group. If one is not specified, the default process group is destroyed. |
| | |
| | If `self.fork_launched` is `True` and `group` is `None`, nothing happens. |
| | """ |
| | PartialState().destroy_process_group(group) |
| |
|
| | @property |
| | def fork_launched(self): |
| | return PartialState().fork_launched |
| |
|
| | @property |
| | def use_distributed(self): |
| | """ |
| | Whether the Accelerator is configured for distributed training |
| | """ |
| | return PartialState().use_distributed |
| |
|
| | @property |
| | def is_fsdp2(self) -> bool: |
| | return self.distributed_type == DistributedType.FSDP and self.fsdp_plugin.fsdp_version == 2 |
| |
|
| | @property |
| | def is_last_process(self) -> bool: |
| | "Returns whether the current process is the last one" |
| | return PartialState().is_last_process |
| |
|
| | @property |
| | def is_main_process(self) -> bool: |
| | "Returns whether the current process is the main process" |
| | return PartialState().is_main_process |
| |
|
| | @property |
| | def is_local_main_process(self) -> bool: |
| | "Returns whether the current process is the main process on the local node" |
| | return PartialState().is_local_main_process |
| |
|
| | def wait_for_everyone(self): |
| | PartialState().wait_for_everyone() |
| |
|
| | @contextmanager |
| | def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False): |
| | """ |
| | Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing |
| | distributed inference, such as with different prompts. |
| | |
| | Note that when using a `dict`, all keys need to have the same number of elements. |
| | |
| | Args: |
| | inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`): |
| | The input to split between processes. |
| | apply_padding (`bool`, `optional`, defaults to `False`): |
| | Whether to apply padding by repeating the last element of the input so that all processes have the same |
| | number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing |
| | in less inputs than there are processes. If so, just remember to drop the padded elements afterwards. |
| | |
| | |
| | Example: |
| | |
| | ```python |
| | # Assume there are two processes |
| | from accelerate.state import AcceleratorState |
| | |
| | state = AcceleratorState() |
| | with state.split_between_processes(["A", "B", "C"]) as inputs: |
| | print(inputs) |
| | # Process 0 |
| | ["A", "B"] |
| | # Process 1 |
| | ["C"] |
| | |
| | with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs: |
| | print(inputs) |
| | # Process 0 |
| | ["A", "B"] |
| | # Process 1 |
| | ["C", "C"] |
| | ``` |
| | """ |
| | with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs: |
| | yield inputs |
| |
|
| | @contextmanager |
| | def main_process_first(self): |
| | """ |
| | Lets the main process go first inside a with block. |
| | |
| | The other processes will enter the with block after the main process exits. |
| | """ |
| | with PartialState().main_process_first(): |
| | yield |
| |
|
| | @contextmanager |
| | def local_main_process_first(self): |
| | """ |
| | Lets the local main process go inside a with block. |
| | |
| | The other processes will enter the with block after the main process exits. |
| | """ |
| | with PartialState().local_main_process_first(): |
| | yield |
| |
|
| | @property |
| | def deepspeed_plugin(self): |
| | """ |
| | Returns the currently active DeepSpeedPlugin. |
| | |
| | If not using deepspeed, returns `None`. |
| | """ |
| | |
| | if self.distributed_type != DistributedType.DEEPSPEED: |
| | return None |
| | from accelerate.utils.deepspeed import get_active_deepspeed_plugin |
| |
|
| | return get_active_deepspeed_plugin(self) |
| |
|
| | @deepspeed_required |
| | def get_deepspeed_plugin(self, name: str): |
| | """ |
| | Returns the DeepSpeedPlugin with the given plugin_key. |
| | """ |
| | return self.deepspeed_plugins[name] |
| |
|
| | @deepspeed_required |
| | def select_deepspeed_plugin(self, name: str | None = None): |
| | """ |
| | Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins. |
| | """ |
| | for key, plugin in self.deepspeed_plugins.items(): |
| | if key != name: |
| | plugin._unselect() |
| | self.deepspeed_plugins[name].select(_from_accelerator_state=True) |
| |
|
| | def print(self, *args, **kwargs): |
| | PartialState().print(*args, **kwargs) |
| |
|
| | def __getattr__(self, name: str): |
| | |
| | |
| | if name in self._known_attrs: |
| | raise AttributeError( |
| | f"`AcceleratorState` object has no attribute `{name}`. " |
| | "This happens if `AcceleratorState._reset_state()` was called and " |
| | "an `Accelerator` or `PartialState` was not reinitialized." |
| | ) |
| | |
| | raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'") |
| |
|
| |
|
| | class GradientState: |
| | """ |
| | Singleton class that has information related to gradient synchronization for gradient accumulation |
| | |
| | **Available attributes:** |
| | |
| | - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader |
| | - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader |
| | - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices |
| | - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over |
| | - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are |
| | being iterated over |
| | - **num_steps** (`int`) -- The number of steps to accumulate over |
| | - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient |
| | accumulation |
| | - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader |
| | iteration and the number of total steps reset |
| | - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized |
| | as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently, |
| | after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence |
| | is_xla_gradients_synced is always true. |
| | """ |
| |
|
| | _shared_state = SharedDict() |
| |
|
| | def __init__(self, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None): |
| | self.__dict__ = self._shared_state |
| | if not self.initialized: |
| | self.sync_gradients = True |
| | self._dataloader_references_ref = [None] |
| | self.plugin_kwargs = ( |
| | gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {} |
| | ) |
| | self._is_xla_gradients_synced = False |
| |
|
| | |
| | if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs(): |
| | self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs() |
| |
|
| | @property |
| | def num_steps(self) -> int: |
| | "Returns the number of steps to accumulate over" |
| | return self.plugin_kwargs.get("num_steps", 1) |
| |
|
| | @property |
| | def adjust_scheduler(self) -> bool: |
| | "Returns whether the scheduler should be adjusted" |
| | return self.plugin_kwargs.get("adjust_scheduler", False) |
| |
|
| | @property |
| | def sync_with_dataloader(self) -> bool: |
| | "Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset" |
| | return self.plugin_kwargs.get("sync_with_dataloader", True) |
| |
|
| | @property |
| | def initialized(self) -> bool: |
| | "Returns whether the `GradientState` has been initialized" |
| | return GradientState._shared_state != {} |
| |
|
| | @property |
| | def end_of_dataloader(self) -> bool: |
| | "Returns whether we have reached the end of the current dataloader" |
| | if not self.in_dataloader: |
| | return False |
| | return self.active_dataloader.end_of_dataloader |
| |
|
| | @property |
| | def remainder(self) -> int: |
| | "Returns the number of extra samples that were added from padding the dataloader" |
| | if not self.in_dataloader: |
| | return -1 |
| | return self.active_dataloader.remainder |
| |
|
| | def __repr__(self): |
| | return ( |
| | f"Sync Gradients: {self.sync_gradients}\n" |
| | f"At end of current dataloader: {self.end_of_dataloader}\n" |
| | f"Extra samples added: {self.remainder}\n" |
| | f"Gradient accumulation plugin: {self.plugin_kwargs}\n" |
| | ) |
| |
|
| | @property |
| | def is_xla_gradients_synced(self): |
| | "Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true." |
| | if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False): |
| | return True |
| | return self._is_xla_gradients_synced |
| |
|
| | @is_xla_gradients_synced.setter |
| | def is_xla_gradients_synced(self, is_synced): |
| | "Set the _is_xla_gradients_synced attribute." |
| | self._is_xla_gradients_synced = is_synced |
| |
|
| | def _set_sync_gradients(self, sync_gradients): |
| | "Private function that sets whether gradients should be synchronized. Users should not have to call this." |
| | self.sync_gradients = sync_gradients |
| | |
| | if ( |
| | self.sync_gradients |
| | and is_torch_xla_available(check_is_tpu=True) |
| | and PartialState().distributed_type == DistributedType.XLA |
| | ): |
| | xm.mark_step() |
| |
|
| | def _add_dataloader(self, dataloader): |
| | "Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this." |
| | |
| | |
| | self.dataloader_references += [dataloader] |
| |
|
| | def _remove_dataloader(self, dataloader): |
| | "Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this." |
| | |
| | self.dataloader_references = [ |
| | dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader |
| | ] |
| |
|
| | @property |
| | def active_dataloader(self): |
| | return self.dataloader_references[-1] |
| |
|
| | @property |
| | def dataloader_references(self): |
| | |
| | return [reference() if reference is not None else reference for reference in self._dataloader_references_ref] |
| |
|
| | @dataloader_references.setter |
| | def dataloader_references(self, references): |
| | self._dataloader_references_ref = [ |
| | weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references |
| | ] |
| |
|
| | @property |
| | def in_dataloader(self) -> bool: |
| | "Returns whether the current process is in a dataloader" |
| | return self.active_dataloader is not None |
| |
|
| | @staticmethod |
| | def _reset_state(): |
| | "Resets `_shared_state`, is used internally and should not be called" |
| | GradientState._shared_state.clear() |
| |
|