Prompt48 commited on
Commit
77b5dcc
·
verified ·
1 Parent(s): d1cd809

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\state.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//state.py ADDED
@@ -0,0 +1,1373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ import os
19
+ import threading
20
+ import warnings
21
+ import weakref
22
+ from contextlib import contextmanager
23
+ from functools import partial
24
+ from typing import Any, Callable
25
+
26
+ import torch
27
+
28
+ from .utils import (
29
+ DistributedType,
30
+ DynamoBackend,
31
+ GradientAccumulationPlugin,
32
+ check_cuda_fp8_capability,
33
+ check_cuda_p2p_ib_support,
34
+ deepspeed_required,
35
+ get_cpu_distributed_information,
36
+ get_int_from_env,
37
+ is_ccl_available,
38
+ is_datasets_available,
39
+ is_deepspeed_available,
40
+ is_fp8_available,
41
+ is_habana_gaudi1,
42
+ is_hpu_available,
43
+ is_ipex_available,
44
+ is_mlu_available,
45
+ is_mps_available,
46
+ is_musa_available,
47
+ is_npu_available,
48
+ is_sdaa_available,
49
+ is_torch_xla_available,
50
+ is_xccl_available,
51
+ is_xpu_available,
52
+ parse_choice_from_env,
53
+ parse_flag_from_env,
54
+ set_numa_affinity,
55
+ )
56
+ from .utils.dataclasses import SageMakerDistributedType
57
+
58
+
59
+ if is_torch_xla_available():
60
+ import torch_xla.core.xla_model as xm
61
+ import torch_xla.runtime as xr
62
+
63
+ if is_mlu_available(check_device=False):
64
+ import torch_mlu # noqa: F401
65
+
66
+ if is_sdaa_available(check_device=False):
67
+ import torch_sdaa # noqa: F401
68
+
69
+ if is_musa_available(check_device=False):
70
+ import torch_musa # noqa: F401
71
+
72
+ if is_npu_available(check_device=False):
73
+ import torch_npu # noqa: F401
74
+
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+
79
+ def is_initialized() -> bool:
80
+ """
81
+ Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
82
+ but works as a module method.
83
+ """
84
+ return AcceleratorState._shared_state != {}
85
+
86
+
87
+ # Lambda function that does nothing
88
+ def do_nothing(*args, **kwargs):
89
+ return None
90
+
91
+
92
+ class ThreadLocalSharedDict(threading.local):
93
+ """
94
+ Descriptor that holds a dict shared between instances of a class in the same thread.
95
+
96
+ Note: Descriptors have slightly different semantics than just a dict field on its own.
97
+ `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
98
+ underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
99
+ the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
100
+ object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
101
+
102
+ See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
103
+
104
+ This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
105
+
106
+ See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
107
+ """
108
+
109
+ def __init__(self, thread_local: bool = False):
110
+ self._storage = {}
111
+
112
+ def __get__(self, obj, objtype=None):
113
+ return self._storage
114
+
115
+ def __set__(self, obj, value):
116
+ self._storage = value
117
+
118
+
119
+ # Prefer global shared dictionary, except when using TPU.
120
+ SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
121
+
122
+
123
+ # Inspired by Alex Martelli's 'Borg'.
124
+ class PartialState:
125
+ """
126
+ Singleton class that has information about the current training environment and functions to help with process
127
+ control. Designed to be used when only process control and device execution states are needed. Does *not* need to
128
+ be initialized from `Accelerator`.
129
+
130
+ Args:
131
+ cpu (`bool`, *optional*):
132
+ Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
133
+ `True` and force the execution on the CPU.
134
+ kwargs (additional keyword arguments, *optional*):
135
+ Additional keyword arguments to pass to the relevant `init_process_group` function. Valid `kwargs` can be
136
+ found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
137
+
138
+ **Available attributes:**
139
+
140
+ - **device** (`torch.device`) -- The device to use.
141
+ - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
142
+ in use.
143
+ - **local_process_index** (`int`) -- The index of the current process on the current server.
144
+ - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
145
+ of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
146
+ - **num_processes** (`int`) -- The number of processes currently launched in parallel.
147
+ - **process_index** (`int`) -- The index of the current process.
148
+ - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
149
+ - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
150
+ - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
151
+ - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
152
+
153
+ Example:
154
+ ```python
155
+ from accelerate.utils import InitProcessGroupKwargs
156
+
157
+ # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`
158
+ kwargs = InitProcessGroupKwargs(...).to_kwargs()
159
+ state = PartialState(**kwargs)
160
+ ```
161
+ """
162
+
163
+ _shared_state = SharedDict()
164
+ _known_attrs = [
165
+ "_cpu",
166
+ "_mixed_precision",
167
+ "_shared_state",
168
+ "backend",
169
+ "debug",
170
+ "device",
171
+ "distributed_type",
172
+ "fork_launched",
173
+ "local_process_index",
174
+ "num_processes",
175
+ "process_index",
176
+ ]
177
+
178
+ def __init__(self, cpu: bool = False, **kwargs):
179
+ self.__dict__ = self._shared_state
180
+ if not self.initialized:
181
+ self._cpu = cpu
182
+ self.backend = None
183
+ env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
184
+ self.device = torch.device(env_device) if env_device is not None else None
185
+ self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
186
+ use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
187
+ dist_information = None
188
+ if use_sagemaker_dp is None:
189
+ use_sagemaker_dp = (
190
+ os.environ.get("ACCELERATE_USE_SAGEMAKER", "false").lower() == "true"
191
+ and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
192
+ )
193
+
194
+ # Sets up self.backend + imports
195
+ original_backend = kwargs.pop("backend", None)
196
+ backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
197
+ if original_backend is not None and backend != original_backend:
198
+ raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}")
199
+ self.backend = backend
200
+ self.distributed_type = distributed_type
201
+ use_deepspeed = False
202
+ if not cpu and self.backend != "xla":
203
+ if int(os.environ.get("LOCAL_RANK", -1)) != -1:
204
+ # Deal with spawning deepspeed
205
+ if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true":
206
+ if not is_deepspeed_available():
207
+ raise ImportError(
208
+ "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
209
+ )
210
+ from deepspeed import comm as dist
211
+
212
+ if not dist.is_initialized():
213
+ if self.backend == "tccl":
214
+ local_rank = os.environ.get("LOCAL_RANK", -1)
215
+ torch.sdaa.set_device(f"sdaa:{local_rank}")
216
+ dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
217
+ # We need to flag to `use_deepspeed` to be True to override `distributed_type` later
218
+ use_deepspeed = True
219
+ # Deal with all other backends but XPU and CPU, that gets handled special later
220
+ elif (
221
+ self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
222
+ and not torch.distributed.is_initialized()
223
+ ):
224
+ if self.backend == "tccl":
225
+ local_rank = os.environ.get("LOCAL_RANK", -1)
226
+ torch.sdaa.set_device(f"sdaa:{local_rank}")
227
+ if (
228
+ self.backend == "nccl"
229
+ and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
230
+ and (
231
+ os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true"
232
+ or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
233
+ )
234
+ ):
235
+ self.backend = "cuda:nccl,cpu:gloo"
236
+ torch.distributed.init_process_group(backend=self.backend, **kwargs)
237
+
238
+ # XPU and CPU require special env configs to be set
239
+ if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
240
+ dist_information = get_cpu_distributed_information()
241
+ os.environ["RANK"] = str(dist_information.rank)
242
+ os.environ["WORLD_SIZE"] = str(dist_information.world_size)
243
+ os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
244
+ os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
245
+ if not os.environ.get("MASTER_PORT", None):
246
+ os.environ["MASTER_PORT"] = "29500"
247
+ if (
248
+ not os.environ.get("MASTER_ADDR", None)
249
+ and dist_information.local_world_size != dist_information.world_size
250
+ and self.backend != "mpi"
251
+ ):
252
+ raise ValueError(
253
+ "Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, "
254
+ "please try exporting rank 0's hostname as `MASTER_ADDR`"
255
+ )
256
+ kwargs["rank"] = dist_information.rank
257
+ kwargs["world_size"] = dist_information.world_size
258
+
259
+ if (
260
+ self.distributed_type == DistributedType.MULTI_CPU
261
+ and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0
262
+ ):
263
+ import psutil
264
+
265
+ num_cpu_threads_per_process = int(
266
+ psutil.cpu_count(logical=False) / dist_information.local_world_size
267
+ )
268
+ if num_cpu_threads_per_process == 0:
269
+ num_cpu_threads_per_process = 1
270
+ torch.set_num_threads(num_cpu_threads_per_process)
271
+ warnings.warn(
272
+ f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob"
273
+ " performance."
274
+ )
275
+
276
+ if not torch.distributed.is_initialized():
277
+ torch.distributed.init_process_group(backend=self.backend, **kwargs)
278
+
279
+ # No backend == no distributed training
280
+ if self.backend is None:
281
+ self.distributed_type = DistributedType.NO
282
+ self.num_processes = 1
283
+ self.process_index = 0
284
+ self.local_process_index = 0
285
+ elif self.backend == "xla":
286
+ # XLA needs device setting first for `set_replication`
287
+ self.set_device()
288
+ xm.set_replication(self.device, xm.get_xla_supported_devices())
289
+ self.num_processes = xr.world_size()
290
+ self.process_index = xr.global_ordinal()
291
+ if is_torch_xla_available(check_is_tpu=True):
292
+ self.local_process_index = xm.get_local_ordinal()
293
+ else:
294
+ self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
295
+ else:
296
+ self.num_processes = torch.distributed.get_world_size()
297
+ self.process_index = torch.distributed.get_rank()
298
+ self.local_process_index = (
299
+ int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
300
+ )
301
+ self.set_device()
302
+ # Now we can change to deepseed
303
+ if use_deepspeed:
304
+ self.distributed_type = DistributedType.DEEPSPEED
305
+
306
+ # Set CPU affinity if enabled
307
+ if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
308
+ set_numa_affinity(self.local_process_index)
309
+
310
+ # Check for old RTX 4000's that can't use P2P or IB and are on old drivers
311
+ if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
312
+ if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
313
+ raise NotImplementedError(
314
+ "Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
315
+ 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
316
+ "will do this automatically."
317
+ )
318
+
319
+ # Important: This should be the *only* code outside of `self.initialized!`
320
+ self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
321
+
322
+ def __repr__(self) -> str:
323
+ return (
324
+ f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
325
+ f"Num processes: {self.num_processes}\n"
326
+ f"Process index: {self.process_index}\n"
327
+ f"Local process index: {self.local_process_index}\n"
328
+ f"Device: {self.device}\n"
329
+ )
330
+
331
+ @staticmethod
332
+ def _reset_state():
333
+ "Resets `_shared_state`, is used internally and should not be called"
334
+ PartialState._shared_state.clear()
335
+
336
+ @property
337
+ def initialized(self) -> bool:
338
+ "Returns whether the `PartialState` has been initialized"
339
+ return self._shared_state != {}
340
+
341
+ @property
342
+ def use_distributed(self):
343
+ """
344
+ Whether the Accelerator is configured for distributed training
345
+ """
346
+ return self.distributed_type != DistributedType.NO and self.num_processes > 1
347
+
348
+ @property
349
+ def is_last_process(self) -> bool:
350
+ "Returns whether the current process is the last one"
351
+ return self.process_index == self.num_processes - 1
352
+
353
+ @property
354
+ def is_main_process(self) -> bool:
355
+ "Returns whether the current process is the main process"
356
+ return (
357
+ self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process
358
+ )
359
+
360
+ @property
361
+ def is_local_main_process(self) -> bool:
362
+ "Returns whether the current process is the main process on the local node"
363
+ return (
364
+ self.local_process_index == 0
365
+ if self.distributed_type != DistributedType.MEGATRON_LM
366
+ else self.is_last_process
367
+ )
368
+
369
+ def wait_for_everyone(self):
370
+ """
371
+ Will stop the execution of the current process until every other process has reached that point (so this does
372
+ nothing when the script is only run in one process). Useful to do before saving a model.
373
+
374
+ Example:
375
+
376
+ ```python
377
+ >>> # Assuming two GPU processes
378
+ >>> import time
379
+ >>> from accelerate.state import PartialState
380
+
381
+ >>> state = PartialState()
382
+ >>> if state.is_main_process:
383
+ ... time.sleep(2)
384
+ >>> else:
385
+ ... print("I'm waiting for the main process to finish its sleep...")
386
+ >>> state.wait_for_everyone()
387
+ >>> # Should print on every process at the same time
388
+ >>> print("Everyone is here")
389
+ ```
390
+ """
391
+ if self.distributed_type in (
392
+ DistributedType.MULTI_GPU,
393
+ DistributedType.MULTI_MLU,
394
+ DistributedType.MULTI_SDAA,
395
+ DistributedType.MULTI_MUSA,
396
+ DistributedType.MULTI_NPU,
397
+ DistributedType.MULTI_XPU,
398
+ DistributedType.MULTI_CPU,
399
+ DistributedType.MULTI_HPU,
400
+ DistributedType.DEEPSPEED,
401
+ DistributedType.FSDP,
402
+ ):
403
+ torch.distributed.barrier(device_ids=[self.local_process_index])
404
+ elif self.distributed_type == DistributedType.XLA:
405
+ xm.rendezvous("accelerate.utils.wait_for_everyone")
406
+
407
+ def _goes_first(self, is_main: bool):
408
+ if not is_main:
409
+ self.wait_for_everyone()
410
+
411
+ yield
412
+
413
+ if is_main:
414
+ self.wait_for_everyone()
415
+
416
+ @contextmanager
417
+ def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
418
+ """
419
+ Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
420
+ distributed inference, such as with different prompts.
421
+
422
+ Note that when using a `dict`, all keys need to have the same number of elements.
423
+
424
+ Args:
425
+ inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
426
+ The input to split between processes.
427
+ apply_padding (`bool`, `optional`, defaults to `False`):
428
+ Whether to apply padding by repeating the last element of the input so that all processes have the same
429
+ number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
430
+ in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
431
+
432
+
433
+ Example:
434
+
435
+ ```python
436
+ # Assume there are two processes
437
+ from accelerate import PartialState
438
+
439
+ state = PartialState()
440
+ with state.split_between_processes(["A", "B", "C"]) as inputs:
441
+ print(inputs)
442
+ # Process 0
443
+ ["A", "B"]
444
+ # Process 1
445
+ ["C"]
446
+
447
+ with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
448
+ print(inputs)
449
+ # Process 0
450
+ ["A", "B"]
451
+ # Process 1
452
+ ["C", "C"]
453
+ ```
454
+ """
455
+ if self.num_processes == 1:
456
+ yield inputs
457
+ return
458
+ length = len(inputs)
459
+ # Nested dictionary of any types
460
+ if isinstance(inputs, dict):
461
+ length = len(inputs[list(inputs.keys())[0]])
462
+ if not all(len(v) == length for v in inputs.values()):
463
+ raise ValueError("All values in the dictionary must have the same length")
464
+ num_samples_per_process, num_extras = divmod(length, self.num_processes)
465
+ start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
466
+ end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
467
+
468
+ def _split_values(inputs, start_index, end_index):
469
+ if isinstance(inputs, (list, tuple, torch.Tensor)):
470
+ if start_index >= len(inputs):
471
+ result = inputs[-1:]
472
+ else:
473
+ result = inputs[start_index:end_index]
474
+ if apply_padding:
475
+ if isinstance(result, torch.Tensor):
476
+ from accelerate.utils import pad_across_processes, send_to_device
477
+
478
+ # The tensor needs to be on the device before we can pad it
479
+ tensorized_result = send_to_device(result, self.device)
480
+ result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
481
+ else:
482
+ result += [result[-1]] * (num_samples_per_process + (1 if num_extras > 0 else 0) - len(result))
483
+ return result
484
+ elif isinstance(inputs, dict):
485
+ for key in inputs.keys():
486
+ inputs[key] = _split_values(inputs[key], start_index, end_index)
487
+ return inputs
488
+ else:
489
+ if is_datasets_available():
490
+ from datasets import Dataset
491
+
492
+ if isinstance(inputs, Dataset):
493
+ if start_index >= len(inputs):
494
+ start_index = len(inputs) - 1
495
+ if end_index > len(inputs):
496
+ end_index = len(inputs)
497
+ result_idcs = list(range(start_index, end_index))
498
+ if apply_padding:
499
+ result_idcs += [end_index - 1] * (
500
+ num_samples_per_process + (1 if num_extras > 0 else 0) - len(result_idcs)
501
+ )
502
+ return inputs.select(result_idcs)
503
+ return inputs
504
+
505
+ yield _split_values(inputs, start_index, end_index)
506
+
507
+ @contextmanager
508
+ def main_process_first(self):
509
+ """
510
+ Lets the main process go first inside a with block.
511
+
512
+ The other processes will enter the with block after the main process exits.
513
+
514
+ Example:
515
+
516
+ ```python
517
+ >>> from accelerate import Accelerator
518
+
519
+ >>> accelerator = Accelerator()
520
+ >>> with accelerator.main_process_first():
521
+ ... # This will be printed first by process 0 then in a seemingly
522
+ ... # random order by the other processes.
523
+ ... print(f"This will be printed by process {accelerator.process_index}")
524
+ ```
525
+ """
526
+ yield from self._goes_first(self.is_main_process)
527
+
528
+ @contextmanager
529
+ def local_main_process_first(self):
530
+ """
531
+ Lets the local main process go inside a with block.
532
+
533
+ The other processes will enter the with block after the main process exits.
534
+
535
+ Example:
536
+
537
+ ```python
538
+ >>> from accelerate.state import PartialState
539
+
540
+ >>> state = PartialState()
541
+ >>> with state.local_main_process_first():
542
+ ... # This will be printed first by local process 0 then in a seemingly
543
+ ... # random order by the other processes.
544
+ ... print(f"This will be printed by process {state.local_process_index}")
545
+ ```
546
+ """
547
+ yield from self._goes_first(self.is_local_main_process)
548
+
549
+ def on_main_process(self, function: Callable[..., Any] | None = None):
550
+ """
551
+ Decorator that only runs the decorated function on the main process.
552
+
553
+ Args:
554
+ function (`Callable`): The function to decorate.
555
+
556
+ Example:
557
+
558
+ ```python
559
+ >>> from accelerate.state import PartialState
560
+
561
+ >>> state = PartialState()
562
+
563
+
564
+ >>> @state.on_main_process
565
+ ... def print_something():
566
+ ... print("This will be printed by process 0 only.")
567
+
568
+
569
+ >>> print_something()
570
+ "This will be printed by process 0 only"
571
+ ```
572
+ """
573
+ if not self.initialized:
574
+ raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.")
575
+ if self.is_main_process or not self.use_distributed:
576
+ return function
577
+ return do_nothing
578
+
579
+ def on_local_main_process(self, function: Callable[..., Any] | None = None):
580
+ """
581
+ Decorator that only runs the decorated function on the local main process.
582
+
583
+ Args:
584
+ function (`Callable`): The function to decorate.
585
+
586
+ Example:
587
+ ```python
588
+ # Assume we have 2 servers with 4 processes each.
589
+ from accelerate.state import PartialState
590
+
591
+ state = PartialState()
592
+
593
+
594
+ @state.on_local_main_process
595
+ def print_something():
596
+ print("This will be printed by process 0 only on each server.")
597
+
598
+
599
+ print_something()
600
+ # On server 1:
601
+ "This will be printed by process 0 only"
602
+ # On server 2:
603
+ "This will be printed by process 0 only"
604
+ ```
605
+ """
606
+ if self.is_local_main_process or not self.use_distributed:
607
+ return function
608
+ return do_nothing
609
+
610
+ def on_last_process(self, function: Callable[..., Any]):
611
+ """
612
+ Decorator that only runs the decorated function on the last process.
613
+
614
+ Args:
615
+ function (`Callable`): The function to decorate.
616
+
617
+ Example:
618
+ ```python
619
+ # Assume we have 4 processes.
620
+ from accelerate.state import PartialState
621
+
622
+ state = PartialState()
623
+
624
+
625
+ @state.on_last_process
626
+ def print_something():
627
+ print(f"Printed on process {state.process_index}")
628
+
629
+
630
+ print_something()
631
+ "Printed on process 3"
632
+ ```
633
+ """
634
+ if self.is_last_process or not self.use_distributed:
635
+ return function
636
+ return do_nothing
637
+
638
+ def on_process(self, function: Callable[..., Any] | None = None, process_index: int | None = None):
639
+ """
640
+ Decorator that only runs the decorated function on the process with the given index.
641
+
642
+ Args:
643
+ function (`Callable`, `optional`):
644
+ The function to decorate.
645
+ process_index (`int`, `optional`):
646
+ The index of the process on which to run the function.
647
+
648
+ Example:
649
+ ```python
650
+ # Assume we have 4 processes.
651
+ from accelerate.state import PartialState
652
+
653
+ state = PartialState()
654
+
655
+
656
+ @state.on_process(process_index=2)
657
+ def print_something():
658
+ print(f"Printed on process {state.process_index}")
659
+
660
+
661
+ print_something()
662
+ "Printed on process 2"
663
+ ```
664
+ """
665
+ if function is None:
666
+ return partial(self.on_process, process_index=process_index)
667
+ if (self.process_index == process_index) or (not self.use_distributed):
668
+ return function
669
+ return do_nothing
670
+
671
+ def on_local_process(self, function: Callable[..., Any] | None = None, local_process_index: int | None = None):
672
+ """
673
+ Decorator that only runs the decorated function on the process with the given index on the current node.
674
+
675
+ Args:
676
+ function (`Callable`, *optional*):
677
+ The function to decorate.
678
+ local_process_index (`int`, *optional*):
679
+ The index of the local process on which to run the function.
680
+
681
+ Example:
682
+ ```python
683
+ # Assume we have 2 servers with 4 processes each.
684
+ from accelerate import Accelerator
685
+
686
+ accelerator = Accelerator()
687
+
688
+
689
+ @accelerator.on_local_process(local_process_index=2)
690
+ def print_something():
691
+ print(f"Printed on process {accelerator.local_process_index}")
692
+
693
+
694
+ print_something()
695
+ # On server 1:
696
+ "Printed on process 2"
697
+ # On server 2:
698
+ "Printed on process 2"
699
+ ```
700
+ """
701
+ if function is None:
702
+ return partial(self.on_local_process, local_process_index=local_process_index)
703
+ if (self.local_process_index == local_process_index) or (not self.use_distributed):
704
+ return function
705
+ return do_nothing
706
+
707
+ def print(self, *args, **kwargs):
708
+ if self.is_local_main_process:
709
+ print(*args, **kwargs)
710
+
711
+ @property
712
+ def default_device(self) -> torch.device:
713
+ """
714
+ Returns the default device which is:
715
+ - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
716
+ - CUDA if `torch.cuda.is_available()`
717
+ - MLU if `is_mlu_available()`
718
+ - SDAA if `is_sdaa_available()`
719
+ - MUSA if `is_musa_available()`
720
+ - NPU if `is_npu_available()`
721
+ - HPU if `is_hpu_available()`
722
+ - CPU otherwise
723
+ """
724
+ if is_mps_available():
725
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
726
+ return torch.device("mps")
727
+ elif is_mlu_available():
728
+ return torch.device("mlu")
729
+ elif is_sdaa_available():
730
+ return torch.device("sdaa")
731
+ elif is_musa_available():
732
+ return torch.device("musa")
733
+ # NPU should be checked before CUDA when using `transfer_to_npu`
734
+ # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
735
+ elif is_npu_available():
736
+ return torch.device("npu")
737
+ elif is_hpu_available():
738
+ return torch.device("hpu")
739
+ elif torch.cuda.is_available():
740
+ return torch.device("cuda")
741
+ elif is_xpu_available():
742
+ return torch.device("xpu")
743
+ else:
744
+ return torch.device("cpu")
745
+
746
+ def _prepare_backend(
747
+ self, cpu: bool = False, sagemaker_dp=False, backend: str | None = None
748
+ ) -> tuple[str, DistributedType]:
749
+ "Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
750
+ distributed_type = None
751
+ if sagemaker_dp:
752
+ import smdistributed.dataparallel.torch.torch_smddp # noqa
753
+
754
+ backend = "smddp"
755
+ distributed_type = DistributedType.MULTI_GPU
756
+ elif is_torch_xla_available():
757
+ backend = "xla"
758
+ distributed_type = DistributedType.XLA
759
+
760
+ elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
761
+ if is_mlu_available():
762
+ backend = "cncl"
763
+ distributed_type = DistributedType.MULTI_MLU
764
+ if is_sdaa_available():
765
+ backend = "tccl"
766
+ distributed_type = DistributedType.MULTI_SDAA
767
+ elif is_musa_available():
768
+ backend = "mccl"
769
+ distributed_type = DistributedType.MULTI_MUSA
770
+ # NPU should be checked before CUDA when using `transfer_to_npu`
771
+ # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
772
+ elif is_npu_available():
773
+ backend = "hccl"
774
+ distributed_type = DistributedType.MULTI_NPU
775
+ elif is_hpu_available(init_hccl=True):
776
+ if backend is None:
777
+ backend = "hccl"
778
+ distributed_type = DistributedType.MULTI_HPU
779
+ elif torch.cuda.is_available():
780
+ if backend is None:
781
+ backend = "nccl"
782
+ distributed_type = DistributedType.MULTI_GPU
783
+ elif is_xpu_available() and is_xccl_available():
784
+ if backend is None:
785
+ backend = "xccl"
786
+ distributed_type = DistributedType.MULTI_XPU
787
+
788
+ if distributed_type is None and (
789
+ int(os.environ.get("LOCAL_RANK", -1)) != -1
790
+ or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
791
+ ):
792
+ if not cpu and is_xpu_available():
793
+ distributed_type = DistributedType.MULTI_XPU
794
+ else:
795
+ distributed_type = DistributedType.MULTI_CPU
796
+
797
+ if (
798
+ backend in (None, "ccl")
799
+ and is_ccl_available()
800
+ and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
801
+ ):
802
+ import oneccl_bindings_for_pytorch # noqa: F401
803
+
804
+ backend = "ccl"
805
+ elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
806
+ backend = "mpi"
807
+ else:
808
+ backend = "gloo"
809
+ if distributed_type is None:
810
+ distributed_type = DistributedType.NO
811
+
812
+ return backend, distributed_type
813
+
814
+ def set_device(self):
815
+ """
816
+ Sets the device in `self.device` to the current distributed environment.
817
+ """
818
+ if self.device is not None:
819
+ return
820
+ if self.distributed_type == DistributedType.NO:
821
+ self.device = torch.device("cpu") if self._cpu else self.default_device
822
+ return
823
+ device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
824
+ if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa"):
825
+ raise ValueError(
826
+ f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
827
+ )
828
+ if device == "xla":
829
+ self.device = xm.xla_device()
830
+ elif device == "hpu":
831
+ self.device = torch.device("hpu", torch.hpu.current_device())
832
+ else:
833
+ if device == "gpu":
834
+ device = "cuda"
835
+ device_module = getattr(torch, device)
836
+ device_index = self.local_process_index % device_module.device_count()
837
+ self.device = torch.device(device, device_index)
838
+ device_module.set_device(self.device)
839
+
840
+ def destroy_process_group(self, group=None):
841
+ """
842
+ Destroys the process group. If one is not specified, the default process group is destroyed.
843
+ """
844
+ if self.fork_launched and group is None:
845
+ return
846
+ # needed when using torch.distributed.init_process_group
847
+ if torch.distributed.is_initialized():
848
+ torch.distributed.destroy_process_group(group)
849
+
850
+ def __getattr__(self, name: str):
851
+ # By this point we know that no attributes of `self` contain `name`,
852
+ # so we just modify the error message
853
+ if name in self._known_attrs:
854
+ raise AttributeError(
855
+ f"`PartialState` object has no attribute `{name}`. "
856
+ "This happens if `PartialState._reset_state()` was called and "
857
+ "an `Accelerator` or `PartialState` was not reinitialized."
858
+ )
859
+ # Raise a typical AttributeError
860
+ raise AttributeError(f"'PartialState' object has no attribute '{name}'")
861
+
862
+
863
+ class AcceleratorState:
864
+ """
865
+ Singleton class that has information about the current training environment.
866
+
867
+ **Available attributes:**
868
+
869
+ - **device** (`torch.device`) -- The device to use.
870
+ - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
871
+ in use.
872
+ - **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the
873
+ current training environment. This is used to configure the distributed training environment.
874
+ - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
875
+ - **local_process_index** (`int`) -- The index of the current process on the current server.
876
+ - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
877
+ of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
878
+ - **num_processes** (`int`) -- The number of processes currently launched in parallel.
879
+ - **process_index** (`int`) -- The index of the current process.
880
+ - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
881
+ - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
882
+ - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
883
+ - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
884
+ """
885
+
886
+ _shared_state = SharedDict()
887
+ _known_attrs = PartialState._known_attrs + [
888
+ "deepspeed_plugin",
889
+ "use_ipex",
890
+ "fsdp_plugin",
891
+ "megatron_lm_plugin",
892
+ "dynamo_plugin",
893
+ ]
894
+
895
+ def __init__(
896
+ self,
897
+ mixed_precision: str | None = None,
898
+ cpu: bool = False,
899
+ dynamo_plugin=None,
900
+ deepspeed_plugin=None,
901
+ fsdp_plugin=None,
902
+ torch_tp_plugin=None,
903
+ megatron_lm_plugin=None,
904
+ parallelism_config=None,
905
+ _from_accelerator: bool = False,
906
+ **kwargs,
907
+ ):
908
+ self.__dict__ = self._shared_state
909
+ if parse_flag_from_env("ACCELERATE_USE_CPU"):
910
+ cpu = True
911
+ if PartialState._shared_state == {}:
912
+ PartialState(cpu, **kwargs)
913
+ self.__dict__.update(PartialState._shared_state)
914
+ self._check_initialized(mixed_precision, cpu)
915
+ if not self.initialized:
916
+ self.deepspeed_plugins = None
917
+ self.use_ipex = None
918
+ self.torch_tp_plugin = torch_tp_plugin
919
+ self.parallelism_config = parallelism_config
920
+ self.device_mesh = None
921
+ mixed_precision = (
922
+ parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
923
+ if mixed_precision is None
924
+ else mixed_precision.lower()
925
+ )
926
+ if mixed_precision == "fp8":
927
+ # this is confusing, why is is_fp8_available only checks for library availability ?
928
+ if not is_fp8_available():
929
+ raise ValueError(
930
+ "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
931
+ )
932
+ elif torch.cuda.is_available() and not check_cuda_fp8_capability():
933
+ logger.warning(
934
+ f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
935
+ "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
936
+ "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
937
+ )
938
+ mixed_precision = "fp16"
939
+ elif is_habana_gaudi1():
940
+ logger.warning(
941
+ "The current HPU device is Gaudi1 which does not support FP8 mixed precision training (requires "
942
+ "Gaudi2 or higher). Will use BF16 instead."
943
+ )
944
+ mixed_precision = "bf16"
945
+
946
+ self.dynamo_plugin = dynamo_plugin
947
+ if not _from_accelerator:
948
+ raise ValueError(
949
+ "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
950
+ "before using any functionality from the `accelerate` library."
951
+ )
952
+ # deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
953
+ # if we're using fp8.
954
+ if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
955
+ self._mixed_precision = "no"
956
+ else:
957
+ self._mixed_precision = mixed_precision
958
+
959
+ if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
960
+ if mixed_precision == "bf16":
961
+ if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
962
+ os.environ["XLA_USE_BF16"] = str(0)
963
+ os.environ["XLA_DOWNCAST_BF16"] = str(1)
964
+ self.downcast_bfloat = True
965
+ else:
966
+ os.environ["XLA_USE_BF16"] = str(1)
967
+ os.environ["XLA_DOWNCAST_BF16"] = str(0)
968
+ self.downcast_bfloat = False
969
+ elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true" and not cpu:
970
+ self.distributed_type = DistributedType.DEEPSPEED
971
+ if not isinstance(deepspeed_plugin, dict):
972
+ deepspeed_plugin.set_mixed_precision(mixed_precision)
973
+ deepspeed_plugin.select(_from_accelerator_state=True)
974
+ else:
975
+ for plugin in deepspeed_plugin.values():
976
+ plugin.set_mixed_precision(mixed_precision)
977
+ # The first plugin passed in is always the active one
978
+ first_plugin = next(iter(deepspeed_plugin.values()))
979
+ first_plugin.select(_from_accelerator_state=True)
980
+ self.deepspeed_plugins = deepspeed_plugin
981
+ elif self.distributed_type in [
982
+ DistributedType.MULTI_GPU,
983
+ DistributedType.MULTI_MLU,
984
+ DistributedType.MULTI_SDAA,
985
+ DistributedType.MULTI_MUSA,
986
+ DistributedType.MULTI_NPU,
987
+ DistributedType.MULTI_XPU,
988
+ DistributedType.MULTI_HPU,
989
+ ]:
990
+ # TODO: Siro - remove when axolotl fixes their side
991
+ if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true":
992
+ if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
993
+ raise ValueError(
994
+ "`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism with `cp_backend=torch`, as we also shard the model across the device mesh to save more memory"
995
+ )
996
+ if (
997
+ self.parallelism_config is not None
998
+ and self.parallelism_config.cp_enabled
999
+ and fsdp_plugin.fsdp_version == 1
1000
+ ):
1001
+ raise ValueError(
1002
+ "Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
1003
+ )
1004
+ if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
1005
+ self.parallelism_config is not None and self.parallelism_config.cp_enabled
1006
+ ):
1007
+ self.distributed_type = DistributedType.FSDP
1008
+ if self._mixed_precision != "no" and fsdp_plugin is not None:
1009
+ fsdp_plugin.set_mixed_precision(self._mixed_precision)
1010
+ self.fsdp_plugin = fsdp_plugin
1011
+ if os.environ.get(
1012
+ "ACCELERATE_USE_MEGATRON_LM", "false"
1013
+ ).lower() == "true" and self.distributed_type not in [
1014
+ DistributedType.MULTI_XPU,
1015
+ ]:
1016
+ self.distributed_type = DistributedType.MEGATRON_LM
1017
+ megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
1018
+ self.megatron_lm_plugin = megatron_lm_plugin
1019
+ elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
1020
+ if is_ipex_available():
1021
+ # check if user disables it explicitly
1022
+ self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True)
1023
+ else:
1024
+ self.use_ipex = False
1025
+ if (
1026
+ self.dynamo_plugin.backend != DynamoBackend.NO
1027
+ and self._mixed_precision == "no"
1028
+ and self.device.type == "cuda"
1029
+ ):
1030
+ torch.backends.cuda.matmul.allow_tf32 = True
1031
+ if (
1032
+ self.dynamo_plugin.backend != DynamoBackend.NO
1033
+ and self._mixed_precision == "no"
1034
+ and self.device.type == "musa"
1035
+ ):
1036
+ torch.backends.musa.matmul.allow_tf32 = True
1037
+ PartialState._shared_state["distributed_type"] = self.distributed_type
1038
+
1039
+ @property
1040
+ def initialized(self) -> bool:
1041
+ return self._shared_state != PartialState._shared_state
1042
+
1043
+ def __repr__(self):
1044
+ repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
1045
+ if self.distributed_type == DistributedType.DEEPSPEED:
1046
+ repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
1047
+ return repr
1048
+
1049
+ def _check_initialized(self, mixed_precision=None, cpu=None):
1050
+ "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
1051
+ if self.initialized:
1052
+ err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
1053
+ if cpu and self.device.type != "cpu":
1054
+ raise ValueError(err.format(flag="cpu=True"))
1055
+ if (
1056
+ mixed_precision is not None
1057
+ and mixed_precision != self._mixed_precision
1058
+ and self.distributed_type != DistributedType.DEEPSPEED
1059
+ ):
1060
+ raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'"))
1061
+
1062
+ @property
1063
+ def mixed_precision(self):
1064
+ if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
1065
+ config = self.deepspeed_plugin.deepspeed_config
1066
+ if config.get("fp16", {}).get("enabled", False):
1067
+ mixed_precision = "fp16"
1068
+ elif config.get("bf16", {}).get("enabled", False):
1069
+ mixed_precision = "bf16"
1070
+ else:
1071
+ mixed_precision = "no"
1072
+ else:
1073
+ mixed_precision = self._mixed_precision
1074
+ return mixed_precision
1075
+
1076
+ @staticmethod
1077
+ def _reset_state(reset_partial_state: bool = False):
1078
+ "Resets `_shared_state`, is used internally and should not be called"
1079
+ AcceleratorState._shared_state.clear()
1080
+ if reset_partial_state:
1081
+ PartialState._reset_state()
1082
+
1083
+ def destroy_process_group(self, group=None):
1084
+ """
1085
+ Destroys the process group. If one is not specified, the default process group is destroyed.
1086
+
1087
+ If `self.fork_launched` is `True` and `group` is `None`, nothing happens.
1088
+ """
1089
+ PartialState().destroy_process_group(group)
1090
+
1091
+ @property
1092
+ def fork_launched(self):
1093
+ return PartialState().fork_launched
1094
+
1095
+ @property
1096
+ def use_distributed(self):
1097
+ """
1098
+ Whether the Accelerator is configured for distributed training
1099
+ """
1100
+ return PartialState().use_distributed
1101
+
1102
+ @property
1103
+ def is_fsdp2(self) -> bool:
1104
+ return self.distributed_type == DistributedType.FSDP and self.fsdp_plugin.fsdp_version == 2
1105
+
1106
+ @property
1107
+ def is_last_process(self) -> bool:
1108
+ "Returns whether the current process is the last one"
1109
+ return PartialState().is_last_process
1110
+
1111
+ @property
1112
+ def is_main_process(self) -> bool:
1113
+ "Returns whether the current process is the main process"
1114
+ return PartialState().is_main_process
1115
+
1116
+ @property
1117
+ def is_local_main_process(self) -> bool:
1118
+ "Returns whether the current process is the main process on the local node"
1119
+ return PartialState().is_local_main_process
1120
+
1121
+ def wait_for_everyone(self):
1122
+ PartialState().wait_for_everyone()
1123
+
1124
+ @contextmanager
1125
+ def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
1126
+ """
1127
+ Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
1128
+ distributed inference, such as with different prompts.
1129
+
1130
+ Note that when using a `dict`, all keys need to have the same number of elements.
1131
+
1132
+ Args:
1133
+ inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
1134
+ The input to split between processes.
1135
+ apply_padding (`bool`, `optional`, defaults to `False`):
1136
+ Whether to apply padding by repeating the last element of the input so that all processes have the same
1137
+ number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
1138
+ in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
1139
+
1140
+
1141
+ Example:
1142
+
1143
+ ```python
1144
+ # Assume there are two processes
1145
+ from accelerate.state import AcceleratorState
1146
+
1147
+ state = AcceleratorState()
1148
+ with state.split_between_processes(["A", "B", "C"]) as inputs:
1149
+ print(inputs)
1150
+ # Process 0
1151
+ ["A", "B"]
1152
+ # Process 1
1153
+ ["C"]
1154
+
1155
+ with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
1156
+ print(inputs)
1157
+ # Process 0
1158
+ ["A", "B"]
1159
+ # Process 1
1160
+ ["C", "C"]
1161
+ ```
1162
+ """
1163
+ with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
1164
+ yield inputs
1165
+
1166
+ @contextmanager
1167
+ def main_process_first(self):
1168
+ """
1169
+ Lets the main process go first inside a with block.
1170
+
1171
+ The other processes will enter the with block after the main process exits.
1172
+ """
1173
+ with PartialState().main_process_first():
1174
+ yield
1175
+
1176
+ @contextmanager
1177
+ def local_main_process_first(self):
1178
+ """
1179
+ Lets the local main process go inside a with block.
1180
+
1181
+ The other processes will enter the with block after the main process exits.
1182
+ """
1183
+ with PartialState().local_main_process_first():
1184
+ yield
1185
+
1186
+ @property
1187
+ def deepspeed_plugin(self):
1188
+ """
1189
+ Returns the currently active DeepSpeedPlugin.
1190
+
1191
+ If not using deepspeed, returns `None`.
1192
+ """
1193
+ # To maintain original behavior, return None if not using deepspeed.
1194
+ if self.distributed_type != DistributedType.DEEPSPEED:
1195
+ return None
1196
+ from accelerate.utils.deepspeed import get_active_deepspeed_plugin
1197
+
1198
+ return get_active_deepspeed_plugin(self)
1199
+
1200
+ @deepspeed_required
1201
+ def get_deepspeed_plugin(self, name: str):
1202
+ """
1203
+ Returns the DeepSpeedPlugin with the given plugin_key.
1204
+ """
1205
+ return self.deepspeed_plugins[name]
1206
+
1207
+ @deepspeed_required
1208
+ def select_deepspeed_plugin(self, name: str | None = None):
1209
+ """
1210
+ Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.
1211
+ """
1212
+ for key, plugin in self.deepspeed_plugins.items():
1213
+ if key != name:
1214
+ plugin._unselect()
1215
+ self.deepspeed_plugins[name].select(_from_accelerator_state=True)
1216
+
1217
+ def print(self, *args, **kwargs):
1218
+ PartialState().print(*args, **kwargs)
1219
+
1220
+ def __getattr__(self, name: str):
1221
+ # By this point we know that no attributes of `self` contain `name`,
1222
+ # so we just modify the error message
1223
+ if name in self._known_attrs:
1224
+ raise AttributeError(
1225
+ f"`AcceleratorState` object has no attribute `{name}`. "
1226
+ "This happens if `AcceleratorState._reset_state()` was called and "
1227
+ "an `Accelerator` or `PartialState` was not reinitialized."
1228
+ )
1229
+ # Raise a typical AttributeError
1230
+ raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
1231
+
1232
+
1233
+ class GradientState:
1234
+ """
1235
+ Singleton class that has information related to gradient synchronization for gradient accumulation
1236
+
1237
+ **Available attributes:**
1238
+
1239
+ - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader
1240
+ - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader
1241
+ - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices
1242
+ - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over
1243
+ - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are
1244
+ being iterated over
1245
+ - **num_steps** (`int`) -- The number of steps to accumulate over
1246
+ - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient
1247
+ accumulation
1248
+ - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
1249
+ iteration and the number of total steps reset
1250
+ - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
1251
+ as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
1252
+ after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
1253
+ is_xla_gradients_synced is always true.
1254
+ """
1255
+
1256
+ _shared_state = SharedDict()
1257
+
1258
+ def __init__(self, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None):
1259
+ self.__dict__ = self._shared_state
1260
+ if not self.initialized:
1261
+ self.sync_gradients = True
1262
+ self._dataloader_references_ref = [None]
1263
+ self.plugin_kwargs = (
1264
+ gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
1265
+ )
1266
+ self._is_xla_gradients_synced = False
1267
+
1268
+ # Plugin args are different and can be updated
1269
+ if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():
1270
+ self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()
1271
+
1272
+ @property
1273
+ def num_steps(self) -> int:
1274
+ "Returns the number of steps to accumulate over"
1275
+ return self.plugin_kwargs.get("num_steps", 1)
1276
+
1277
+ @property
1278
+ def adjust_scheduler(self) -> bool:
1279
+ "Returns whether the scheduler should be adjusted"
1280
+ return self.plugin_kwargs.get("adjust_scheduler", False)
1281
+
1282
+ @property
1283
+ def sync_with_dataloader(self) -> bool:
1284
+ "Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset"
1285
+ return self.plugin_kwargs.get("sync_with_dataloader", True)
1286
+
1287
+ @property
1288
+ def initialized(self) -> bool:
1289
+ "Returns whether the `GradientState` has been initialized"
1290
+ return GradientState._shared_state != {}
1291
+
1292
+ @property
1293
+ def end_of_dataloader(self) -> bool:
1294
+ "Returns whether we have reached the end of the current dataloader"
1295
+ if not self.in_dataloader:
1296
+ return False
1297
+ return self.active_dataloader.end_of_dataloader
1298
+
1299
+ @property
1300
+ def remainder(self) -> int:
1301
+ "Returns the number of extra samples that were added from padding the dataloader"
1302
+ if not self.in_dataloader:
1303
+ return -1
1304
+ return self.active_dataloader.remainder
1305
+
1306
+ def __repr__(self):
1307
+ return (
1308
+ f"Sync Gradients: {self.sync_gradients}\n"
1309
+ f"At end of current dataloader: {self.end_of_dataloader}\n"
1310
+ f"Extra samples added: {self.remainder}\n"
1311
+ f"Gradient accumulation plugin: {self.plugin_kwargs}\n"
1312
+ )
1313
+
1314
+ @property
1315
+ def is_xla_gradients_synced(self):
1316
+ "Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true."
1317
+ if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False):
1318
+ return True
1319
+ return self._is_xla_gradients_synced
1320
+
1321
+ @is_xla_gradients_synced.setter
1322
+ def is_xla_gradients_synced(self, is_synced):
1323
+ "Set the _is_xla_gradients_synced attribute."
1324
+ self._is_xla_gradients_synced = is_synced
1325
+
1326
+ def _set_sync_gradients(self, sync_gradients):
1327
+ "Private function that sets whether gradients should be synchronized. Users should not have to call this."
1328
+ self.sync_gradients = sync_gradients
1329
+ # Allow grad-sync to automatically work on TPUs
1330
+ if (
1331
+ self.sync_gradients
1332
+ and is_torch_xla_available(check_is_tpu=True)
1333
+ and PartialState().distributed_type == DistributedType.XLA
1334
+ ):
1335
+ xm.mark_step()
1336
+
1337
+ def _add_dataloader(self, dataloader):
1338
+ "Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
1339
+ # We explicitly use assignment to ensure that the property setter is triggered, which is required for garbage collection.
1340
+ # Avoid using self.dataloader_references.append as it will not trigger the setter.
1341
+ self.dataloader_references += [dataloader]
1342
+
1343
+ def _remove_dataloader(self, dataloader):
1344
+ "Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
1345
+ # We explicitly use assignment to ensure that the property setter is triggered.
1346
+ self.dataloader_references = [
1347
+ dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader
1348
+ ]
1349
+
1350
+ @property
1351
+ def active_dataloader(self):
1352
+ return self.dataloader_references[-1]
1353
+
1354
+ @property
1355
+ def dataloader_references(self):
1356
+ # We use a property getter and setter with weakrefs to avoid circular references that prevent garbage collection
1357
+ return [reference() if reference is not None else reference for reference in self._dataloader_references_ref]
1358
+
1359
+ @dataloader_references.setter
1360
+ def dataloader_references(self, references):
1361
+ self._dataloader_references_ref = [
1362
+ weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references
1363
+ ]
1364
+
1365
+ @property
1366
+ def in_dataloader(self) -> bool:
1367
+ "Returns whether the current process is in a dataloader"
1368
+ return self.active_dataloader is not None
1369
+
1370
+ @staticmethod
1371
+ def _reset_state():
1372
+ "Resets `_shared_state`, is used internally and should not be called"
1373
+ GradientState._shared_state.clear()