xiaoanyu123 commited on
Commit
cf47cca
·
verified ·
1 Parent(s): 8e2e3c8

Add files using upload-large-folder tool

Browse files
pythonProject/.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
pythonProject/.venv/Lib/site-packages/accelerate/accelerator.py ADDED
The diff for this file is too large to render. See raw diff
 
pythonProject/.venv/Lib/site-packages/accelerate/big_modeling.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import re
18
+ from contextlib import contextmanager
19
+ from functools import wraps
20
+ from typing import Optional, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from .hooks import (
26
+ AlignDevicesHook,
27
+ CpuOffload,
28
+ LayerwiseCastingHook,
29
+ UserCpuOffloadHook,
30
+ add_hook_to_module,
31
+ attach_align_device_hook,
32
+ attach_align_device_hook_on_blocks,
33
+ )
34
+ from .utils import (
35
+ OffloadedWeightsLoader,
36
+ check_cuda_p2p_ib_support,
37
+ check_device_map,
38
+ extract_submodules_state_dict,
39
+ find_tied_parameters,
40
+ get_balanced_memory,
41
+ infer_auto_device_map,
42
+ is_bnb_available,
43
+ is_mlu_available,
44
+ is_musa_available,
45
+ is_npu_available,
46
+ is_sdaa_available,
47
+ is_xpu_available,
48
+ load_checkpoint_in_model,
49
+ offload_state_dict,
50
+ parse_flag_from_env,
51
+ retie_parameters,
52
+ )
53
+ from .utils.constants import SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING
54
+ from .utils.other import recursive_getattr
55
+
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ @contextmanager
61
+ def init_empty_weights(include_buffers: bool = None):
62
+ """
63
+ A context manager under which models are initialized with all parameters on the meta device, therefore creating an
64
+ empty model. Useful when just initializing the model would blow the available RAM.
65
+
66
+ Args:
67
+ include_buffers (`bool`, *optional*):
68
+ Whether or not to also put all buffers on the meta device while initializing.
69
+
70
+ Example:
71
+
72
+ ```python
73
+ import torch.nn as nn
74
+ from accelerate import init_empty_weights
75
+
76
+ # Initialize a model with 100 billions parameters in no time and without using any RAM.
77
+ with init_empty_weights():
78
+ tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
79
+ ```
80
+
81
+ <Tip warning={true}>
82
+
83
+ Any model created under this context manager has no weights. As such you can't do something like
84
+ `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
85
+ Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
86
+ called.
87
+
88
+ </Tip>
89
+ """
90
+ if include_buffers is None:
91
+ include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
92
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
93
+ yield f
94
+
95
+
96
+ @contextmanager
97
+ def init_on_device(device: torch.device, include_buffers: bool = None):
98
+ """
99
+ A context manager under which models are initialized with all parameters on the specified device.
100
+
101
+ Args:
102
+ device (`torch.device`):
103
+ Device to initialize all parameters on.
104
+ include_buffers (`bool`, *optional*):
105
+ Whether or not to also put all buffers on the meta device while initializing.
106
+
107
+ Example:
108
+
109
+ ```python
110
+ import torch.nn as nn
111
+ from accelerate import init_on_device
112
+
113
+ with init_on_device(device=torch.device("cuda")):
114
+ tst = nn.Linear(100, 100) # on `cuda` device
115
+ ```
116
+ """
117
+ if include_buffers is None:
118
+ include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
119
+
120
+ if include_buffers:
121
+ with device:
122
+ yield
123
+ return
124
+
125
+ old_register_parameter = nn.Module.register_parameter
126
+ if include_buffers:
127
+ old_register_buffer = nn.Module.register_buffer
128
+
129
+ def register_empty_parameter(module, name, param):
130
+ old_register_parameter(module, name, param)
131
+ if param is not None:
132
+ param_cls = type(module._parameters[name])
133
+ kwargs = module._parameters[name].__dict__
134
+ kwargs["requires_grad"] = param.requires_grad
135
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
136
+
137
+ def register_empty_buffer(module, name, buffer, persistent=True):
138
+ old_register_buffer(module, name, buffer, persistent=persistent)
139
+ if buffer is not None:
140
+ module._buffers[name] = module._buffers[name].to(device)
141
+
142
+ # Patch tensor creation
143
+ if include_buffers:
144
+ tensor_constructors_to_patch = {
145
+ torch_function_name: getattr(torch, torch_function_name)
146
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
147
+ }
148
+ else:
149
+ tensor_constructors_to_patch = {}
150
+
151
+ def patch_tensor_constructor(fn):
152
+ def wrapper(*args, **kwargs):
153
+ kwargs["device"] = device
154
+ return fn(*args, **kwargs)
155
+
156
+ return wrapper
157
+
158
+ try:
159
+ nn.Module.register_parameter = register_empty_parameter
160
+ if include_buffers:
161
+ nn.Module.register_buffer = register_empty_buffer
162
+ for torch_function_name in tensor_constructors_to_patch.keys():
163
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
164
+ yield
165
+ finally:
166
+ nn.Module.register_parameter = old_register_parameter
167
+ if include_buffers:
168
+ nn.Module.register_buffer = old_register_buffer
169
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
170
+ setattr(torch, torch_function_name, old_torch_function)
171
+
172
+
173
+ def cpu_offload(
174
+ model: nn.Module,
175
+ execution_device: Optional[torch.device] = None,
176
+ offload_buffers: bool = False,
177
+ state_dict: Optional[dict[str, torch.Tensor]] = None,
178
+ preload_module_classes: Optional[list[str]] = None,
179
+ ):
180
+ """
181
+ Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
182
+ copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
183
+ state dict and put on the execution device passed as they are needed, then offloaded again.
184
+
185
+ Args:
186
+ model (`torch.nn.Module`):
187
+ The model to offload.
188
+ execution_device (`torch.device`, *optional*):
189
+ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
190
+ model first parameter device.
191
+ offload_buffers (`bool`, *optional*, defaults to `False`):
192
+ Whether or not to offload the buffers with the model parameters.
193
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
194
+ The state dict of the model that will be kept on CPU.
195
+ preload_module_classes (`List[str]`, *optional*):
196
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
197
+ of the forward. This should only be used for classes that have submodules which are registered but not
198
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
199
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
200
+ """
201
+ if execution_device is None:
202
+ execution_device = next(iter(model.parameters())).device
203
+ if state_dict is None:
204
+ state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
205
+
206
+ add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
207
+ attach_align_device_hook(
208
+ model,
209
+ execution_device=execution_device,
210
+ offload=True,
211
+ offload_buffers=offload_buffers,
212
+ weights_map=state_dict,
213
+ preload_module_classes=preload_module_classes,
214
+ )
215
+
216
+ return model
217
+
218
+
219
+ def cpu_offload_with_hook(
220
+ model: torch.nn.Module,
221
+ execution_device: Optional[Union[int, str, torch.device]] = None,
222
+ prev_module_hook: Optional[UserCpuOffloadHook] = None,
223
+ ):
224
+ """
225
+ Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
226
+ [`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
227
+ the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
228
+
229
+ Args:
230
+ model (`torch.nn.Module`):
231
+ The model to offload.
232
+ execution_device(`str`, `int` or `torch.device`, *optional*):
233
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
234
+ GPU 0 if there is a GPU, and finally to the CPU.
235
+ prev_module_hook (`UserCpuOffloadHook`, *optional*):
236
+ The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
237
+ offload method will be called just before the forward of the model to which this hook is attached.
238
+
239
+ Example:
240
+
241
+ ```py
242
+ model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
243
+ model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
244
+ model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
245
+
246
+ hid_1 = model_1(input)
247
+ for i in range(50):
248
+ # model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
249
+ hid_2 = model_2(hid_1)
250
+ # model2 is offloaded to the CPU just before this forward.
251
+ hid_3 = model_3(hid_3)
252
+
253
+ # For model3, you need to manually call the hook offload method.
254
+ hook_3.offload()
255
+ ```
256
+ """
257
+ hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
258
+ add_hook_to_module(model, hook, append=True)
259
+ user_hook = UserCpuOffloadHook(model, hook)
260
+ return model, user_hook
261
+
262
+
263
+ def disk_offload(
264
+ model: nn.Module,
265
+ offload_dir: Union[str, os.PathLike],
266
+ execution_device: Optional[torch.device] = None,
267
+ offload_buffers: bool = False,
268
+ preload_module_classes: Optional[list[str]] = None,
269
+ ):
270
+ """
271
+ Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
272
+ memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
273
+ put on the execution device passed as they are needed, then offloaded again.
274
+
275
+ Args:
276
+ model (`torch.nn.Module`): The model to offload.
277
+ offload_dir (`str` or `os.PathLike`):
278
+ The folder in which to offload the model weights (or where the model weights are already offloaded).
279
+ execution_device (`torch.device`, *optional*):
280
+ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
281
+ model's first parameter device.
282
+ offload_buffers (`bool`, *optional*, defaults to `False`):
283
+ Whether or not to offload the buffers with the model parameters.
284
+ preload_module_classes (`List[str]`, *optional*):
285
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
286
+ of the forward. This should only be used for classes that have submodules which are registered but not
287
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
288
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
289
+ """
290
+ if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
291
+ offload_state_dict(offload_dir, model.state_dict())
292
+ if execution_device is None:
293
+ execution_device = next(iter(model.parameters())).device
294
+ weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
295
+
296
+ add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
297
+ attach_align_device_hook(
298
+ model,
299
+ execution_device=execution_device,
300
+ offload=True,
301
+ offload_buffers=offload_buffers,
302
+ weights_map=weights_map,
303
+ preload_module_classes=preload_module_classes,
304
+ )
305
+
306
+ return model
307
+
308
+
309
+ def dispatch_model(
310
+ model: nn.Module,
311
+ device_map: dict[str, Union[str, int, torch.device]],
312
+ main_device: Optional[torch.device] = None,
313
+ state_dict: Optional[dict[str, torch.Tensor]] = None,
314
+ offload_dir: Optional[Union[str, os.PathLike]] = None,
315
+ offload_index: Optional[dict[str, str]] = None,
316
+ offload_buffers: bool = False,
317
+ skip_keys: Optional[Union[str, list[str]]] = None,
318
+ preload_module_classes: Optional[list[str]] = None,
319
+ force_hooks: bool = False,
320
+ ):
321
+ """
322
+ Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
323
+ the CPU or even the disk.
324
+
325
+ Args:
326
+ model (`torch.nn.Module`):
327
+ The model to dispatch.
328
+ device_map (`Dict[str, Union[str, int, torch.device]]`):
329
+ A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
330
+ `"disk"` is accepted even if it's not a proper value for `torch.device`.
331
+ main_device (`str`, `int` or `torch.device`, *optional*):
332
+ The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
333
+ `"disk"`.
334
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
335
+ The state dict of the part of the model that will be kept on CPU.
336
+ offload_dir (`str` or `os.PathLike`):
337
+ The folder in which to offload the model weights (or where the model weights are already offloaded).
338
+ offload_index (`Dict`, *optional*):
339
+ A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
340
+ to the index saved in `save_folder`.
341
+ offload_buffers (`bool`, *optional*, defaults to `False`):
342
+ Whether or not to offload the buffers with the model parameters.
343
+ skip_keys (`str` or `List[str]`, *optional*):
344
+ A list of keys to ignore when moving inputs or outputs between devices.
345
+ preload_module_classes (`List[str]`, *optional*):
346
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
347
+ of the forward. This should only be used for classes that have submodules which are registered but not
348
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
349
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
350
+ force_hooks (`bool`, *optional*, defaults to `False`):
351
+ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
352
+ single device.
353
+ """
354
+ # Error early if the device map is incomplete.
355
+ check_device_map(model, device_map)
356
+
357
+ # We need to force hook for quantized model that can't be moved with to()
358
+ if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes":
359
+ # since bnb 0.43.2, we can move 4-bit model
360
+ if getattr(model, "is_loaded_in_8bit", False) or (
361
+ getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2")
362
+ ):
363
+ force_hooks = True
364
+
365
+ # We attach hooks if the device_map has at least 2 different devices or if
366
+ # force_hooks is set to `True`. Otherwise, the model in already loaded
367
+ # in the unique device and the user can decide where to dispatch the model.
368
+ # If the model is quantized, we always force-dispatch the model
369
+ if (len(set(device_map.values())) > 1) or force_hooks:
370
+ if main_device is None:
371
+ if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
372
+ main_device = "cpu"
373
+ else:
374
+ main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
375
+
376
+ if main_device != "cpu":
377
+ cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
378
+ if state_dict is None and len(cpu_modules) > 0:
379
+ state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
380
+
381
+ disk_modules = [name for name, device in device_map.items() if device == "disk"]
382
+ if offload_dir is None and offload_index is None and len(disk_modules) > 0:
383
+ raise ValueError(
384
+ "We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
385
+ f"need to be offloaded: {', '.join(disk_modules)}."
386
+ )
387
+ if (
388
+ len(disk_modules) > 0
389
+ and offload_index is None
390
+ and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")))
391
+ ):
392
+ disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
393
+ offload_state_dict(offload_dir, disk_state_dict)
394
+
395
+ execution_device = {
396
+ name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
397
+ }
398
+ execution_device[""] = main_device
399
+ offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
400
+ offload = {name: device in offloaded_devices for name, device in device_map.items()}
401
+ save_folder = offload_dir if len(disk_modules) > 0 else None
402
+ if state_dict is not None or save_folder is not None or offload_index is not None:
403
+ device = main_device if offload_index is not None else None
404
+ weights_map = OffloadedWeightsLoader(
405
+ state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
406
+ )
407
+ else:
408
+ weights_map = None
409
+
410
+ # When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
411
+ # tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
412
+ # original pointer) on each devices.
413
+ tied_params = find_tied_parameters(model)
414
+
415
+ tied_params_map = {}
416
+ for group in tied_params:
417
+ for param_name in group:
418
+ # data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
419
+ # to care about views of tensors through storage_offset.
420
+ data_ptr = recursive_getattr(model, param_name).data_ptr()
421
+ tied_params_map[data_ptr] = {}
422
+
423
+ # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
424
+ # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
425
+
426
+ attach_align_device_hook_on_blocks(
427
+ model,
428
+ execution_device=execution_device,
429
+ offload=offload,
430
+ offload_buffers=offload_buffers,
431
+ weights_map=weights_map,
432
+ skip_keys=skip_keys,
433
+ preload_module_classes=preload_module_classes,
434
+ tied_params_map=tied_params_map,
435
+ )
436
+
437
+ # warn if there is any params on the meta device
438
+ offloaded_devices_str = " and ".join(
439
+ [device for device in set(device_map.values()) if device in ("cpu", "disk")]
440
+ )
441
+ if len(offloaded_devices_str) > 0:
442
+ logger.warning(
443
+ f"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}."
444
+ )
445
+
446
+ # Attaching the hook may break tied weights, so we retie them
447
+ retie_parameters(model, tied_params)
448
+
449
+ # add warning to cuda and to method
450
+ def add_warning(fn, model):
451
+ @wraps(fn)
452
+ def wrapper(*args, **kwargs):
453
+ warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks."
454
+ if str(fn.__name__) == "to":
455
+ to_device = torch._C._nn._parse_to(*args, **kwargs)[0]
456
+ if to_device is not None:
457
+ logger.warning(warning_msg)
458
+ else:
459
+ logger.warning(warning_msg)
460
+ for param in model.parameters():
461
+ if param.device == torch.device("meta"):
462
+ raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.")
463
+ return fn(*args, **kwargs)
464
+
465
+ return wrapper
466
+
467
+ # Make sure to update _accelerate_added_attributes in hooks.py if you add any hook
468
+ model.to = add_warning(model.to, model)
469
+ if is_npu_available():
470
+ model.npu = add_warning(model.npu, model)
471
+ elif is_mlu_available():
472
+ model.mlu = add_warning(model.mlu, model)
473
+ elif is_sdaa_available():
474
+ model.sdaa = add_warning(model.sdaa, model)
475
+ elif is_musa_available():
476
+ model.musa = add_warning(model.musa, model)
477
+ elif is_xpu_available():
478
+ model.xpu = add_warning(model.xpu, model)
479
+ else:
480
+ model.cuda = add_warning(model.cuda, model)
481
+
482
+ # Check if we are using multi-gpus with RTX 4000 series
483
+ use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1
484
+ if use_multi_gpu and not check_cuda_p2p_ib_support():
485
+ logger.warning(
486
+ "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. "
487
+ "This can affect the multi-gpu inference when using accelerate device_map."
488
+ "Please make sure to update your driver to the latest version which resolves this."
489
+ )
490
+ else:
491
+ device = list(device_map.values())[0]
492
+ # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
493
+ if is_npu_available() and isinstance(device, int):
494
+ device = f"npu:{device}"
495
+ elif is_mlu_available() and isinstance(device, int):
496
+ device = f"mlu:{device}"
497
+ elif is_sdaa_available() and isinstance(device, int):
498
+ device = f"sdaa:{device}"
499
+ elif is_musa_available() and isinstance(device, int):
500
+ device = f"musa:{device}"
501
+ if device != "disk":
502
+ model.to(device)
503
+ else:
504
+ raise ValueError(
505
+ "You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead."
506
+ )
507
+ # Convert OrderedDict back to dict for easier usage
508
+ model.hf_device_map = dict(device_map)
509
+ return model
510
+
511
+
512
+ def load_checkpoint_and_dispatch(
513
+ model: nn.Module,
514
+ checkpoint: Union[str, os.PathLike],
515
+ device_map: Optional[Union[str, dict[str, Union[int, str, torch.device]]]] = None,
516
+ max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
517
+ no_split_module_classes: Optional[list[str]] = None,
518
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
519
+ offload_buffers: bool = False,
520
+ dtype: Optional[Union[str, torch.dtype]] = None,
521
+ offload_state_dict: Optional[bool] = None,
522
+ skip_keys: Optional[Union[str, list[str]]] = None,
523
+ preload_module_classes: Optional[list[str]] = None,
524
+ force_hooks: bool = False,
525
+ strict: bool = False,
526
+ full_state_dict: bool = True,
527
+ broadcast_from_rank0: bool = False,
528
+ ):
529
+ """
530
+ Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
531
+ loaded and adds the various hooks that will make this model run properly (even if split across devices).
532
+
533
+ Args:
534
+ model (`torch.nn.Module`): The model in which we want to load a checkpoint.
535
+ checkpoint (`str` or `os.PathLike`):
536
+ The folder checkpoint to load. It can be:
537
+ - a path to a file containing a whole model state dict
538
+ - a path to a `.json` file containing the index to a sharded checkpoint
539
+ - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
540
+ device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
541
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
542
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
543
+
544
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more
545
+ information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).
546
+ Defaults to None, which means [`dispatch_model`] will not be called.
547
+ max_memory (`Dict`, *optional*):
548
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
549
+ and the available CPU RAM if unset.
550
+ no_split_module_classes (`List[str]`, *optional*):
551
+ A list of layer class names that should never be split across device (for instance any layer that has a
552
+ residual connection).
553
+ offload_folder (`str` or `os.PathLike`, *optional*):
554
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
555
+ offload_buffers (`bool`, *optional*, defaults to `False`):
556
+ In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
557
+ well as the parameters.
558
+ dtype (`str` or `torch.dtype`, *optional*):
559
+ If provided, the weights will be converted to that type when loaded.
560
+ offload_state_dict (`bool`, *optional*):
561
+ If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
562
+ the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
563
+ picked contains `"disk"` values.
564
+ skip_keys (`str` or `List[str]`, *optional*):
565
+ A list of keys to ignore when moving inputs or outputs between devices.
566
+ preload_module_classes (`List[str]`, *optional*):
567
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
568
+ of the forward. This should only be used for classes that have submodules which are registered but not
569
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
570
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
571
+ force_hooks (`bool`, *optional*, defaults to `False`):
572
+ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
573
+ single device.
574
+ strict (`bool`, *optional*, defaults to `False`):
575
+ Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
576
+ state_dict.
577
+ full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the
578
+ loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.
579
+ broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed
580
+ `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors
581
+ in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)
582
+ according to the local shards in the model.
583
+
584
+ Example:
585
+
586
+ ```python
587
+ >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch
588
+ >>> from huggingface_hub import hf_hub_download
589
+ >>> from transformers import AutoConfig, AutoModelForCausalLM
590
+
591
+ >>> # Download the Weights
592
+ >>> checkpoint = "EleutherAI/gpt-j-6B"
593
+ >>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
594
+
595
+ >>> # Create a model and initialize it with empty weights
596
+ >>> config = AutoConfig.from_pretrained(checkpoint)
597
+ >>> with init_empty_weights():
598
+ ... model = AutoModelForCausalLM.from_config(config)
599
+
600
+ >>> # Load the checkpoint and dispatch it to the right devices
601
+ >>> model = load_checkpoint_and_dispatch(
602
+ ... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
603
+ ... )
604
+ ```
605
+ """
606
+ if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
607
+ raise ValueError(
608
+ "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or 'sequential'."
609
+ )
610
+ if isinstance(device_map, str):
611
+ if device_map != "sequential":
612
+ max_memory = get_balanced_memory(
613
+ model,
614
+ max_memory=max_memory,
615
+ no_split_module_classes=no_split_module_classes,
616
+ dtype=dtype,
617
+ low_zero=(device_map == "balanced_low_0"),
618
+ )
619
+ device_map = infer_auto_device_map(
620
+ model,
621
+ max_memory=max_memory,
622
+ no_split_module_classes=no_split_module_classes,
623
+ dtype=dtype,
624
+ offload_buffers=offload_buffers,
625
+ )
626
+ if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
627
+ offload_state_dict = True
628
+ load_checkpoint_in_model(
629
+ model,
630
+ checkpoint,
631
+ device_map=device_map,
632
+ offload_folder=offload_folder,
633
+ dtype=dtype,
634
+ offload_state_dict=offload_state_dict,
635
+ offload_buffers=offload_buffers,
636
+ strict=strict,
637
+ full_state_dict=full_state_dict,
638
+ broadcast_from_rank0=broadcast_from_rank0,
639
+ )
640
+ if device_map is None:
641
+ return model
642
+ return dispatch_model(
643
+ model,
644
+ device_map=device_map,
645
+ offload_dir=offload_folder,
646
+ offload_buffers=offload_buffers,
647
+ skip_keys=skip_keys,
648
+ preload_module_classes=preload_module_classes,
649
+ force_hooks=force_hooks,
650
+ )
651
+
652
+
653
+ def attach_layerwise_casting_hooks(
654
+ module: torch.nn.Module,
655
+ storage_dtype: torch.dtype,
656
+ compute_dtype: torch.dtype,
657
+ skip_modules_pattern: Union[str, tuple[str, ...]] = None,
658
+ skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,
659
+ non_blocking: bool = False,
660
+ ) -> None:
661
+ r"""
662
+ Applies layerwise casting to a given module. The module expected here is a PyTorch `nn.Module`. This is helpful for
663
+ reducing memory requirements when one doesn't want to fully quantize a model. Model params can be kept in say,
664
+ `torch.float8_e4m3fn` and upcasted to a higher precision like `torch.bfloat16` during forward pass and downcasted
665
+ back to `torch.float8_e4m3fn` to realize memory savings.
666
+
667
+ Args:
668
+ module (`torch.nn.Module`):
669
+ The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
670
+ precision dtype for storage.
671
+ storage_dtype (`torch.dtype`):
672
+ The dtype to cast the module to before/after the forward pass for storage.
673
+ compute_dtype (`torch.dtype`):
674
+ The dtype to cast the module to during the forward pass for computation.
675
+ skip_modules_pattern (`tuple[str, ...]`, defaults to `None`):
676
+ A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
677
+ to `None` alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the
678
+ module instead of its internal submodules.
679
+ skip_modules_classes (`tuple[type[torch.nn.Module], ...]`, defaults to `None`):
680
+ A list of module classes to skip during the layerwise casting process.
681
+ non_blocking (`bool`, defaults to `False`):
682
+ If `True`, the weight casting operations are non-blocking.
683
+
684
+ Example:
685
+
686
+ ```python
687
+ >>> from accelerate.hooks import attach_layerwise_casting_hooks
688
+ >>> from transformers import AutoModelForCausalLM
689
+ >>> import torch
690
+
691
+ >>> # Model
692
+ >>> checkpoint = "EleutherAI/gpt-j-6B"
693
+ >>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
694
+
695
+ >>> # Attach hooks and perform inference
696
+ >>> attach_layerwise_casting_hooks(model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
697
+ >>> with torch.no_grad():
698
+ ... model(...)
699
+ ```
700
+
701
+ Users can also pass modules they want to avoid from getting downcasted.
702
+
703
+ ```py
704
+ >>> attach_layerwise_casting_hooks(
705
+ ... model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16, skip_modules_pattern=["norm"]
706
+ ... )
707
+ ```
708
+ """
709
+ _attach_layerwise_casting_hooks(
710
+ module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
711
+ )
712
+
713
+
714
+ def _attach_layerwise_casting_hooks(
715
+ module: torch.nn.Module,
716
+ storage_dtype: torch.dtype,
717
+ compute_dtype: torch.dtype,
718
+ skip_modules_pattern: Union[str, tuple[str, ...]] = None,
719
+ skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,
720
+ non_blocking: bool = False,
721
+ _prefix: str = "",
722
+ ):
723
+ should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
724
+ skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
725
+ )
726
+ if should_skip:
727
+ logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
728
+ return
729
+
730
+ if isinstance(module, SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING):
731
+ logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
732
+ add_hook_to_module(
733
+ module,
734
+ LayerwiseCastingHook(storage_dtype=storage_dtype, compute_dtype=compute_dtype, non_blocking=non_blocking),
735
+ append=True,
736
+ )
737
+ return
738
+
739
+ for name, submodule in module.named_children():
740
+ layer_name = f"{_prefix}.{name}" if _prefix else name
741
+ _attach_layerwise_casting_hooks(
742
+ submodule,
743
+ storage_dtype,
744
+ compute_dtype,
745
+ skip_modules_pattern,
746
+ skip_modules_classes,
747
+ non_blocking,
748
+ _prefix=layer_name,
749
+ )
750
+
751
+
752
+ def _attach_context_parallel_hooks(
753
+ model: nn.Module,
754
+ ):
755
+ """
756
+ Monkeypatch huggingface's `transformers` model to fix attention mask issues when using context parallelism.
757
+
758
+ This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the
759
+ args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,
760
+ if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does
761
+ not support attention masks. This function modifies the model in place.
762
+
763
+ Args:
764
+ model (`nn.Module`):
765
+ The model to attach the hooks to.
766
+
767
+ """
768
+
769
+ def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
770
+ if "attention_mask" in module_kwargs:
771
+ module_kwargs["attention_mask"] = None
772
+ module_kwargs["is_causal"] = True
773
+
774
+ return module_args, module_kwargs
775
+
776
+ for name, module in model.named_modules():
777
+ # We hope (assume) that if user uses their own model (without this structure which transformers uses), they read the docs saying they can't pass in attention masks
778
+ # Then these cases can happen:
779
+ # 1) some modules end with a `self-attn` module, in which case we attach the hook, but the
780
+ # there's no attention mask kwarg -> hook is a no-op
781
+ # 2) some modules end with a `self-attn` module, in which case we attach the hook, and the
782
+ # attention mask kwarg is passed -> hook will remove the attention mask and add
783
+ # `is_causal=True` kwarg, which either crashes the training or fixes it
784
+ # (training would crash anyway as attention mask isn't supported)
785
+ # 3) no modules end with a `self-attn` module, in which case we don't attach the hook, this is
786
+ # a no-op as well
787
+ if name.endswith("self_attn"):
788
+ # we want the hook to be executed first, to avoid any other hooks doing work on the attention mask
789
+ module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)
pythonProject/.venv/Lib/site-packages/accelerate/checkpointing.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import torch
20
+ from safetensors.torch import load_model
21
+
22
+ from .utils import (
23
+ MODEL_NAME,
24
+ OPTIMIZER_NAME,
25
+ RNG_STATE_NAME,
26
+ SAFE_MODEL_NAME,
27
+ SAFE_WEIGHTS_NAME,
28
+ SAMPLER_NAME,
29
+ SCALER_NAME,
30
+ SCHEDULER_NAME,
31
+ WEIGHTS_NAME,
32
+ get_pretty_name,
33
+ is_cuda_available,
34
+ is_hpu_available,
35
+ is_mlu_available,
36
+ is_musa_available,
37
+ is_sdaa_available,
38
+ is_torch_version,
39
+ is_torch_xla_available,
40
+ is_xpu_available,
41
+ load,
42
+ save,
43
+ )
44
+
45
+
46
+ if is_torch_version(">=", "2.4.0"):
47
+ from torch.amp import GradScaler
48
+ else:
49
+ from torch.cuda.amp import GradScaler
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ from .logging import get_logger
55
+ from .state import PartialState
56
+
57
+
58
+ logger = get_logger(__name__)
59
+
60
+
61
+ def save_accelerator_state(
62
+ output_dir: str,
63
+ model_states: list[dict],
64
+ optimizers: list,
65
+ schedulers: list,
66
+ dataloaders: list,
67
+ process_index: int,
68
+ step: int,
69
+ scaler: GradScaler = None,
70
+ save_on_each_node: bool = False,
71
+ safe_serialization: bool = True,
72
+ ):
73
+ """
74
+ Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
75
+
76
+ <Tip>
77
+
78
+ If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
79
+ `pickle`.
80
+
81
+ </Tip>
82
+
83
+ Args:
84
+ output_dir (`str` or `os.PathLike`):
85
+ The name of the folder to save all relevant weights and states.
86
+ model_states (`List[torch.nn.Module]`):
87
+ A list of model states
88
+ optimizers (`List[torch.optim.Optimizer]`):
89
+ A list of optimizer instances
90
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
91
+ A list of learning rate schedulers
92
+ dataloaders (`List[torch.utils.data.DataLoader]`):
93
+ A list of dataloader instances to save their sampler states
94
+ process_index (`int`):
95
+ The current process index in the Accelerator state
96
+ step (`int`):
97
+ The current step in the internal step tracker
98
+ scaler (`torch.amp.GradScaler`, *optional*):
99
+ An optional gradient scaler instance to save;
100
+ save_on_each_node (`bool`, *optional*):
101
+ Whether to save on every node, or only the main node.
102
+ safe_serialization (`bool`, *optional*, defaults to `True`):
103
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
104
+ """
105
+ output_dir = Path(output_dir)
106
+ # Model states
107
+ for i, state in enumerate(model_states):
108
+ weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
109
+ if i > 0:
110
+ weights_name = weights_name.replace(".", f"_{i}.")
111
+ output_model_file = output_dir.joinpath(weights_name)
112
+ save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
113
+ logger.info(f"Model weights saved in {output_model_file}")
114
+ # Optimizer states
115
+ for i, opt in enumerate(optimizers):
116
+ state = opt.state_dict()
117
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
118
+ output_optimizer_file = output_dir.joinpath(optimizer_name)
119
+ save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
120
+ logger.info(f"Optimizer state saved in {output_optimizer_file}")
121
+ # Scheduler states
122
+ for i, scheduler in enumerate(schedulers):
123
+ state = scheduler.state_dict()
124
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
125
+ output_scheduler_file = output_dir.joinpath(scheduler_name)
126
+ save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
127
+ logger.info(f"Scheduler state saved in {output_scheduler_file}")
128
+ # DataLoader states
129
+ for i, dataloader in enumerate(dataloaders):
130
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
131
+ output_sampler_file = output_dir.joinpath(sampler_name)
132
+ # Only save if we have our custom sampler
133
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
134
+
135
+ if isinstance(dataloader.dataset, IterableDatasetShard):
136
+ sampler = dataloader.get_sampler()
137
+ if isinstance(sampler, SeedableRandomSampler):
138
+ save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
139
+ if getattr(dataloader, "use_stateful_dataloader", False):
140
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
141
+ output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
142
+ state_dict = dataloader.state_dict()
143
+ torch.save(state_dict, output_dataloader_state_dict_file)
144
+ logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
145
+
146
+ # GradScaler state
147
+ if scaler is not None:
148
+ state = scaler.state_dict()
149
+ output_scaler_file = output_dir.joinpath(SCALER_NAME)
150
+ torch.save(state, output_scaler_file)
151
+ logger.info(f"Gradient scaler state saved in {output_scaler_file}")
152
+ # Random number generator states
153
+ states = {}
154
+ states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
155
+ states["step"] = step
156
+ states["random_state"] = random.getstate()
157
+ states["numpy_random_seed"] = np.random.get_state()
158
+ states["torch_manual_seed"] = torch.get_rng_state()
159
+ if is_xpu_available():
160
+ states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
161
+ if is_mlu_available():
162
+ states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
163
+ elif is_sdaa_available():
164
+ states["torch_sdaa_manual_seed"] = torch.sdaa.get_rng_state_all()
165
+ elif is_musa_available():
166
+ states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
167
+ if is_hpu_available():
168
+ states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
169
+ if is_cuda_available():
170
+ states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
171
+ if is_torch_xla_available():
172
+ states["xm_seed"] = xm.get_rng_state()
173
+ output_states_file = output_dir.joinpath(states_name)
174
+ torch.save(states, output_states_file)
175
+ logger.info(f"Random states saved in {output_states_file}")
176
+ return output_dir
177
+
178
+
179
+ def load_accelerator_state(
180
+ input_dir,
181
+ models,
182
+ optimizers,
183
+ schedulers,
184
+ dataloaders,
185
+ process_index,
186
+ scaler=None,
187
+ map_location=None,
188
+ load_kwargs=None,
189
+ **load_model_func_kwargs,
190
+ ):
191
+ """
192
+ Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
193
+
194
+ Args:
195
+ input_dir (`str` or `os.PathLike`):
196
+ The name of the folder to load all relevant weights and states.
197
+ models (`List[torch.nn.Module]`):
198
+ A list of model instances
199
+ optimizers (`List[torch.optim.Optimizer]`):
200
+ A list of optimizer instances
201
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
202
+ A list of learning rate schedulers
203
+ process_index (`int`):
204
+ The current process index in the Accelerator state
205
+ scaler (`torch.amp.GradScaler`, *optional*):
206
+ An optional *GradScaler* instance to load
207
+ map_location (`str`, *optional*):
208
+ What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
209
+ load_kwargs (`dict`, *optional*):
210
+ Additional arguments that can be passed to the `load` function.
211
+ load_model_func_kwargs (`dict`, *optional*):
212
+ Additional arguments that can be passed to the model's `load_state_dict` method.
213
+
214
+ Returns:
215
+ `dict`: Contains the `Accelerator` attributes to override while loading the state.
216
+ """
217
+ # stores the `Accelerator` attributes to override
218
+ override_attributes = dict()
219
+ if map_location not in [None, "cpu", "on_device"]:
220
+ raise TypeError(
221
+ "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
222
+ )
223
+ if map_location is None:
224
+ map_location = "cpu"
225
+ elif map_location == "on_device":
226
+ map_location = PartialState().device
227
+
228
+ if load_kwargs is None:
229
+ load_kwargs = {}
230
+
231
+ input_dir = Path(input_dir)
232
+ # Model states
233
+ for i, model in enumerate(models):
234
+ ending = f"_{i}" if i > 0 else ""
235
+ input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
236
+ if input_model_file.exists():
237
+ load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
238
+ else:
239
+ # Load with torch
240
+ input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
241
+ state_dict = load(input_model_file, map_location=map_location)
242
+ model.load_state_dict(state_dict, **load_model_func_kwargs)
243
+ logger.info("All model weights loaded successfully")
244
+
245
+ # Optimizer states
246
+ for i, opt in enumerate(optimizers):
247
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
248
+ input_optimizer_file = input_dir.joinpath(optimizer_name)
249
+ optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)
250
+ optimizers[i].load_state_dict(optimizer_state)
251
+ logger.info("All optimizer states loaded successfully")
252
+
253
+ # Scheduler states
254
+ for i, scheduler in enumerate(schedulers):
255
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
256
+ input_scheduler_file = input_dir.joinpath(scheduler_name)
257
+ scheduler_state = load(input_scheduler_file, **load_kwargs)
258
+ scheduler.load_state_dict(scheduler_state)
259
+ logger.info("All scheduler states loaded successfully")
260
+
261
+ for i, dataloader in enumerate(dataloaders):
262
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
263
+ input_sampler_file = input_dir.joinpath(sampler_name)
264
+ # Only load if we have our custom sampler
265
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
266
+
267
+ if isinstance(dataloader.dataset, IterableDatasetShard):
268
+ sampler = dataloader.get_sampler()
269
+ if isinstance(sampler, SeedableRandomSampler):
270
+ sampler = dataloader.set_sampler(load(input_sampler_file))
271
+ if getattr(dataloader, "use_stateful_dataloader", False):
272
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
273
+ input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
274
+ if input_dataloader_state_dict_file.exists():
275
+ state_dict = load(input_dataloader_state_dict_file, **load_kwargs)
276
+ dataloader.load_state_dict(state_dict)
277
+ logger.info("All dataloader sampler states loaded successfully")
278
+
279
+ # GradScaler state
280
+ if scaler is not None:
281
+ input_scaler_file = input_dir.joinpath(SCALER_NAME)
282
+ scaler_state = load(input_scaler_file)
283
+ scaler.load_state_dict(scaler_state)
284
+ logger.info("GradScaler state loaded successfully")
285
+
286
+ # Random states
287
+ try:
288
+ states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
289
+ if "step" in states:
290
+ override_attributes["step"] = states["step"]
291
+ random.setstate(states["random_state"])
292
+ np.random.set_state(states["numpy_random_seed"])
293
+ torch.set_rng_state(states["torch_manual_seed"])
294
+ if is_xpu_available():
295
+ torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
296
+ if is_mlu_available():
297
+ torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
298
+ elif is_sdaa_available():
299
+ torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"])
300
+ elif is_musa_available():
301
+ torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
302
+ else:
303
+ torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
304
+ if is_torch_xla_available():
305
+ xm.set_rng_state(states["xm_seed"])
306
+ logger.info("All random states loaded successfully")
307
+ except Exception:
308
+ logger.info("Could not load random states")
309
+
310
+ return override_attributes
311
+
312
+
313
+ def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
314
+ """
315
+ Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
316
+ """
317
+ # Should this be the right way to get a qual_name type value from `obj`?
318
+ save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
319
+ logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
320
+ save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
321
+
322
+
323
+ def load_custom_state(obj, path, index: int = 0):
324
+ """
325
+ Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
326
+ loading the state.
327
+ """
328
+ load_location = f"{path}/custom_checkpoint_{index}.pkl"
329
+ logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
330
+ obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))
pythonProject/.venv/Lib/site-packages/accelerate/commands/merge.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from accelerate.commands.utils import CustomArgumentParser
17
+ from accelerate.utils import merge_fsdp_weights
18
+
19
+
20
+ description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
21
+ `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
22
+
23
+ This is a CPU-bound process and requires enough RAM to load the entire model state dict."""
24
+
25
+
26
+ def merge_command(args):
27
+ merge_fsdp_weights(
28
+ args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
29
+ )
30
+
31
+
32
+ def merge_command_parser(subparsers=None):
33
+ if subparsers is not None:
34
+ parser = subparsers.add_parser("merge-weights", description=description)
35
+ else:
36
+ parser = CustomArgumentParser(description=description)
37
+
38
+ parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
39
+ parser.add_argument(
40
+ "output_path",
41
+ type=str,
42
+ help="The path to save the merged weights. Defaults to the current directory. ",
43
+ )
44
+ parser.add_argument(
45
+ "--unsafe_serialization",
46
+ action="store_true",
47
+ default=False,
48
+ help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
49
+ )
50
+ parser.add_argument(
51
+ "--remove_checkpoint_dir",
52
+ action="store_true",
53
+ help="Whether to remove the checkpoint directory after merging.",
54
+ default=False,
55
+ )
56
+
57
+ if subparsers is not None:
58
+ parser.set_defaults(func=merge_command)
59
+ return parser
60
+
61
+
62
+ def main():
63
+ parser = merge_command_parser()
64
+ args = parser.parse_args()
65
+ merge_command(args)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
pythonProject/.venv/Lib/site-packages/accelerate/commands/test.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package
20
+
21
+
22
+ def test_command_parser(subparsers=None):
23
+ if subparsers is not None:
24
+ parser = subparsers.add_parser("test")
25
+ else:
26
+ parser = argparse.ArgumentParser("Accelerate test command")
27
+
28
+ parser.add_argument(
29
+ "--config_file",
30
+ default=None,
31
+ help=(
32
+ "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
33
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
34
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
35
+ "with 'huggingface'."
36
+ ),
37
+ )
38
+
39
+ if subparsers is not None:
40
+ parser.set_defaults(func=test_command)
41
+ return parser
42
+
43
+
44
+ def test_command(args):
45
+ script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py")
46
+
47
+ if args.config_file is None:
48
+ test_args = [script_name]
49
+ else:
50
+ test_args = f"--config_file={args.config_file} {script_name}".split()
51
+
52
+ cmd = ["accelerate-launch"] + test_args
53
+ result = execute_subprocess_async(cmd)
54
+ if result.returncode == 0:
55
+ print("Test is a success! You are ready for your distributed training!")
56
+
57
+
58
+ def main():
59
+ parser = test_command_parser()
60
+ args = parser.parse_args()
61
+ test_command(args)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()
pythonProject/.venv/Lib/site-packages/accelerate/data_loader.py ADDED
@@ -0,0 +1,1451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import math
17
+ from contextlib import suppress
18
+ from typing import Callable, Optional, Union
19
+
20
+ import torch
21
+ from packaging import version
22
+ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
23
+
24
+ from .logging import get_logger
25
+ from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
26
+ from .utils import (
27
+ RNGType,
28
+ broadcast,
29
+ broadcast_object_list,
30
+ compare_versions,
31
+ concatenate,
32
+ find_batch_size,
33
+ get_data_structure,
34
+ initialize_tensors,
35
+ is_datasets_available,
36
+ is_torch_version,
37
+ is_torchdata_stateful_dataloader_available,
38
+ send_to_device,
39
+ slice_tensors,
40
+ synchronize_rng_states,
41
+ )
42
+
43
+
44
+ logger = get_logger(__name__)
45
+
46
+ # kwargs of the DataLoader in min version 2.0
47
+ _PYTORCH_DATALOADER_KWARGS = {
48
+ "batch_size": 1,
49
+ "shuffle": False,
50
+ "sampler": None,
51
+ "batch_sampler": None,
52
+ "num_workers": 0,
53
+ "collate_fn": None,
54
+ "pin_memory": False,
55
+ "drop_last": False,
56
+ "timeout": 0,
57
+ "worker_init_fn": None,
58
+ "multiprocessing_context": None,
59
+ "generator": None,
60
+ "prefetch_factor": 2,
61
+ "persistent_workers": False,
62
+ "pin_memory_device": "",
63
+ }
64
+
65
+ # kwargs added after by version
66
+ _PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}}
67
+
68
+ for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
69
+ if is_torch_version(">=", v):
70
+ _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
71
+
72
+
73
+ class SeedableRandomSampler(RandomSampler):
74
+ """
75
+ Same as a random sampler, except that in `__iter__` a seed can be used.
76
+
77
+ Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
78
+ and be fully reproducable on multiple iterations.
79
+
80
+ If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
81
+ (stored in `self.epoch`).
82
+ """
83
+
84
+ def __init__(self, *args, **kwargs):
85
+ data_seed = kwargs.pop("data_seed", None)
86
+ super().__init__(*args, **kwargs)
87
+
88
+ self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
89
+ self.epoch = 0
90
+
91
+ def __iter__(self):
92
+ if self.generator is None:
93
+ self.generator = torch.Generator(
94
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
95
+ )
96
+ self.generator.manual_seed(self.initial_seed)
97
+
98
+ # Allow `self.epoch` to modify the seed of the generator
99
+ seed = self.epoch + self.initial_seed
100
+ # print("Setting seed at epoch", self.epoch, seed)
101
+ self.generator.manual_seed(seed)
102
+ yield from super().__iter__()
103
+ self.set_epoch(self.epoch + 1)
104
+
105
+ def set_epoch(self, epoch: int):
106
+ "Sets the current iteration of the sampler."
107
+ self.epoch = epoch
108
+
109
+
110
+ class BatchSamplerShard(BatchSampler):
111
+ """
112
+ Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
113
+ always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
114
+ Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
115
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
116
+
117
+ Args:
118
+ batch_sampler (`torch.utils.data.sampler.BatchSampler`):
119
+ The batch sampler to split in several shards.
120
+ num_processes (`int`, *optional*, defaults to 1):
121
+ The number of processes running concurrently.
122
+ process_index (`int`, *optional*, defaults to 0):
123
+ The index of the current process.
124
+ split_batches (`bool`, *optional*, defaults to `False`):
125
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
126
+ yielding different full batches on each process.
127
+
128
+ On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
129
+
130
+ - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
131
+ this argument is set to `False`.
132
+ - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
133
+ then `[6, 7]` if this argument is set to `True`.
134
+ even_batches (`bool`, *optional*, defaults to `True`):
135
+ Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
136
+ multiple of (original batch size / number of processes).
137
+
138
+ <Tip warning={true}>
139
+
140
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
141
+ equal to `False`
142
+
143
+ </Tip>"""
144
+
145
+ def __init__(
146
+ self,
147
+ batch_sampler: BatchSampler,
148
+ num_processes: int = 1,
149
+ process_index: int = 0,
150
+ split_batches: bool = False,
151
+ even_batches: bool = True,
152
+ ):
153
+ if split_batches and batch_sampler.batch_size % num_processes != 0:
154
+ raise ValueError(
155
+ f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
156
+ f"needs to be a round multiple of the number of processes ({num_processes})."
157
+ )
158
+ self.batch_sampler = batch_sampler
159
+ self.num_processes = num_processes
160
+ self.process_index = process_index
161
+ self.split_batches = split_batches
162
+ self.even_batches = even_batches
163
+ self.batch_size = getattr(batch_sampler, "batch_size", None)
164
+ self.drop_last = getattr(batch_sampler, "drop_last", False)
165
+ if self.batch_size is None and self.even_batches:
166
+ raise ValueError(
167
+ "You need to use `even_batches=False` when the batch sampler has no batch size. If you "
168
+ "are not calling this method directly, set `accelerator.even_batches=False` instead."
169
+ )
170
+
171
+ @property
172
+ def total_length(self):
173
+ return len(self.batch_sampler)
174
+
175
+ def __len__(self):
176
+ if self.split_batches:
177
+ # Split batches does not change the length of the batch sampler
178
+ return len(self.batch_sampler)
179
+ if len(self.batch_sampler) % self.num_processes == 0:
180
+ # If the length is a round multiple of the number of processes, it's easy.
181
+ return len(self.batch_sampler) // self.num_processes
182
+ length = len(self.batch_sampler) // self.num_processes
183
+ if self.drop_last:
184
+ # Same if we drop the remainder.
185
+ return length
186
+ elif self.even_batches:
187
+ # When we even batches we always get +1
188
+ return length + 1
189
+ else:
190
+ # Otherwise it depends on the process index.
191
+ return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
192
+
193
+ def __iter__(self):
194
+ return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
195
+
196
+ def _iter_with_split(self):
197
+ initial_data = []
198
+ batch_length = self.batch_sampler.batch_size // self.num_processes
199
+ for idx, batch in enumerate(self.batch_sampler):
200
+ if idx == 0:
201
+ initial_data = batch
202
+ if len(batch) == self.batch_size:
203
+ # If the batch is full, we yield the part of it this process is responsible of.
204
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
205
+
206
+ # If drop_last is True of the last batch was full, iteration is over, otherwise...
207
+ if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
208
+ if not self.even_batches:
209
+ if len(batch) > batch_length * self.process_index:
210
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
211
+ else:
212
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
213
+ while len(initial_data) < self.batch_size:
214
+ initial_data += initial_data
215
+ batch = batch + initial_data
216
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
217
+
218
+ def _iter_with_no_split(self):
219
+ initial_data = []
220
+ batch_to_yield = []
221
+ for idx, batch in enumerate(self.batch_sampler):
222
+ # We gather the initial indices in case we need to circle back at the end.
223
+ if not self.drop_last and idx < self.num_processes:
224
+ initial_data += batch
225
+ # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
226
+ # yielding it.
227
+ if idx % self.num_processes == self.process_index:
228
+ batch_to_yield = batch
229
+ if idx % self.num_processes == self.num_processes - 1 and (
230
+ self.batch_size is None or len(batch) == self.batch_size
231
+ ):
232
+ yield batch_to_yield
233
+ batch_to_yield = []
234
+
235
+ # If drop_last is True, iteration is over, otherwise...
236
+ if not self.drop_last and len(initial_data) > 0:
237
+ if not self.even_batches:
238
+ if len(batch_to_yield) > 0:
239
+ yield batch_to_yield
240
+ else:
241
+ # ... we yield the complete batch we had saved before if it has the proper length
242
+ if len(batch_to_yield) == self.batch_size:
243
+ yield batch_to_yield
244
+
245
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
246
+ while len(initial_data) < self.num_processes * self.batch_size:
247
+ initial_data += initial_data
248
+
249
+ # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
250
+ if len(batch) == self.batch_size:
251
+ batch = []
252
+ idx += 1
253
+
254
+ # Make sure we yield a multiple of self.num_processes batches
255
+ cycle_index = 0
256
+ while idx % self.num_processes != 0 or len(batch) > 0:
257
+ end_index = cycle_index + self.batch_size - len(batch)
258
+ batch += initial_data[cycle_index:end_index]
259
+ if idx % self.num_processes == self.process_index:
260
+ yield batch
261
+ cycle_index = end_index
262
+ batch = []
263
+ idx += 1
264
+
265
+
266
+ class IterableDatasetShard(IterableDataset):
267
+ """
268
+ Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
269
+ always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
270
+ `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
271
+ `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
272
+ be too small or loop with indices from the beginning.
273
+
274
+ Args:
275
+ dataset (`torch.utils.data.dataset.IterableDataset`):
276
+ The batch sampler to split in several shards.
277
+ batch_size (`int`, *optional*, defaults to 1):
278
+ The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
279
+ `split_batches=True`).
280
+ drop_last (`bool`, *optional*, defaults to `False`):
281
+ Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
282
+ beginning.
283
+ num_processes (`int`, *optional*, defaults to 1):
284
+ The number of processes running concurrently.
285
+ process_index (`int`, *optional*, defaults to 0):
286
+ The index of the current process.
287
+ split_batches (`bool`, *optional*, defaults to `False`):
288
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
289
+ yielding different full batches on each process.
290
+
291
+ On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
292
+
293
+ - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
294
+ argument is set to `False`.
295
+ - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
296
+ this argument is set to `True`.
297
+ """
298
+
299
+ def __init__(
300
+ self,
301
+ dataset: IterableDataset,
302
+ batch_size: int = 1,
303
+ drop_last: bool = False,
304
+ num_processes: int = 1,
305
+ process_index: int = 0,
306
+ split_batches: bool = False,
307
+ ):
308
+ if split_batches and batch_size > 1 and batch_size % num_processes != 0:
309
+ raise ValueError(
310
+ f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
311
+ f"needs to be a round multiple of the number of processes ({num_processes})."
312
+ )
313
+ self.dataset = dataset
314
+ self.batch_size = batch_size
315
+ self.drop_last = drop_last
316
+ self.num_processes = num_processes
317
+ self.process_index = process_index
318
+ self.split_batches = split_batches
319
+
320
+ def set_epoch(self, epoch):
321
+ self.epoch = epoch
322
+ if hasattr(self.dataset, "set_epoch"):
323
+ self.dataset.set_epoch(epoch)
324
+
325
+ def __len__(self):
326
+ # We will just raise the downstream error if the underlying dataset is not sized
327
+ if self.drop_last:
328
+ return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
329
+ else:
330
+ return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
331
+
332
+ def __iter__(self):
333
+ if (
334
+ not hasattr(self.dataset, "set_epoch")
335
+ and hasattr(self.dataset, "generator")
336
+ and isinstance(self.dataset.generator, torch.Generator)
337
+ ):
338
+ self.dataset.generator.manual_seed(self.epoch)
339
+ real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
340
+ process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
341
+ process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
342
+
343
+ first_batch = None
344
+ current_batch = []
345
+ for element in self.dataset:
346
+ current_batch.append(element)
347
+ # Wait to have a full batch before yielding elements.
348
+ if len(current_batch) == real_batch_size:
349
+ for i in process_slice:
350
+ yield current_batch[i]
351
+ if first_batch is None:
352
+ first_batch = current_batch.copy()
353
+ current_batch = []
354
+
355
+ # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
356
+ if not self.drop_last and len(current_batch) > 0:
357
+ if first_batch is None:
358
+ first_batch = current_batch.copy()
359
+ while len(current_batch) < real_batch_size:
360
+ current_batch += first_batch
361
+ for i in process_slice:
362
+ yield current_batch[i]
363
+
364
+
365
+ class DataLoaderStateMixin:
366
+ """
367
+ Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
368
+ end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
369
+ useful information that might be needed.
370
+
371
+ **Available attributes:**
372
+
373
+ - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
374
+ - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
375
+ batch size
376
+
377
+ <Tip warning={true}>
378
+
379
+ Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
380
+ `self.gradient_state`.
381
+
382
+ </Tip>
383
+
384
+ """
385
+
386
+ def __init_subclass__(cls, **kwargs):
387
+ cls.end_of_dataloader = False
388
+ cls.remainder = -1
389
+
390
+ def reset(self):
391
+ self.end_of_dataloader = False
392
+ self.remainder = -1
393
+
394
+ def begin(self):
395
+ "Prepares the gradient state for the current dataloader"
396
+ self.reset()
397
+ with suppress(Exception):
398
+ if not self._drop_last:
399
+ length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
400
+ self.remainder = length % self.total_batch_size
401
+ self.gradient_state._add_dataloader(self)
402
+
403
+ def end(self):
404
+ "Cleans up the gradient state after exiting the dataloader"
405
+ self.gradient_state._remove_dataloader(self)
406
+
407
+
408
+ class DataLoaderAdapter:
409
+ """
410
+ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
411
+ compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
412
+ """
413
+
414
+ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
415
+ self.use_stateful_dataloader = use_stateful_dataloader
416
+ if is_torchdata_stateful_dataloader_available():
417
+ from torchdata.stateful_dataloader import StatefulDataLoader
418
+
419
+ if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
420
+ raise ImportError(
421
+ "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
422
+ )
423
+ if use_stateful_dataloader:
424
+ torchdata_version = version.parse(importlib.metadata.version("torchdata"))
425
+ if (
426
+ "in_order" in kwargs
427
+ and compare_versions(torchdata_version, "<", "0.11")
428
+ and is_torch_version(">=", "2.6.0")
429
+ ):
430
+ kwargs.pop("in_order")
431
+ self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
432
+ else:
433
+ self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
434
+
435
+ if hasattr(self.base_dataloader, "state_dict"):
436
+ self.dl_state_dict = self.base_dataloader.state_dict()
437
+
438
+ def __getattr__(self, name):
439
+ # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
440
+ if name == "base_dataloader":
441
+ raise AttributeError()
442
+ # Delegate attribute access to the internal dataloader
443
+ return getattr(self.base_dataloader, name)
444
+
445
+ def state_dict(self):
446
+ return self.dl_state_dict
447
+
448
+ def load_state_dict(self, state_dict):
449
+ self.base_dataloader.load_state_dict(state_dict)
450
+
451
+ @property
452
+ def __class__(self):
453
+ """
454
+ In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
455
+ returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
456
+ object.
457
+ """
458
+ return self.base_dataloader.__class__
459
+
460
+ def __len__(self):
461
+ return len(self.base_dataloader)
462
+
463
+ def adjust_state_dict_for_prefetch(self):
464
+ """
465
+ Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
466
+ `self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
467
+ overridden.
468
+
469
+ This should modify `self.dl_state_dict` directly
470
+ """
471
+ # The state dict will be off by a factor of `n-1` batch too many during DDP,
472
+ # so we need to adjust it here
473
+ if PartialState().distributed_type != DistributedType.NO:
474
+ factor = PartialState().num_processes - 1
475
+ if self.dl_state_dict["_sampler_iter_yielded"] > 0:
476
+ self.dl_state_dict["_sampler_iter_yielded"] -= factor
477
+ if self.dl_state_dict["_num_yielded"] > 0:
478
+ self.dl_state_dict["_num_yielded"] -= factor
479
+ if self.dl_state_dict["_index_sampler_state"] is not None:
480
+ if (
481
+ "samples_yielded" in self.dl_state_dict["_index_sampler_state"]
482
+ and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
483
+ ):
484
+ self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
485
+
486
+ def _update_state_dict(self):
487
+ # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
488
+ # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
489
+ # what it wants to yield.
490
+ #
491
+ # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
492
+ if hasattr(self.base_dataloader, "state_dict"):
493
+ self.dl_state_dict = self.base_dataloader.state_dict()
494
+ # Potentially modify the state_dict to adjust for prefetching
495
+ self.adjust_state_dict_for_prefetch()
496
+ # Then tag if we are at the end of the dataloader
497
+ self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
498
+
499
+
500
+ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
501
+ """
502
+ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
503
+
504
+ Args:
505
+ dataset (`torch.utils.data.dataset.Dataset`):
506
+ The dataset to use to build this dataloader.
507
+ device (`torch.device`, *optional*):
508
+ If passed, the device to put all batches on.
509
+ rng_types (list of `str` or [`~utils.RNGType`]):
510
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
511
+ several of:
512
+
513
+ - `"torch"`: the base torch random number generator
514
+ - `"cuda"`: the CUDA random number generator (GPU only)
515
+ - `"xla"`: the XLA random number generator (TPU only)
516
+ - `"generator"`: an optional `torch.Generator`
517
+ synchronized_generator (`torch.Generator`, *optional*):
518
+ A random number generator to keep synchronized across processes.
519
+ skip_batches (`int`, *optional*, defaults to 0):
520
+ The number of batches to skip at the beginning.
521
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
522
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
523
+ **kwargs (additional keyword arguments, *optional*):
524
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
525
+
526
+ **Available attributes:**
527
+
528
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
529
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
530
+ number of processes
531
+
532
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
533
+ """
534
+
535
+ def __init__(
536
+ self,
537
+ dataset,
538
+ device=None,
539
+ rng_types=None,
540
+ synchronized_generator=None,
541
+ skip_batches=0,
542
+ use_stateful_dataloader=False,
543
+ _drop_last: bool = False,
544
+ _non_blocking: bool = False,
545
+ torch_device_mesh=None,
546
+ **kwargs,
547
+ ):
548
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
549
+ self.device = device
550
+ self.rng_types = rng_types
551
+ self.synchronized_generator = synchronized_generator
552
+ self.skip_batches = skip_batches
553
+ self.gradient_state = GradientState()
554
+ self._drop_last = _drop_last
555
+ self._non_blocking = _non_blocking
556
+ self.iteration = 0
557
+
558
+ def __iter__(self):
559
+ if self.rng_types is not None:
560
+ synchronize_rng_states(self.rng_types, self.synchronized_generator)
561
+ self.begin()
562
+
563
+ self.set_epoch(self.iteration)
564
+ dataloader_iter = self.base_dataloader.__iter__()
565
+ # We iterate one batch ahead to check when we are at the end
566
+ try:
567
+ current_batch = next(dataloader_iter)
568
+ except StopIteration:
569
+ self.end()
570
+ return
571
+
572
+ batch_index = 0
573
+ while True:
574
+ try:
575
+ # But we still move it to the device so it is done before `StopIteration` is reached
576
+ if self.device is not None:
577
+ current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
578
+ self._update_state_dict()
579
+ next_batch = next(dataloader_iter)
580
+ if batch_index >= self.skip_batches:
581
+ yield current_batch
582
+ batch_index += 1
583
+ current_batch = next_batch
584
+ except StopIteration:
585
+ self.end_of_dataloader = True
586
+ self._update_state_dict()
587
+ if batch_index >= self.skip_batches:
588
+ yield current_batch
589
+ break
590
+
591
+ self.iteration += 1
592
+ self.end()
593
+
594
+ def __reduce__(self):
595
+ """
596
+ Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
597
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
598
+ `__class__` member.
599
+ """
600
+ args = super().__reduce__()
601
+ return (DataLoaderShard, *args[1:])
602
+
603
+ def set_epoch(self, epoch: int):
604
+ # In case it is manually passed in, the user can set it to what they like
605
+ if self.iteration != epoch:
606
+ self.iteration = epoch
607
+ if hasattr(self.batch_sampler, "set_epoch"):
608
+ self.batch_sampler.set_epoch(epoch)
609
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
610
+ self.batch_sampler.sampler.set_epoch(epoch)
611
+ if (
612
+ hasattr(self.batch_sampler, "batch_sampler")
613
+ and hasattr(self.batch_sampler.batch_sampler, "sampler")
614
+ and hasattr(self.batch_sampler.batch_sampler.sampler, "set_epoch")
615
+ ):
616
+ self.batch_sampler.batch_sampler.sampler.set_epoch(epoch)
617
+ # We support if a custom `Dataset` implementation has `set_epoch`
618
+ # or in general HF datasets `Datasets`
619
+ elif hasattr(self.dataset, "set_epoch"):
620
+ self.dataset.set_epoch(epoch)
621
+
622
+ @property
623
+ def total_batch_size(self):
624
+ batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
625
+ return (
626
+ batch_sampler.batch_size
627
+ if getattr(batch_sampler, "split_batches", False)
628
+ else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
629
+ )
630
+
631
+ @property
632
+ def total_dataset_length(self):
633
+ if hasattr(self.dataset, "total_length"):
634
+ return self.dataset.total_length
635
+ else:
636
+ return len(self.dataset)
637
+
638
+ def get_sampler(self):
639
+ return get_sampler(self)
640
+
641
+ def set_sampler(self, sampler):
642
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
643
+ if sampler_is_batch_sampler:
644
+ self.sampler.sampler = sampler
645
+ else:
646
+ self.batch_sampler.sampler = sampler
647
+ if hasattr(self.batch_sampler, "batch_sampler"):
648
+ self.batch_sampler.batch_sampler.sampler = sampler
649
+
650
+
651
+ if is_torch_xla_available():
652
+ import torch_xla.distributed.parallel_loader as xpl
653
+
654
+ class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
655
+ """
656
+ Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
657
+
658
+ XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
659
+ prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
660
+ thread only.
661
+
662
+ **Available attributes:**
663
+
664
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
665
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
666
+ number of processes
667
+
668
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
669
+ """
670
+
671
+ def __init__(self, dataloader: DataLoaderShard, device: torch.device):
672
+ super().__init__(dataloader, device)
673
+ self._rng_types = self._loader.rng_types
674
+ self._loader.rng_types = None
675
+ self.device = device
676
+
677
+ def __iter__(self):
678
+ if self._rng_types is not None:
679
+ synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
680
+
681
+ return super().__iter__()
682
+
683
+ def set_epoch(self, epoch: int):
684
+ if hasattr(self.dataloader, "set_epoch"):
685
+ self.dataloader.set_epoch(epoch)
686
+
687
+ @property
688
+ def total_batch_size(self):
689
+ return self._loader.total_batch_size
690
+
691
+ @property
692
+ def total_dataset_length(self):
693
+ return self._loader.total_dataset_length
694
+
695
+ @property
696
+ def batch_sampler(self):
697
+ return self._loader.batch_sampler
698
+
699
+ @property
700
+ def dataloader(self):
701
+ return self._loader
702
+
703
+
704
+ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
705
+ """
706
+ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
707
+ their part of the batch.
708
+
709
+ Args:
710
+ split_batches (`bool`, *optional*, defaults to `False`):
711
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
712
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
713
+ `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
714
+ the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
715
+ `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
716
+ size of the `dataloader` is a round multiple of `batch_size`.
717
+ skip_batches (`int`, *optional*, defaults to 0):
718
+ The number of batches to skip at the beginning of an iteration.
719
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
720
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
721
+
722
+ **Available attributes:**
723
+
724
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
725
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
726
+ number of processes
727
+
728
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
729
+ """
730
+
731
+ def __init__(
732
+ self,
733
+ dataset,
734
+ split_batches: bool = False,
735
+ skip_batches=0,
736
+ use_stateful_dataloader=False,
737
+ _drop_last: bool = False,
738
+ _non_blocking: bool = False,
739
+ slice_fn=None,
740
+ torch_device_mesh=None,
741
+ **kwargs,
742
+ ):
743
+ shuffle = False
744
+ from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
745
+
746
+ # We need to save the shuffling state of the DataPipe
747
+ if isinstance(dataset, ShufflerIterDataPipe):
748
+ shuffle = dataset._shuffle_enabled
749
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
750
+ self.split_batches = split_batches
751
+ if shuffle:
752
+ torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
753
+
754
+ self.gradient_state = GradientState()
755
+ self.state = PartialState()
756
+ self._drop_last = _drop_last
757
+ self._non_blocking = _non_blocking
758
+ self.skip_batches = skip_batches
759
+ self.torch_device_mesh = torch_device_mesh
760
+
761
+ self.slice_fn = slice_tensors if slice_fn is None else slice_fn
762
+ self.iteration = 0
763
+
764
+ # if a device mesh is provided extract each dimension (dp, fsdp, tp)
765
+ # device mesh may hold any number of dimensions, however,
766
+ # below code is for targetted support for dp, fsdp and tp
767
+
768
+ # device mesh will be used only if there is tp involved
769
+ # or any multi-dimensional parallelism involving tp
770
+ # (dp, tp) (fsdp, tp) (dp, fsdp, tp)
771
+ # otherwise the default behavour not using device mesh should be sufficient
772
+ # since multi dimensional parallelism devoid of tp would anyway need
773
+ # different batches for each process irrespective of dp or fsdp
774
+ self.submesh_tp = None
775
+ self.submesh_dp = None
776
+ self.submesh_fsdp = None
777
+ if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
778
+ self.submesh_tp = self.torch_device_mesh["tp"]
779
+ if "dp" in self.torch_device_mesh.mesh_dim_names:
780
+ self.submesh_dp = self.torch_device_mesh["dp"]
781
+ if "fsdp" in self.torch_device_mesh.mesh_dim_names:
782
+ self.submesh_fsdp = self.torch_device_mesh["fsdp"]
783
+ if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
784
+ raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")
785
+
786
+ def _fetch_batches(self, iterator):
787
+ batches, batch = None, None
788
+ # On process 0, we gather the batch to dispatch.
789
+ if self.state.process_index == 0:
790
+ # Procedure to support TP only is simpler
791
+ # since we want to dispatch the same batch of samples across all ranks
792
+ # this removes complexity of handling multiple tp rank groups when TP + DP
793
+ # combination is involved.
794
+
795
+ try:
796
+ # for TP case avoid using split_batches
797
+ # since it would mean that the dataloader should be spilling out
798
+ # duplicates of batches.
799
+ if self.split_batches:
800
+ # One batch of the main iterator is dispatched and split.
801
+ if self.submesh_tp:
802
+ logger.warning(
803
+ "Use of split_batches for TP would need the dataloader to produce duplicate batches,"
804
+ "otherwise, use dispatch_batches=True instead."
805
+ )
806
+ self._update_state_dict()
807
+ batch = next(iterator)
808
+ else:
809
+ # num_processes batches of the main iterator are concatenated then dispatched and split.
810
+ # We add the batches one by one so we have the remainder available when drop_last=False.
811
+ batches = []
812
+ if self.submesh_tp:
813
+ # when tp, extract single batch and then replicate
814
+ self._update_state_dict()
815
+ batch = next(iterator)
816
+ batches = [batch] * self.state.num_processes
817
+ else:
818
+ for _ in range(self.state.num_processes):
819
+ self._update_state_dict()
820
+ batches.append(next(iterator))
821
+ try:
822
+ batch = concatenate(batches, dim=0)
823
+ except RuntimeError as e:
824
+ raise RuntimeError(
825
+ "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
826
+ "either pass `dispatch_batches=False` and have each process fetch its own batch "
827
+ " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
828
+ "slice it into `num_processes` batches for each process."
829
+ ) from e
830
+ # In both cases, we need to get the structure of the batch that we will broadcast on other
831
+ # processes to initialize the tensors with the right shape.
832
+ # data_structure, stop_iteration
833
+ batch_info = [get_data_structure(batch), False]
834
+ except StopIteration:
835
+ batch_info = [None, True]
836
+ else:
837
+ batch_info = [None, self._stop_iteration]
838
+ # This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
839
+ broadcast_object_list(batch_info)
840
+ self._stop_iteration = batch_info[1]
841
+ if self._stop_iteration:
842
+ # If drop_last is False and split_batches is False, we may have a remainder to take care of.
843
+ if not self.split_batches and not self._drop_last:
844
+ if self.state.process_index == 0 and len(batches) > 0:
845
+ batch = concatenate(batches, dim=0)
846
+ batch_info = [get_data_structure(batch), False]
847
+ else:
848
+ batch_info = [None, True]
849
+ broadcast_object_list(batch_info)
850
+ return batch, batch_info
851
+
852
+ def __iter__(self):
853
+ self.begin()
854
+ self.set_epoch(self.iteration)
855
+ main_iterator = None
856
+ if is_torch_version(">=", "2.0.1"):
857
+ # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
858
+ # shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
859
+ # But, we only iterate through the DataLoader on process 0.
860
+ main_iterator = self.base_dataloader.__iter__()
861
+ elif self.state.process_index == 0:
862
+ main_iterator = self.base_dataloader.__iter__()
863
+ stop_iteration = False
864
+ self._stop_iteration = False
865
+ first_batch = None
866
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
867
+ batch_index = 0
868
+ while not stop_iteration:
869
+ batch, batch_info = next_batch, next_batch_info
870
+
871
+ if self.state.process_index != 0:
872
+ # Initialize tensors on other processes than process 0.
873
+ batch = initialize_tensors(batch_info[0])
874
+ batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
875
+ # Broadcast the batch before splitting it.
876
+ batch = broadcast(batch, from_process=0)
877
+
878
+ if not self._drop_last and first_batch is None:
879
+ # We keep at least num processes elements of the first batch to be able to complete the last batch
880
+ first_batch = self.slice_fn(
881
+ batch,
882
+ slice(0, self.state.num_processes),
883
+ process_index=self.state.process_index,
884
+ num_processes=self.state.num_processes,
885
+ )
886
+
887
+ if batch is None:
888
+ raise ValueError(
889
+ f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
890
+ )
891
+
892
+ observed_batch_size = find_batch_size(batch)
893
+ batch_size = observed_batch_size // self.state.num_processes
894
+
895
+ stop_iteration = self._stop_iteration
896
+ if not stop_iteration:
897
+ # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
898
+ # the dataloader since the number of batches is a round multiple of the number of processes.
899
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
900
+ # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
901
+ if self._stop_iteration and next_batch_info[0] is None:
902
+ stop_iteration = True
903
+
904
+ if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
905
+ # If the last batch is not complete, let's add the first batch to it.
906
+ batch = concatenate([batch, first_batch], dim=0)
907
+ # Batch size computation above is wrong, it's off by 1 so we fix it.
908
+ batch_size += 1
909
+
910
+ data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
911
+ batch = self.slice_fn(
912
+ batch,
913
+ data_slice,
914
+ process_index=self.state.process_index,
915
+ num_processes=self.state.num_processes,
916
+ )
917
+
918
+ if stop_iteration:
919
+ self.end_of_dataloader = True
920
+ self._update_state_dict()
921
+ self.remainder = observed_batch_size
922
+ if batch_index >= self.skip_batches:
923
+ yield batch
924
+ batch_index += 1
925
+ self.iteration += 1
926
+ self.end()
927
+
928
+ def set_epoch(self, epoch: int):
929
+ # In case it is manually passed in, the user can set it to what they like
930
+ if self.iteration != epoch:
931
+ self.iteration = epoch
932
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
933
+ self.batch_sampler.sampler.set_epoch(epoch)
934
+ elif hasattr(self.dataset, "set_epoch"):
935
+ self.dataset.set_epoch(epoch)
936
+
937
+ def __len__(self):
938
+ whole_length = len(self.base_dataloader)
939
+ if self.split_batches:
940
+ return whole_length
941
+ elif self._drop_last:
942
+ return whole_length // self.state.num_processes
943
+ else:
944
+ return math.ceil(whole_length / self.state.num_processes)
945
+
946
+ def __reduce__(self):
947
+ """
948
+ Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
949
+ be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
950
+ `__class__` member.
951
+ """
952
+ args = super().__reduce__()
953
+ return (DataLoaderDispatcher, *args[1:])
954
+
955
+ @property
956
+ def total_batch_size(self):
957
+ return (
958
+ self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
959
+ )
960
+
961
+ @property
962
+ def total_dataset_length(self):
963
+ return len(self.dataset)
964
+
965
+ def get_sampler(self):
966
+ return get_sampler(self)
967
+
968
+ def set_sampler(self, sampler):
969
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
970
+ if sampler_is_batch_sampler:
971
+ self.sampler.sampler = sampler
972
+ else:
973
+ self.batch_sampler.sampler = sampler
974
+ if hasattr(self.batch_sampler, "batch_sampler"):
975
+ self.batch_sampler.batch_sampler.sampler = sampler
976
+
977
+
978
+ def get_sampler(dataloader):
979
+ """
980
+ Get the sampler associated to the dataloader
981
+
982
+ Args:
983
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
984
+ The data loader to split across several devices.
985
+ Returns:
986
+ `torch.utils.data.Sampler`: The sampler associated to the dataloader
987
+ """
988
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
989
+ if sampler_is_batch_sampler:
990
+ sampler = getattr(dataloader.sampler, "sampler", None)
991
+ else:
992
+ sampler = getattr(dataloader.batch_sampler, "sampler", None)
993
+ return sampler
994
+
995
+
996
+ def prepare_data_loader(
997
+ dataloader: DataLoader,
998
+ device: Optional[torch.device] = None,
999
+ num_processes: Optional[int] = None,
1000
+ process_index: Optional[int] = None,
1001
+ split_batches: bool = False,
1002
+ put_on_device: bool = False,
1003
+ rng_types: Optional[list[Union[str, RNGType]]] = None,
1004
+ dispatch_batches: Optional[bool] = None,
1005
+ even_batches: bool = True,
1006
+ slice_fn_for_dispatch: Optional[Callable] = None,
1007
+ use_seedable_sampler: bool = False,
1008
+ data_seed: Optional[int] = None,
1009
+ non_blocking: bool = False,
1010
+ use_stateful_dataloader: bool = False,
1011
+ torch_device_mesh=None,
1012
+ ) -> DataLoader:
1013
+ """
1014
+ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
1015
+
1016
+ Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
1017
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
1018
+
1019
+ Args:
1020
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
1021
+ The data loader to split across several devices.
1022
+ device (`torch.device`):
1023
+ The target device for the returned `DataLoader`.
1024
+ num_processes (`int`, *optional*):
1025
+ The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
1026
+ process_index (`int`, *optional*):
1027
+ The index of the current process. Will default to the value given by [`~state.PartialState`].
1028
+ split_batches (`bool`, *optional*, defaults to `False`):
1029
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
1030
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
1031
+ `num_processes` batches at each iteration).
1032
+
1033
+ Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
1034
+ this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
1035
+ otherwise.
1036
+
1037
+ Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
1038
+ `batch_size`.
1039
+ put_on_device (`bool`, *optional*, defaults to `False`):
1040
+ Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
1041
+ dictionaries of tensors).
1042
+ rng_types (list of `str` or [`~utils.RNGType`]):
1043
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
1044
+ several of:
1045
+
1046
+ - `"torch"`: the base torch random number generator
1047
+ - `"cuda"`: the CUDA random number generator (GPU only)
1048
+ - `"xla"`: the XLA random number generator (TPU only)
1049
+ - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
1050
+ dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
1051
+
1052
+ dispatch_batches (`bool`, *optional*):
1053
+ If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
1054
+ are split and broadcast to each process. Will default to `True` when the underlying dataset is an
1055
+ `IterableDataset`, `False` otherwise.
1056
+ even_batches (`bool`, *optional*, defaults to `True`):
1057
+ If set to `True`, in cases where the total batch size across all processes does not exactly divide the
1058
+ dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
1059
+ all workers.
1060
+ slice_fn_for_dispatch (`Callable`, *optional*`):
1061
+ If passed, this function will be used to slice tensors across `num_processes`. Will default to
1062
+ [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
1063
+ ignored otherwise.
1064
+ use_seedable_sampler (`bool`, *optional*, defaults to `False`):
1065
+ Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
1066
+ reproducability. Comes at a cost of potentially different performances due to different shuffling
1067
+ algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
1068
+ `self.set_epoch`
1069
+ data_seed (`int`, *optional*, defaults to `None`):
1070
+ The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
1071
+ will use the current default seed from torch.
1072
+ non_blocking (`bool`, *optional*, defaults to `False`):
1073
+ If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
1074
+ `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
1075
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
1076
+ "If set to true, the dataloader prepared by the Accelerator will be backed by "
1077
+ "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
1078
+ This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
1079
+ torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
1080
+ PyTorch device mesh.
1081
+
1082
+
1083
+ Returns:
1084
+ `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
1085
+
1086
+ <Tip warning={true}>
1087
+
1088
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
1089
+ equal to `False`
1090
+
1091
+ </Tip>
1092
+ """
1093
+ if dispatch_batches is None:
1094
+ if not put_on_device:
1095
+ dispatch_batches = False
1096
+ else:
1097
+ dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
1098
+
1099
+ if dispatch_batches and not put_on_device:
1100
+ raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
1101
+ # Grab defaults from PartialState
1102
+ state = PartialState()
1103
+ if num_processes is None:
1104
+ num_processes = state.num_processes
1105
+
1106
+ if process_index is None:
1107
+ process_index = state.process_index
1108
+
1109
+ if torch_device_mesh:
1110
+ if state.distributed_type == DistributedType.DEEPSPEED:
1111
+ # In DeepSpeed, the optimizer sharing level in DP is determined by the config file.
1112
+ # Only considers "dp" and "tp".
1113
+ # Given a device mesh (dp, tp) = (2, 3):
1114
+ # - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
1115
+ # - Processes with the same DP rank will receive the same batch.
1116
+ submesh_tp_size = 1
1117
+ if "tp" in torch_device_mesh.mesh_dim_names:
1118
+ submesh_tp_size = torch_device_mesh["tp"].size()
1119
+ process_index = process_index // submesh_tp_size
1120
+ num_processes = num_processes // submesh_tp_size
1121
+ else:
1122
+ # when device mesh is used, specifically with TP
1123
+ # then there is need to update process_index and num_processes
1124
+ # to bring in the effect of generating same batch across TP ranks
1125
+ # and different batch across FSDP and DP ranks.
1126
+ # Example:
1127
+ # if device mesh is (dp,fsdp,tp) = (2, 2, 3)
1128
+ # ranks would range from 0...11
1129
+ # from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
1130
+ # processes with same ranks/ids would receive the same batch
1131
+ # for CP the same as TP applies
1132
+ submesh_fsdp_size = 1
1133
+ submesh_dp_size = 1
1134
+ submesh_tp_size = 1
1135
+ submesh_cp_size = 1
1136
+ if "tp" in torch_device_mesh.mesh_dim_names:
1137
+ submesh_tp_size = torch_device_mesh["tp"].size()
1138
+ if "cp" in torch_device_mesh.mesh_dim_names:
1139
+ submesh_cp_size = torch_device_mesh["cp"].size()
1140
+ if "dp_replicate" in torch_device_mesh.mesh_dim_names:
1141
+ submesh_dp_size = torch_device_mesh["dp_replicate"].size()
1142
+ if "dp_shard" in torch_device_mesh.mesh_dim_names:
1143
+ submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
1144
+ process_index = process_index // (submesh_tp_size * submesh_cp_size)
1145
+ num_processes = submesh_fsdp_size * submesh_dp_size
1146
+
1147
+ # Sanity check
1148
+ if split_batches:
1149
+ if dataloader.batch_size is not None:
1150
+ batch_size_for_check = dataloader.batch_size
1151
+ else:
1152
+ # For custom batch_sampler
1153
+ if hasattr(dataloader.batch_sampler, "batch_size"):
1154
+ batch_size_for_check = dataloader.batch_sampler.batch_size
1155
+ else:
1156
+ raise ValueError(
1157
+ "In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
1158
+ "`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
1159
+ "Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
1160
+ f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
1161
+ )
1162
+
1163
+ if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
1164
+ raise ValueError(
1165
+ f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
1166
+ f"needs to be a round multiple of the number of processes ({num_processes})."
1167
+ )
1168
+
1169
+ new_dataset = dataloader.dataset
1170
+ # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
1171
+ new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
1172
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1173
+ synchronized_generator = None
1174
+
1175
+ sampler = get_sampler(dataloader)
1176
+ if isinstance(sampler, RandomSampler) and use_seedable_sampler:
1177
+ # When iterating through the dataloader during distributed processes
1178
+ # we want to ensure that on each process we are iterating through the same
1179
+ # samples in the same order if a seed is set. This requires a tweak
1180
+ # to the `torch.utils.data.RandomSampler` class (if used).
1181
+ sampler = SeedableRandomSampler(
1182
+ data_source=sampler.data_source,
1183
+ replacement=sampler.replacement,
1184
+ num_samples=sampler._num_samples,
1185
+ generator=getattr(
1186
+ sampler,
1187
+ "generator",
1188
+ torch.Generator(device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"),
1189
+ ),
1190
+ data_seed=data_seed,
1191
+ )
1192
+
1193
+ if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
1194
+ # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
1195
+ generator = torch.Generator(
1196
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1197
+ )
1198
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
1199
+ generator.manual_seed(seed)
1200
+ dataloader.generator = generator
1201
+ dataloader.sampler.generator = generator
1202
+ # No change if no multiprocess
1203
+ if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
1204
+ if is_datasets_available():
1205
+ from datasets import IterableDataset as DatasetsIterableDataset
1206
+ if (
1207
+ is_datasets_available()
1208
+ and isinstance(new_dataset, DatasetsIterableDataset)
1209
+ and not split_batches
1210
+ and new_dataset.n_shards > num_processes
1211
+ ):
1212
+ new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)
1213
+ elif isinstance(new_dataset, IterableDataset):
1214
+ if getattr(dataloader.dataset, "generator", None) is not None:
1215
+ synchronized_generator = dataloader.dataset.generator
1216
+ new_dataset = IterableDatasetShard(
1217
+ new_dataset,
1218
+ batch_size=dataloader.batch_size,
1219
+ drop_last=dataloader.drop_last,
1220
+ num_processes=num_processes,
1221
+ process_index=process_index,
1222
+ split_batches=split_batches,
1223
+ )
1224
+ else:
1225
+ if not use_seedable_sampler and hasattr(sampler, "generator"):
1226
+ if sampler.generator is None:
1227
+ sampler.generator = torch.Generator(
1228
+ device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1229
+ )
1230
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
1231
+ sampler.generator.manual_seed(seed)
1232
+ synchronized_generator = sampler.generator
1233
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1234
+ new_batch_sampler = BatchSamplerShard(
1235
+ batch_sampler,
1236
+ num_processes=num_processes,
1237
+ process_index=process_index,
1238
+ split_batches=split_batches,
1239
+ even_batches=even_batches,
1240
+ )
1241
+
1242
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1243
+ ignore_kwargs = [
1244
+ "batch_size",
1245
+ "shuffle",
1246
+ "sampler",
1247
+ "batch_sampler",
1248
+ "drop_last",
1249
+ ]
1250
+
1251
+ if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
1252
+ rng_types.remove("generator")
1253
+
1254
+ kwargs = {
1255
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1256
+ for k in _PYTORCH_DATALOADER_KWARGS
1257
+ if k not in ignore_kwargs
1258
+ }
1259
+
1260
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1261
+ if new_batch_sampler is None:
1262
+ kwargs["drop_last"] = dataloader.drop_last
1263
+ kwargs["batch_size"] = (
1264
+ dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
1265
+ )
1266
+ if dispatch_batches:
1267
+ kwargs.pop("generator")
1268
+ dataloader = DataLoaderDispatcher(
1269
+ new_dataset,
1270
+ split_batches=split_batches,
1271
+ batch_sampler=new_batch_sampler,
1272
+ _drop_last=dataloader.drop_last,
1273
+ _non_blocking=non_blocking,
1274
+ slice_fn=slice_fn_for_dispatch,
1275
+ use_stateful_dataloader=use_stateful_dataloader,
1276
+ torch_device_mesh=torch_device_mesh,
1277
+ **kwargs,
1278
+ )
1279
+ elif sampler_is_batch_sampler:
1280
+ dataloader = DataLoaderShard(
1281
+ new_dataset,
1282
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1283
+ sampler=new_batch_sampler,
1284
+ batch_size=dataloader.batch_size,
1285
+ rng_types=rng_types,
1286
+ _drop_last=dataloader.drop_last,
1287
+ _non_blocking=non_blocking,
1288
+ synchronized_generator=synchronized_generator,
1289
+ use_stateful_dataloader=use_stateful_dataloader,
1290
+ **kwargs,
1291
+ )
1292
+ else:
1293
+ dataloader = DataLoaderShard(
1294
+ new_dataset,
1295
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1296
+ batch_sampler=new_batch_sampler,
1297
+ rng_types=rng_types,
1298
+ synchronized_generator=synchronized_generator,
1299
+ _drop_last=dataloader.drop_last,
1300
+ _non_blocking=non_blocking,
1301
+ use_stateful_dataloader=use_stateful_dataloader,
1302
+ **kwargs,
1303
+ )
1304
+
1305
+ if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
1306
+ dataloader.set_sampler(sampler)
1307
+ if state.distributed_type == DistributedType.XLA:
1308
+ return MpDeviceLoaderWrapper(dataloader, device)
1309
+ return dataloader
1310
+
1311
+
1312
+ class SkipBatchSampler(BatchSampler):
1313
+ """
1314
+ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
1315
+ Should not be used if the original dataloader is a `StatefulDataLoader`.
1316
+ """
1317
+
1318
+ def __init__(self, batch_sampler, skip_batches=0):
1319
+ self.batch_sampler = batch_sampler
1320
+ self.skip_batches = skip_batches
1321
+
1322
+ def __iter__(self):
1323
+ for index, samples in enumerate(self.batch_sampler):
1324
+ if index >= self.skip_batches:
1325
+ yield samples
1326
+
1327
+ @property
1328
+ def total_length(self):
1329
+ return len(self.batch_sampler)
1330
+
1331
+ def __len__(self):
1332
+ return len(self.batch_sampler) - self.skip_batches
1333
+
1334
+
1335
+ class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
1336
+ """
1337
+ Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
1338
+ `skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
1339
+
1340
+ Args:
1341
+ dataset (`torch.utils.data.dataset.Dataset`):
1342
+ The dataset to use to build this dataloader.
1343
+ skip_batches (`int`, *optional*, defaults to 0):
1344
+ The number of batches to skip at the beginning.
1345
+ kwargs:
1346
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
1347
+ """
1348
+
1349
+ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
1350
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
1351
+ self.skip_batches = skip_batches
1352
+ self.gradient_state = GradientState()
1353
+
1354
+ def __iter__(self):
1355
+ self.begin()
1356
+ for index, batch in enumerate(self.base_dataloader.__iter__()):
1357
+ if index >= self.skip_batches:
1358
+ self._update_state_dict()
1359
+ yield batch
1360
+ self.end()
1361
+
1362
+ def __len__(self):
1363
+ return len(self.base_dataloader) - self.skip_batches
1364
+
1365
+ def __reduce__(self):
1366
+ """
1367
+ Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
1368
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
1369
+ `__class__` member.
1370
+ """
1371
+ args = super().__reduce__()
1372
+ return (SkipDataLoader, *args[1:])
1373
+
1374
+
1375
+ def skip_first_batches(dataloader, num_batches=0):
1376
+ """
1377
+ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
1378
+ the original dataloader is a `StatefulDataLoader`.
1379
+ """
1380
+ state = PartialState()
1381
+ if state.distributed_type == DistributedType.XLA:
1382
+ device = dataloader.device
1383
+ dataloader = dataloader.dataloader
1384
+
1385
+ dataset = dataloader.dataset
1386
+ sampler_is_batch_sampler = False
1387
+ if isinstance(dataset, IterableDataset):
1388
+ new_batch_sampler = None
1389
+ else:
1390
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1391
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1392
+ new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
1393
+
1394
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1395
+ ignore_kwargs = [
1396
+ "batch_size",
1397
+ "shuffle",
1398
+ "sampler",
1399
+ "batch_sampler",
1400
+ "drop_last",
1401
+ ]
1402
+
1403
+ kwargs = {
1404
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1405
+ for k in _PYTORCH_DATALOADER_KWARGS
1406
+ if k not in ignore_kwargs
1407
+ }
1408
+
1409
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1410
+ if new_batch_sampler is None:
1411
+ kwargs["drop_last"] = dataloader.drop_last
1412
+ kwargs["batch_size"] = dataloader.batch_size
1413
+
1414
+ if isinstance(dataloader, DataLoaderDispatcher):
1415
+ if new_batch_sampler is None:
1416
+ # Need to manually skip batches in the dataloader
1417
+ kwargs["skip_batches"] = num_batches
1418
+ dataloader = DataLoaderDispatcher(
1419
+ dataset,
1420
+ split_batches=dataloader.split_batches,
1421
+ batch_sampler=new_batch_sampler,
1422
+ _drop_last=dataloader._drop_last,
1423
+ **kwargs,
1424
+ )
1425
+ elif isinstance(dataloader, DataLoaderShard):
1426
+ if new_batch_sampler is None:
1427
+ # Need to manually skip batches in the dataloader
1428
+ kwargs["skip_batches"] = num_batches
1429
+ elif sampler_is_batch_sampler:
1430
+ kwargs["sampler"] = new_batch_sampler
1431
+ kwargs["batch_size"] = dataloader.batch_size
1432
+ else:
1433
+ kwargs["batch_sampler"] = new_batch_sampler
1434
+ dataloader = DataLoaderShard(
1435
+ dataset,
1436
+ device=dataloader.device,
1437
+ rng_types=dataloader.rng_types,
1438
+ synchronized_generator=dataloader.synchronized_generator,
1439
+ **kwargs,
1440
+ )
1441
+ else:
1442
+ if new_batch_sampler is None:
1443
+ # Need to manually skip batches in the dataloader
1444
+ dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
1445
+ else:
1446
+ dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
1447
+
1448
+ if state.distributed_type == DistributedType.XLA:
1449
+ dataloader = MpDeviceLoaderWrapper(dataloader, device)
1450
+
1451
+ return dataloader
pythonProject/.venv/Lib/site-packages/accelerate/hooks.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from collections.abc import Mapping
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from .state import PartialState
23
+ from .utils import (
24
+ PrefixedDataset,
25
+ find_device,
26
+ named_module_tensors,
27
+ send_to_device,
28
+ set_module_tensor_to_device,
29
+ )
30
+ from .utils.imports import (
31
+ is_mlu_available,
32
+ is_musa_available,
33
+ is_npu_available,
34
+ )
35
+ from .utils.memory import clear_device_cache
36
+ from .utils.modeling import get_non_persistent_buffers
37
+ from .utils.other import recursive_getattr
38
+
39
+
40
+ _accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "sdaa", "musa"]
41
+
42
+
43
+ class ModelHook:
44
+ """
45
+ A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
46
+ with PyTorch existing hooks is that they get passed along the kwargs.
47
+
48
+ Class attribute:
49
+ - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
50
+ the `torch.no_grad()` context manager.
51
+ """
52
+
53
+ no_grad = False
54
+
55
+ def init_hook(self, module):
56
+ """
57
+ To be executed when the hook is attached to the module.
58
+
59
+ Args:
60
+ module (`torch.nn.Module`): The module attached to this hook.
61
+ """
62
+ return module
63
+
64
+ def pre_forward(self, module, *args, **kwargs):
65
+ """
66
+ To be executed just before the forward method of the model.
67
+
68
+ Args:
69
+ module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
70
+ args (`Tuple[Any]`): The positional arguments passed to the module.
71
+ kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
72
+
73
+ Returns:
74
+ `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
75
+ """
76
+ return args, kwargs
77
+
78
+ def post_forward(self, module, output):
79
+ """
80
+ To be executed just after the forward method of the model.
81
+
82
+ Args:
83
+ module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
84
+ output (`Any`): The output of the module.
85
+
86
+ Returns:
87
+ `Any`: The processed `output`.
88
+ """
89
+ return output
90
+
91
+ def detach_hook(self, module):
92
+ """
93
+ To be executed when the hook is detached from a module.
94
+
95
+ Args:
96
+ module (`torch.nn.Module`): The module detached from this hook.
97
+ """
98
+ return module
99
+
100
+
101
+ class SequentialHook(ModelHook):
102
+ """
103
+ A hook that can contain several hooks and iterates through them at each event.
104
+ """
105
+
106
+ def __init__(self, *hooks):
107
+ self.hooks = hooks
108
+
109
+ def init_hook(self, module):
110
+ for hook in self.hooks:
111
+ module = hook.init_hook(module)
112
+ return module
113
+
114
+ def pre_forward(self, module, *args, **kwargs):
115
+ for hook in self.hooks:
116
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
117
+ return args, kwargs
118
+
119
+ def post_forward(self, module, output):
120
+ for hook in self.hooks:
121
+ output = hook.post_forward(module, output)
122
+ return output
123
+
124
+ def detach_hook(self, module):
125
+ for hook in self.hooks:
126
+ module = hook.detach_hook(module)
127
+ return module
128
+
129
+
130
+ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
131
+ """
132
+ Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
133
+ this behavior and restore the original `forward` method, use `remove_hook_from_module`.
134
+
135
+ <Tip warning={true}>
136
+
137
+ If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
138
+ together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
139
+
140
+ </Tip>
141
+
142
+ Args:
143
+ module (`torch.nn.Module`):
144
+ The module to attach a hook to.
145
+ hook (`ModelHook`):
146
+ The hook to attach.
147
+ append (`bool`, *optional*, defaults to `False`):
148
+ Whether the hook should be chained with an existing one (if module already contains a hook) or not.
149
+
150
+ Returns:
151
+ `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
152
+ be discarded).
153
+ """
154
+ if append and (getattr(module, "_hf_hook", None) is not None):
155
+ old_hook = module._hf_hook
156
+ remove_hook_from_module(module)
157
+ hook = SequentialHook(old_hook, hook)
158
+
159
+ if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
160
+ # If we already put some hook on this module, we replace it with the new one.
161
+ old_forward = module._old_forward
162
+ else:
163
+ old_forward = module.forward
164
+ module._old_forward = old_forward
165
+
166
+ module = hook.init_hook(module)
167
+ module._hf_hook = hook
168
+
169
+ def new_forward(module, *args, **kwargs):
170
+ args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
171
+ if module._hf_hook.no_grad:
172
+ with torch.no_grad():
173
+ output = module._old_forward(*args, **kwargs)
174
+ else:
175
+ output = module._old_forward(*args, **kwargs)
176
+ return module._hf_hook.post_forward(module, output)
177
+
178
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
179
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
180
+ if "GraphModuleImpl" in str(type(module)):
181
+ module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
182
+ else:
183
+ module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
184
+
185
+ return module
186
+
187
+
188
+ def remove_hook_from_module(module: nn.Module, recurse=False):
189
+ """
190
+ Removes any hook attached to a module via `add_hook_to_module`.
191
+
192
+ Args:
193
+ module (`torch.nn.Module`): The module to attach a hook to.
194
+ recurse (`bool`, **optional**): Whether to remove the hooks recursively
195
+
196
+ Returns:
197
+ `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
198
+ be discarded).
199
+ """
200
+
201
+ if hasattr(module, "_hf_hook"):
202
+ module._hf_hook.detach_hook(module)
203
+ delattr(module, "_hf_hook")
204
+
205
+ if hasattr(module, "_old_forward"):
206
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
207
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
208
+ if "GraphModuleImpl" in str(type(module)):
209
+ module.__class__.forward = module._old_forward
210
+ else:
211
+ module.forward = module._old_forward
212
+ delattr(module, "_old_forward")
213
+
214
+ # Remove accelerate added warning hooks from dispatch_model
215
+ for attr in _accelerate_added_attributes:
216
+ module.__dict__.pop(attr, None)
217
+
218
+ if recurse:
219
+ for child in module.children():
220
+ remove_hook_from_module(child, recurse)
221
+
222
+ return module
223
+
224
+
225
+ class AlignDevicesHook(ModelHook):
226
+ """
227
+ A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
228
+ associated module, potentially offloading the weights after the forward pass.
229
+
230
+ Args:
231
+ execution_device (`torch.device`, *optional*):
232
+ The device on which inputs and model weights should be placed before the forward pass.
233
+ offload (`bool`, *optional*, defaults to `False`):
234
+ Whether or not the weights should be offloaded after the forward pass.
235
+ io_same_device (`bool`, *optional*, defaults to `False`):
236
+ Whether or not the output should be placed on the same device as the input was.
237
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
238
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
239
+ offload_buffers (`bool`, *optional*, defaults to `False`):
240
+ Whether or not to include the associated module's buffers when offloading.
241
+ place_submodules (`bool`, *optional*, defaults to `False`):
242
+ Whether to place the submodules on `execution_device` during the `init_hook` event.
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ execution_device: Optional[Union[int, str, torch.device]] = None,
248
+ offload: bool = False,
249
+ io_same_device: bool = False,
250
+ weights_map: Optional[Mapping] = None,
251
+ offload_buffers: bool = False,
252
+ place_submodules: bool = False,
253
+ skip_keys: Optional[Union[str, list[str]]] = None,
254
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
255
+ ):
256
+ self.execution_device = execution_device
257
+ self.offload = offload
258
+ self.io_same_device = io_same_device
259
+ self.weights_map = weights_map
260
+ self.offload_buffers = offload_buffers
261
+ self.place_submodules = place_submodules
262
+ self.skip_keys = skip_keys
263
+
264
+ # Will contain the input device when `io_same_device=True`.
265
+ self.input_device = None
266
+ self.param_original_devices = {}
267
+ self.buffer_original_devices = {}
268
+ self.tied_params_names = set()
269
+
270
+ # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
271
+ # for tied weights already loaded on the target execution device.
272
+ self.tied_params_map = tied_params_map
273
+
274
+ def __repr__(self):
275
+ return (
276
+ f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
277
+ f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
278
+ f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
279
+ )
280
+
281
+ def init_hook(self, module):
282
+ # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
283
+ if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
284
+ self.tied_params_map = None
285
+
286
+ if not self.offload and self.execution_device is not None:
287
+ for name, _ in named_module_tensors(module, recurse=self.place_submodules):
288
+ set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
289
+ elif self.offload:
290
+ self.original_devices = {
291
+ name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
292
+ }
293
+ if self.weights_map is None:
294
+ self.weights_map = {
295
+ name: param.to("cpu")
296
+ for name, param in named_module_tensors(
297
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules
298
+ )
299
+ }
300
+ for name, _ in named_module_tensors(
301
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
302
+ ):
303
+ # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
304
+ # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
305
+ # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
306
+ # to add on the fly pointers to `tied_params_map` in the pre_forward call.
307
+ if (
308
+ self.tied_params_map is not None
309
+ and recursive_getattr(module, name).data_ptr() in self.tied_params_map
310
+ ):
311
+ self.tied_params_names.add(name)
312
+
313
+ set_module_tensor_to_device(module, name, "meta")
314
+
315
+ if not self.offload_buffers and self.execution_device is not None:
316
+ for name, _ in module.named_buffers(recurse=self.place_submodules):
317
+ set_module_tensor_to_device(
318
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
319
+ )
320
+ elif self.offload_buffers and self.execution_device is not None:
321
+ for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
322
+ set_module_tensor_to_device(
323
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
324
+ )
325
+
326
+ return module
327
+
328
+ def pre_forward(self, module, *args, **kwargs):
329
+ if self.io_same_device:
330
+ self.input_device = find_device([args, kwargs])
331
+ if self.offload:
332
+ self.tied_pointers_to_remove = set()
333
+
334
+ for name, _ in named_module_tensors(
335
+ module,
336
+ include_buffers=self.offload_buffers,
337
+ recurse=self.place_submodules,
338
+ remove_non_persistent=True,
339
+ ):
340
+ fp16_statistics = None
341
+ value = self.weights_map[name]
342
+ if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
343
+ if value.dtype == torch.int8:
344
+ fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
345
+
346
+ # In case we are using offloading with tied weights, we need to keep track of the offloaded weights
347
+ # that are loaded on device at this point, as we will need to remove them as well from the dictionary
348
+ # self.tied_params_map in order to allow to free memory.
349
+ if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
350
+ self.tied_params_map[value.data_ptr()] = {}
351
+
352
+ if (
353
+ value is not None
354
+ and self.tied_params_map is not None
355
+ and value.data_ptr() in self.tied_params_map
356
+ and self.execution_device not in self.tied_params_map[value.data_ptr()]
357
+ ):
358
+ self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
359
+
360
+ set_module_tensor_to_device(
361
+ module,
362
+ name,
363
+ self.execution_device,
364
+ value=value,
365
+ fp16_statistics=fp16_statistics,
366
+ tied_params_map=self.tied_params_map,
367
+ )
368
+
369
+ return send_to_device(args, self.execution_device), send_to_device(
370
+ kwargs, self.execution_device, skip_keys=self.skip_keys
371
+ )
372
+
373
+ def post_forward(self, module, output):
374
+ if self.offload:
375
+ for name, _ in named_module_tensors(
376
+ module,
377
+ include_buffers=self.offload_buffers,
378
+ recurse=self.place_submodules,
379
+ remove_non_persistent=True,
380
+ ):
381
+ set_module_tensor_to_device(module, name, "meta")
382
+ if type(module).__name__ == "Linear8bitLt":
383
+ module.state.SCB = None
384
+ module.state.CxB = None
385
+
386
+ # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
387
+ # this dictionary to allow the garbage collector to do its job.
388
+ for value_pointer, device in self.tied_pointers_to_remove:
389
+ if isinstance(device, int):
390
+ if is_npu_available():
391
+ device = f"npu:{device}"
392
+ elif is_mlu_available():
393
+ device = f"mlu:{device}"
394
+ elif is_musa_available():
395
+ device = f"musa:{device}"
396
+ if device in self.tied_params_map[value_pointer]:
397
+ del self.tied_params_map[value_pointer][device]
398
+ self.tied_pointers_to_remove = set()
399
+ if self.io_same_device and self.input_device is not None:
400
+ output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
401
+
402
+ return output
403
+
404
+ def detach_hook(self, module):
405
+ if self.offload:
406
+ for name, device in self.original_devices.items():
407
+ if device != torch.device("meta"):
408
+ set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
409
+ return module
410
+
411
+
412
+ def attach_execution_device_hook(
413
+ module: torch.nn.Module,
414
+ execution_device: Union[int, str, torch.device],
415
+ skip_keys: Optional[Union[str, list[str]]] = None,
416
+ preload_module_classes: Optional[list[str]] = None,
417
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
418
+ ):
419
+ """
420
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
421
+ execution device
422
+
423
+ Args:
424
+ module (`torch.nn.Module`):
425
+ The module where we want to attach the hooks.
426
+ execution_device (`int`, `str` or `torch.device`):
427
+ The device on which inputs and model weights should be placed before the forward pass.
428
+ skip_keys (`str` or `List[str]`, *optional*):
429
+ A list of keys to ignore when moving inputs or outputs between devices.
430
+ preload_module_classes (`List[str]`, *optional*):
431
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
432
+ of the forward. This should only be used for classes that have submodules which are registered but not
433
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
434
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
435
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
436
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
437
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
438
+ instead of duplicating memory.
439
+ """
440
+ if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
441
+ add_hook_to_module(
442
+ module,
443
+ AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
444
+ )
445
+
446
+ # Break the recursion if we get to a preload module.
447
+ if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
448
+ return
449
+
450
+ for child in module.children():
451
+ attach_execution_device_hook(
452
+ child,
453
+ execution_device,
454
+ skip_keys=skip_keys,
455
+ preload_module_classes=preload_module_classes,
456
+ tied_params_map=tied_params_map,
457
+ )
458
+
459
+
460
+ def attach_align_device_hook(
461
+ module: torch.nn.Module,
462
+ execution_device: Optional[torch.device] = None,
463
+ offload: bool = False,
464
+ weights_map: Optional[Mapping] = None,
465
+ offload_buffers: bool = False,
466
+ module_name: str = "",
467
+ skip_keys: Optional[Union[str, list[str]]] = None,
468
+ preload_module_classes: Optional[list[str]] = None,
469
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
470
+ ):
471
+ """
472
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
473
+ buffers.
474
+
475
+ Args:
476
+ module (`torch.nn.Module`):
477
+ The module where we want to attach the hooks.
478
+ execution_device (`torch.device`, *optional*):
479
+ The device on which inputs and model weights should be placed before the forward pass.
480
+ offload (`bool`, *optional*, defaults to `False`):
481
+ Whether or not the weights should be offloaded after the forward pass.
482
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
483
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
484
+ offload_buffers (`bool`, *optional*, defaults to `False`):
485
+ Whether or not to include the associated module's buffers when offloading.
486
+ module_name (`str`, *optional*, defaults to `""`):
487
+ The name of the module.
488
+ skip_keys (`str` or `List[str]`, *optional*):
489
+ A list of keys to ignore when moving inputs or outputs between devices.
490
+ preload_module_classes (`List[str]`, *optional*):
491
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
492
+ of the forward. This should only be used for classes that have submodules which are registered but not
493
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
494
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
495
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
496
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
497
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
498
+ instead of duplicating memory.
499
+ """
500
+ # Attach the hook on this module if it has any direct tensor.
501
+ directs = named_module_tensors(module)
502
+ full_offload = (
503
+ offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
504
+ )
505
+
506
+ if len(list(directs)) > 0 or full_offload:
507
+ if weights_map is not None:
508
+ prefix = f"{module_name}." if len(module_name) > 0 else ""
509
+ prefixed_weights_map = PrefixedDataset(weights_map, prefix)
510
+ else:
511
+ prefixed_weights_map = None
512
+ hook = AlignDevicesHook(
513
+ execution_device=execution_device,
514
+ offload=offload,
515
+ weights_map=prefixed_weights_map,
516
+ offload_buffers=offload_buffers,
517
+ place_submodules=full_offload,
518
+ skip_keys=skip_keys,
519
+ tied_params_map=tied_params_map,
520
+ )
521
+ add_hook_to_module(module, hook, append=True)
522
+
523
+ # We stop the recursion in case we hit the full offload.
524
+ if full_offload:
525
+ return
526
+
527
+ # Recurse on all children of the module.
528
+ for child_name, child in module.named_children():
529
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
530
+ attach_align_device_hook(
531
+ child,
532
+ execution_device=execution_device,
533
+ offload=offload,
534
+ weights_map=weights_map,
535
+ offload_buffers=offload_buffers,
536
+ module_name=child_name,
537
+ preload_module_classes=preload_module_classes,
538
+ skip_keys=skip_keys,
539
+ tied_params_map=tied_params_map,
540
+ )
541
+
542
+
543
+ def remove_hook_from_submodules(module: nn.Module):
544
+ """
545
+ Recursively removes all hooks attached on the submodules of a given model.
546
+
547
+ Args:
548
+ module (`torch.nn.Module`): The module on which to remove all hooks.
549
+ """
550
+ remove_hook_from_module(module)
551
+ for child in module.children():
552
+ remove_hook_from_submodules(child)
553
+
554
+
555
+ def attach_align_device_hook_on_blocks(
556
+ module: nn.Module,
557
+ execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,
558
+ offload: Union[bool, dict[str, bool]] = False,
559
+ weights_map: Mapping = None,
560
+ offload_buffers: bool = False,
561
+ module_name: str = "",
562
+ skip_keys: Optional[Union[str, list[str]]] = None,
563
+ preload_module_classes: Optional[list[str]] = None,
564
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
565
+ ):
566
+ """
567
+ Attaches `AlignDevicesHook` to all blocks of a given model as needed.
568
+
569
+ Args:
570
+ module (`torch.nn.Module`):
571
+ The module where we want to attach the hooks.
572
+ execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
573
+ The device on which inputs and model weights should be placed before the forward pass. It can be one device
574
+ for the whole module, or a dictionary mapping module name to device.
575
+ offload (`bool`, *optional*, defaults to `False`):
576
+ Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
577
+ module, or a dictionary mapping module name to boolean.
578
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
579
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
580
+ offload_buffers (`bool`, *optional*, defaults to `False`):
581
+ Whether or not to include the associated module's buffers when offloading.
582
+ module_name (`str`, *optional*, defaults to `""`):
583
+ The name of the module.
584
+ skip_keys (`str` or `List[str]`, *optional*):
585
+ A list of keys to ignore when moving inputs or outputs between devices.
586
+ preload_module_classes (`List[str]`, *optional*):
587
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
588
+ of the forward. This should only be used for classes that have submodules which are registered but not
589
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
590
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
591
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
592
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
593
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
594
+ instead of duplicating memory.
595
+ """
596
+ # If one device and one offload, we've got one hook.
597
+ if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
598
+ if not offload:
599
+ hook = AlignDevicesHook(
600
+ execution_device=execution_device,
601
+ io_same_device=True,
602
+ skip_keys=skip_keys,
603
+ place_submodules=True,
604
+ tied_params_map=tied_params_map,
605
+ )
606
+ add_hook_to_module(module, hook)
607
+ else:
608
+ attach_align_device_hook(
609
+ module,
610
+ execution_device=execution_device,
611
+ offload=True,
612
+ weights_map=weights_map,
613
+ offload_buffers=offload_buffers,
614
+ module_name=module_name,
615
+ skip_keys=skip_keys,
616
+ tied_params_map=tied_params_map,
617
+ )
618
+ return
619
+
620
+ if not isinstance(execution_device, Mapping):
621
+ execution_device = {key: execution_device for key in offload.keys()}
622
+ if not isinstance(offload, Mapping):
623
+ offload = {key: offload for key in execution_device.keys()}
624
+
625
+ if module_name in execution_device and module_name in offload and not offload[module_name]:
626
+ hook = AlignDevicesHook(
627
+ execution_device=execution_device[module_name],
628
+ offload_buffers=offload_buffers,
629
+ io_same_device=(module_name == ""),
630
+ place_submodules=True,
631
+ skip_keys=skip_keys,
632
+ tied_params_map=tied_params_map,
633
+ )
634
+ add_hook_to_module(module, hook)
635
+ attach_execution_device_hook(
636
+ module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
637
+ )
638
+ elif module_name in execution_device and module_name in offload:
639
+ attach_align_device_hook(
640
+ module,
641
+ execution_device=execution_device[module_name],
642
+ offload=True,
643
+ weights_map=weights_map,
644
+ offload_buffers=offload_buffers,
645
+ module_name=module_name,
646
+ skip_keys=skip_keys,
647
+ preload_module_classes=preload_module_classes,
648
+ tied_params_map=tied_params_map,
649
+ )
650
+ if not hasattr(module, "_hf_hook"):
651
+ hook = AlignDevicesHook(
652
+ execution_device=execution_device[module_name],
653
+ io_same_device=(module_name == ""),
654
+ skip_keys=skip_keys,
655
+ tied_params_map=tied_params_map,
656
+ )
657
+ add_hook_to_module(module, hook)
658
+ attach_execution_device_hook(
659
+ module,
660
+ execution_device[module_name],
661
+ preload_module_classes=preload_module_classes,
662
+ skip_keys=skip_keys,
663
+ tied_params_map=tied_params_map,
664
+ )
665
+ elif module_name == "":
666
+ hook = AlignDevicesHook(
667
+ execution_device=execution_device.get(""),
668
+ io_same_device=True,
669
+ skip_keys=skip_keys,
670
+ tied_params_map=tied_params_map,
671
+ )
672
+ add_hook_to_module(module, hook)
673
+
674
+ for child_name, child in module.named_children():
675
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
676
+ attach_align_device_hook_on_blocks(
677
+ child,
678
+ execution_device=execution_device,
679
+ offload=offload,
680
+ weights_map=weights_map,
681
+ offload_buffers=offload_buffers,
682
+ module_name=child_name,
683
+ preload_module_classes=preload_module_classes,
684
+ skip_keys=skip_keys,
685
+ tied_params_map=tied_params_map,
686
+ )
687
+
688
+
689
+ class CpuOffload(ModelHook):
690
+ """
691
+ Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
692
+ the forward, the user needs to call the `init_hook` method again for this.
693
+
694
+ Args:
695
+ execution_device(`str`, `int` or `torch.device`, *optional*):
696
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
697
+ GPU 0 if there is a GPU, and finally to the CPU.
698
+ prev_module_hook (`UserCpuOffloadHook`, *optional*):
699
+ The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
700
+ passed, its offload method will be called just before the forward of the model to which this hook is
701
+ attached.
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ execution_device: Optional[Union[str, int, torch.device]] = None,
707
+ prev_module_hook: Optional["UserCpuOffloadHook"] = None,
708
+ ):
709
+ self.prev_module_hook = prev_module_hook
710
+
711
+ self.execution_device = execution_device if execution_device is not None else PartialState().default_device
712
+
713
+ def init_hook(self, module):
714
+ return module.to("cpu")
715
+
716
+ def pre_forward(self, module, *args, **kwargs):
717
+ if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
718
+ prev_module = self.prev_module_hook.model
719
+ prev_device = next(prev_module.parameters()).device
720
+
721
+ # Only offload the previous module if it is not already on CPU.
722
+ if prev_device != torch.device("cpu"):
723
+ self.prev_module_hook.offload()
724
+ clear_device_cache()
725
+
726
+ # If the current device is already the self.execution_device, we can skip the transfer.
727
+ current_device = next(module.parameters()).device
728
+ if current_device == self.execution_device:
729
+ return args, kwargs
730
+
731
+ module.to(self.execution_device)
732
+ return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
733
+
734
+
735
+ class UserCpuOffloadHook:
736
+ """
737
+ A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
738
+ or remove it entirely.
739
+ """
740
+
741
+ def __init__(self, model, hook):
742
+ self.model = model
743
+ self.hook = hook
744
+
745
+ def offload(self):
746
+ self.hook.init_hook(self.model)
747
+
748
+ def remove(self):
749
+ remove_hook_from_module(self.model)
750
+
751
+
752
+ class LayerwiseCastingHook(ModelHook):
753
+ r"""
754
+ A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
755
+ for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
756
+ footprint.
757
+ """
758
+
759
+ _is_stateful = False
760
+
761
+ def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
762
+ self.storage_dtype = storage_dtype
763
+ self.compute_dtype = compute_dtype
764
+ self.non_blocking = non_blocking
765
+
766
+ def init_hook(self, module: torch.nn.Module):
767
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
768
+ return module
769
+
770
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
771
+ module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
772
+ return args, kwargs
773
+
774
+ def post_forward(self, module: torch.nn.Module, output):
775
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
776
+ return output
pythonProject/.venv/Lib/site-packages/accelerate/inference.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from types import MethodType
16
+ from typing import Any, Optional, Union
17
+
18
+ from .state import PartialState
19
+ from .utils import (
20
+ calculate_maximum_sizes,
21
+ convert_bytes,
22
+ copy_tensor_to_devices,
23
+ ignorant_find_batch_size,
24
+ infer_auto_device_map,
25
+ is_pippy_available,
26
+ pad_input_tensors,
27
+ send_to_device,
28
+ )
29
+
30
+
31
+ def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
32
+ """
33
+ Calculates the device map for `model` with an offset for PiPPy
34
+ """
35
+ if num_processes == 1:
36
+ return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
37
+ if max_memory is None:
38
+ model_size, shared = calculate_maximum_sizes(model)
39
+
40
+ # Split into `n` chunks for each GPU
41
+ memory = (model_size + shared[0]) / num_processes
42
+ memory = convert_bytes(memory)
43
+ value, ending = memory.split(" ")
44
+
45
+ # Add a chunk to deal with potential extra shared memory instances
46
+ memory = math.ceil(float(value)) * 1.1
47
+ memory = f"{memory} {ending}"
48
+ max_memory = {i: memory for i in range(num_processes)}
49
+ device_map = infer_auto_device_map(
50
+ model,
51
+ max_memory=max_memory,
52
+ no_split_module_classes=no_split_module_classes,
53
+ clean_result=False,
54
+ )
55
+ return device_map
56
+
57
+
58
+ def find_pippy_batch_size(args, kwargs):
59
+ found_batch_size = None
60
+ if args is not None:
61
+ for arg in args:
62
+ found_batch_size = ignorant_find_batch_size(arg)
63
+ if found_batch_size is not None:
64
+ break
65
+ if kwargs is not None and found_batch_size is None:
66
+ for kwarg in kwargs.values():
67
+ found_batch_size = ignorant_find_batch_size(kwarg)
68
+ if found_batch_size is not None:
69
+ break
70
+ return found_batch_size
71
+
72
+
73
+ def build_pipeline(model, split_points, args, kwargs, num_chunks):
74
+ """
75
+ Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
76
+ in needed `args` and `kwargs` as the model needs on the CPU.
77
+
78
+ Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
79
+ `AcceleratorState.num_processes`
80
+ """
81
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
82
+ from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
83
+
84
+ # We need to annotate the split points in the model for PiPPy
85
+ state = PartialState()
86
+ split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
87
+ pipe = pipeline(
88
+ model,
89
+ mb_args=args,
90
+ mb_kwargs=kwargs,
91
+ split_spec=split_spec,
92
+ )
93
+ stage = pipe.build_stage(state.local_process_index, device=state.device)
94
+ schedule = ScheduleGPipe(stage, num_chunks)
95
+
96
+ return schedule
97
+
98
+
99
+ def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
100
+ state = PartialState()
101
+ output = None
102
+
103
+ if state.num_processes == 1:
104
+ output = forward(*args, **kwargs)
105
+ elif state.is_local_main_process:
106
+ found_batch_size = find_pippy_batch_size(args, kwargs)
107
+ if found_batch_size is None:
108
+ raise ValueError("Could not find batch size from args or kwargs")
109
+ else:
110
+ if found_batch_size != num_chunks:
111
+ args = pad_input_tensors(args, found_batch_size, num_chunks)
112
+ kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
113
+ forward(*args, **kwargs)
114
+ elif state.is_last_process:
115
+ output = forward()
116
+ else:
117
+ forward()
118
+ if gather_output:
119
+ # Each node will get a copy of the full output which is only on the last GPU
120
+ output = copy_tensor_to_devices(output)
121
+ return output
122
+
123
+
124
+ def prepare_pippy(
125
+ model,
126
+ split_points: Optional[Union[str, list[str]]] = "auto",
127
+ no_split_module_classes: Optional[list[str]] = None,
128
+ example_args: Optional[tuple[Any]] = (),
129
+ example_kwargs: Optional[dict[str, Any]] = None,
130
+ num_chunks: Optional[int] = None,
131
+ gather_output: Optional[bool] = False,
132
+ ):
133
+ """
134
+ Wraps `model` for pipeline parallel inference.
135
+
136
+ Args:
137
+ model (`torch.nn.Module`):
138
+ A model we want to split for pipeline-parallel inference
139
+ split_points (`str` or `List[str]`, defaults to 'auto'):
140
+ How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
141
+ split given any model. Should be a list of layer names in the model to split by otherwise.
142
+ no_split_module_classes (`List[str]`):
143
+ A list of class names for layers we don't want to be split.
144
+ example_args (tuple of model inputs):
145
+ The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
146
+ this method if possible.
147
+ example_kwargs (dict of model inputs)
148
+ The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
149
+ *highly* limiting structure that requires the same keys be present at *all* inference calls. Not
150
+ recommended unless the prior condition is true for all cases.
151
+ num_chunks (`int`, defaults to the number of available GPUs):
152
+ The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
153
+ this can be tuned and played with. In general one should have num_chunks >= num_gpus.
154
+ gather_output (`bool`, defaults to `False`):
155
+ If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
156
+ """
157
+ if not is_pippy_available():
158
+ raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
159
+ state = PartialState()
160
+ example_args = send_to_device(example_args, "cpu")
161
+ example_kwargs = send_to_device(example_kwargs, "cpu")
162
+ if num_chunks is None:
163
+ num_chunks = state.num_processes
164
+ if split_points == "auto":
165
+ device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
166
+ split_points = []
167
+ for i in range(1, num_chunks):
168
+ split_points.append(next(k for k, v in device_map.items() if v == i))
169
+ model.hf_split_points = split_points
170
+ stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
171
+ model._original_forward = model.forward
172
+ model._original_call = model.__call__
173
+ model.pippy_stage = stage
174
+ model.hf_split_points = split_points
175
+
176
+ def forward(*args, **kwargs):
177
+ return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
178
+
179
+ # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
180
+ # Note: creates an infinite recursion loop with `generate`
181
+ model_forward = MethodType(forward, model)
182
+ forward.__wrapped__ = model_forward
183
+ model.forward = forward
184
+ return model
pythonProject/.venv/Lib/site-packages/accelerate/launchers.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ import tempfile
18
+
19
+ import torch
20
+
21
+ from .state import AcceleratorState, PartialState
22
+ from .utils import (
23
+ PrecisionType,
24
+ PrepareForLaunch,
25
+ are_libraries_initialized,
26
+ check_cuda_p2p_ib_support,
27
+ get_gpu_info,
28
+ is_mps_available,
29
+ is_torch_version,
30
+ patch_environment,
31
+ )
32
+ from .utils.constants import ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION
33
+
34
+
35
+ def test_launch():
36
+ "Verify a `PartialState` can be initialized."
37
+ _ = PartialState()
38
+
39
+
40
+ def notebook_launcher(
41
+ function,
42
+ args=(),
43
+ num_processes=None,
44
+ mixed_precision="no",
45
+ use_port="29500",
46
+ master_addr="127.0.0.1",
47
+ node_rank=0,
48
+ num_nodes=1,
49
+ rdzv_backend="static",
50
+ rdzv_endpoint="",
51
+ rdzv_conf=None,
52
+ rdzv_id="none",
53
+ max_restarts=0,
54
+ monitor_interval=0.1,
55
+ log_line_prefix_template=None,
56
+ ):
57
+ """
58
+ Launches a training function, using several processes or multiple nodes if it's possible in the current environment
59
+ (TPU with multiple cores for instance).
60
+
61
+ <Tip warning={true}>
62
+
63
+ To use this function absolutely zero calls to a device must be made in the notebook session before calling. If any
64
+ have been made, you will need to restart the notebook and make sure no cells use any device capability.
65
+
66
+ Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
67
+ of those calls have been made.
68
+
69
+ </Tip>
70
+
71
+ Args:
72
+ function (`Callable`):
73
+ The training function to execute. If it accepts arguments, the first argument should be the index of the
74
+ process run.
75
+ args (`Tuple`):
76
+ Tuple of arguments to pass to the function (it will receive `*args`).
77
+ num_processes (`int`, *optional*):
78
+ The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
79
+ the number of devices available otherwise.
80
+ mixed_precision (`str`, *optional*, defaults to `"no"`):
81
+ If `fp16` or `bf16`, will use mixed precision training on multi-device.
82
+ use_port (`str`, *optional*, defaults to `"29500"`):
83
+ The port to use to communicate between processes when launching a multi-device training.
84
+ master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
85
+ The address to use for communication between processes.
86
+ node_rank (`int`, *optional*, defaults to 0):
87
+ The rank of the current node.
88
+ num_nodes (`int`, *optional*, defaults to 1):
89
+ The number of nodes to use for training.
90
+ rdzv_backend (`str`, *optional*, defaults to `"static"`):
91
+ The rendezvous method to use, such as 'static' (the default) or 'c10d'
92
+ rdzv_endpoint (`str`, *optional*, defaults to `""`):
93
+ The endpoint of the rdzv sync. storage.
94
+ rdzv_conf (`Dict`, *optional*, defaults to `None`):
95
+ Additional rendezvous configuration.
96
+ rdzv_id (`str`, *optional*, defaults to `"none"`):
97
+ The unique run id of the job.
98
+ max_restarts (`int`, *optional*, defaults to 0):
99
+ The maximum amount of restarts that elastic agent will conduct on workers before failure.
100
+ monitor_interval (`float`, *optional*, defaults to 0.1):
101
+ The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
102
+ log_line_prefix_template (`str`, *optional*, defaults to `None`):
103
+ The prefix template for elastic launch logging. Available from PyTorch 2.2.0.
104
+
105
+ Example:
106
+
107
+ ```python
108
+ # Assume this is defined in a Jupyter Notebook on an instance with two devices
109
+ from accelerate import notebook_launcher
110
+
111
+
112
+ def train(*args):
113
+ # Your training function here
114
+ ...
115
+
116
+
117
+ notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16")
118
+ ```
119
+ """
120
+ # Are we in a google colab or a Kaggle Kernel?
121
+ in_colab = False
122
+ in_kaggle = False
123
+ if any(key.startswith("KAGGLE") for key in os.environ.keys()):
124
+ in_kaggle = True
125
+ elif "IPython" in sys.modules:
126
+ in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
127
+
128
+ try:
129
+ mixed_precision = PrecisionType(mixed_precision.lower())
130
+ except ValueError:
131
+ raise ValueError(
132
+ f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
133
+ )
134
+
135
+ if (in_colab or in_kaggle) and (
136
+ (os.environ.get("TPU_NAME", None) is not None) or (os.environ.get("PJRT_DEVICE", "") == "TPU")
137
+ ):
138
+ # TPU launch
139
+ import torch_xla.distributed.xla_multiprocessing as xmp
140
+
141
+ if len(AcceleratorState._shared_state) > 0:
142
+ raise ValueError(
143
+ "To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
144
+ "your training function. Restart your notebook and make sure no cells initializes an "
145
+ "`Accelerator`."
146
+ )
147
+
148
+ launcher = PrepareForLaunch(function, distributed_type="XLA")
149
+ print("Launching a training on TPU cores.")
150
+ xmp.spawn(launcher, args=args, start_method="fork")
151
+ elif in_colab and get_gpu_info()[1] < 2:
152
+ # No need for a distributed launch otherwise as it's either CPU or one GPU.
153
+ if torch.cuda.is_available():
154
+ print("Launching training on one GPU.")
155
+ else:
156
+ print("Launching training on one CPU.")
157
+ function(*args)
158
+ else:
159
+ if num_processes is None:
160
+ raise ValueError(
161
+ "You have to specify the number of devices you would like to use, add `num_processes=...` to your call."
162
+ )
163
+ if node_rank >= num_nodes:
164
+ raise ValueError("The node_rank must be less than the number of nodes.")
165
+ if num_processes > 1:
166
+ # Multi-device launch
167
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
168
+ from torch.multiprocessing import start_processes
169
+ from torch.multiprocessing.spawn import ProcessRaisedException
170
+
171
+ if len(AcceleratorState._shared_state) > 0:
172
+ raise ValueError(
173
+ "To launch a multi-device training from your notebook, the `Accelerator` should only be initialized "
174
+ "inside your training function. Restart your notebook and make sure no cells initializes an "
175
+ "`Accelerator`."
176
+ )
177
+ # Check for specific libraries known to initialize device that users constantly use
178
+ problematic_imports = are_libraries_initialized("bitsandbytes")
179
+ if len(problematic_imports) > 0:
180
+ err = (
181
+ "Could not start distributed process. Libraries known to initialize device upon import have been "
182
+ "imported already. Please keep these imports inside your training function to try and help with this:"
183
+ )
184
+ for lib_name in problematic_imports:
185
+ err += f"\n\t* `{lib_name}`"
186
+ raise RuntimeError(err)
187
+
188
+ patched_env = dict(
189
+ nproc=num_processes,
190
+ node_rank=node_rank,
191
+ world_size=num_nodes * num_processes,
192
+ master_addr=master_addr,
193
+ master_port=use_port,
194
+ mixed_precision=mixed_precision,
195
+ )
196
+
197
+ # Check for CUDA P2P and IB issues
198
+ if not check_cuda_p2p_ib_support():
199
+ patched_env["nccl_p2p_disable"] = "1"
200
+ patched_env["nccl_ib_disable"] = "1"
201
+
202
+ # torch.distributed will expect a few environment variable to be here. We set the ones common to each
203
+ # process here (the other ones will be set be the launcher).
204
+ with patch_environment(**patched_env):
205
+ # First dummy launch
206
+ device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
207
+ distributed_type = "MULTI_XPU" if device_type == "xpu" else "MULTI_GPU"
208
+ if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
209
+ launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)
210
+ try:
211
+ start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
212
+ except ProcessRaisedException as e:
213
+ err = "An issue was found when verifying a stable environment for the notebook launcher."
214
+ if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
215
+ raise RuntimeError(
216
+ f"{err}"
217
+ "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
218
+ "Please review your imports and test them when running the `notebook_launcher()` to identify "
219
+ f"which one is problematic and causing {device_type.upper()} to be initialized."
220
+ ) from e
221
+ else:
222
+ raise RuntimeError(f"{err} The following error was raised: {e}") from e
223
+ # Now the actual launch
224
+ launcher = PrepareForLaunch(function, distributed_type=distributed_type)
225
+ print(f"Launching training on {num_processes} {device_type.upper()}s.")
226
+ try:
227
+ if rdzv_conf is None:
228
+ rdzv_conf = {}
229
+ if rdzv_backend == "static":
230
+ rdzv_conf["rank"] = node_rank
231
+ if not rdzv_endpoint:
232
+ rdzv_endpoint = f"{master_addr}:{use_port}"
233
+ launch_config_kwargs = dict(
234
+ min_nodes=num_nodes,
235
+ max_nodes=num_nodes,
236
+ nproc_per_node=num_processes,
237
+ run_id=rdzv_id,
238
+ rdzv_endpoint=rdzv_endpoint,
239
+ rdzv_backend=rdzv_backend,
240
+ rdzv_configs=rdzv_conf,
241
+ max_restarts=max_restarts,
242
+ monitor_interval=monitor_interval,
243
+ start_method="fork",
244
+ )
245
+ if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
246
+ launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
247
+ elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
248
+ except ProcessRaisedException as e:
249
+ if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
250
+ raise RuntimeError(
251
+ f"{device_type.upper()} has been initialized before the `notebook_launcher` could create a forked subprocess. "
252
+ "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
253
+ "Please review your imports and test them when running the `notebook_launcher()` to identify "
254
+ f"which one is problematic and causing {device_type.upper()} to be initialized."
255
+ ) from e
256
+ else:
257
+ raise RuntimeError(f"An issue was found when launching the training: {e}") from e
258
+
259
+ else:
260
+ # No need for a distributed launch otherwise as it's either CPU, GPU, XPU or MPS.
261
+ if is_mps_available():
262
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
263
+ print("Launching training on MPS.")
264
+ elif torch.cuda.is_available():
265
+ print("Launching training on one GPU.")
266
+ elif torch.xpu.is_available():
267
+ print("Launching training on one XPU.")
268
+ else:
269
+ print("Launching training on CPU.")
270
+ function(*args)
271
+
272
+
273
+ def debug_launcher(function, args=(), num_processes=2):
274
+ """
275
+ Launches a training function using several processes on CPU for debugging purposes.
276
+
277
+ <Tip warning={true}>
278
+
279
+ This function is provided for internal testing and debugging, but it's not intended for real trainings. It will
280
+ only use the CPU.
281
+
282
+ </Tip>
283
+
284
+ Args:
285
+ function (`Callable`):
286
+ The training function to execute.
287
+ args (`Tuple`):
288
+ Tuple of arguments to pass to the function (it will receive `*args`).
289
+ num_processes (`int`, *optional*, defaults to 2):
290
+ The number of processes to use for training.
291
+ """
292
+ from torch.multiprocessing import start_processes
293
+
294
+ with tempfile.NamedTemporaryFile() as tmp_file:
295
+ # torch.distributed will expect a few environment variable to be here. We set the ones common to each
296
+ # process here (the other ones will be set be the launcher).
297
+ with patch_environment(
298
+ world_size=num_processes,
299
+ master_addr="127.0.0.1",
300
+ master_port="29500",
301
+ accelerate_mixed_precision="no",
302
+ accelerate_debug_rdv_file=tmp_file.name,
303
+ accelerate_use_cpu="yes",
304
+ ):
305
+ launcher = PrepareForLaunch(function, debug=True)
306
+ start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
pythonProject/.venv/Lib/site-packages/accelerate/local_sgd.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ from accelerate import Accelerator, DistributedType
17
+
18
+
19
+ class LocalSGD:
20
+ """
21
+ A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
22
+ on each device, and averages model weights every K synchronization step.
23
+
24
+ It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
25
+ this is a simple implementation that cannot support scenarios such as model parallelism.
26
+
27
+
28
+ Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
29
+ back to at least:
30
+
31
+ Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
32
+ arXiv:1606.07365.](https://arxiv.org/abs/1606.07365)
33
+
34
+ We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).
35
+
36
+ Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
37
+ Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767)
38
+
39
+ """
40
+
41
+ def __enter__(self):
42
+ if self.enabled:
43
+ self.model_sync_obj = self.model.no_sync()
44
+ self.model_sync_obj.__enter__()
45
+
46
+ return self
47
+
48
+ def __exit__(self, type, value, tb):
49
+ if self.enabled:
50
+ # Average all models on exit
51
+ self._sync_and_avg_model_params()
52
+ self.model_sync_obj.__exit__(type, value, tb)
53
+
54
+ def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
55
+ """
56
+ Constructor.
57
+
58
+ Args:
59
+ model (`torch.nn.Module):
60
+ The model whose parameters we need to average.
61
+ accelerator (`Accelerator`):
62
+ Accelerator object.
63
+ local_sgd_steps (`int`):
64
+ A number of local SGD steps (before model parameters are synchronized).
65
+ enabled (`bool):
66
+ Local SGD is disabled if this parameter set to `False`.
67
+ """
68
+ if accelerator.distributed_type not in [
69
+ DistributedType.NO,
70
+ DistributedType.MULTI_CPU,
71
+ DistributedType.MULTI_GPU,
72
+ DistributedType.MULTI_XPU,
73
+ DistributedType.MULTI_MLU,
74
+ DistributedType.MULTI_HPU,
75
+ DistributedType.MULTI_SDAA,
76
+ DistributedType.MULTI_MUSA,
77
+ DistributedType.MULTI_NPU,
78
+ ]:
79
+ raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
80
+ self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
81
+ self.num_steps = 0
82
+ if self.enabled:
83
+ self.accelerator = accelerator
84
+ self.model = model
85
+ self.local_sgd_steps = local_sgd_steps
86
+
87
+ def step(self):
88
+ """
89
+ This function makes a "step" and synchronizes model parameters if necessary.
90
+ """
91
+ self.num_steps += 1
92
+ if not self.enabled:
93
+ return
94
+
95
+ if self.num_steps % self.local_sgd_steps == 0:
96
+ self._sync_and_avg_model_params()
97
+
98
+ def _sync_and_avg_model_params(self):
99
+ """
100
+ Synchronize + Average model parameters across all GPUs
101
+ """
102
+
103
+ self.accelerator.wait_for_everyone()
104
+ with self.accelerator.autocast():
105
+ for param in self.model.parameters():
106
+ param.data = self.accelerator.reduce(param.data, reduction="mean")
pythonProject/.venv/Lib/site-packages/accelerate/logging.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import logging
17
+ import os
18
+
19
+ from .state import PartialState
20
+
21
+
22
+ class MultiProcessAdapter(logging.LoggerAdapter):
23
+ """
24
+ An adapter to assist with logging in multiprocess.
25
+
26
+ `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
27
+ or only the main executed one. Default is `main_process_only=True`.
28
+
29
+ Does not require an `Accelerator` object to be created first.
30
+ """
31
+
32
+ @staticmethod
33
+ def _should_log(main_process_only):
34
+ "Check if log should be performed"
35
+ state = PartialState()
36
+ return not main_process_only or (main_process_only and state.is_main_process)
37
+
38
+ def log(self, level, msg, *args, **kwargs):
39
+ """
40
+ Delegates logger call after checking if we should log.
41
+
42
+ Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
43
+ or only the main executed one. Default is `True` if not passed
44
+
45
+ Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
46
+ read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
47
+ break with the previous behavior.
48
+
49
+ `in_order` is ignored if `main_process_only` is passed.
50
+ """
51
+ if PartialState._shared_state == {}:
52
+ raise RuntimeError(
53
+ "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
54
+ )
55
+ main_process_only = kwargs.pop("main_process_only", True)
56
+ in_order = kwargs.pop("in_order", False)
57
+ # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
58
+ kwargs.setdefault("stacklevel", 2)
59
+
60
+ if self.isEnabledFor(level):
61
+ if self._should_log(main_process_only):
62
+ msg, kwargs = self.process(msg, kwargs)
63
+ self.logger.log(level, msg, *args, **kwargs)
64
+
65
+ elif in_order:
66
+ state = PartialState()
67
+ for i in range(state.num_processes):
68
+ if i == state.process_index:
69
+ msg, kwargs = self.process(msg, kwargs)
70
+ self.logger.log(level, msg, *args, **kwargs)
71
+ state.wait_for_everyone()
72
+
73
+ @functools.lru_cache(None)
74
+ def warning_once(self, *args, **kwargs):
75
+ """
76
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
77
+
78
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
79
+ cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
80
+ switch to another type of cache that includes the caller frame information in the hashing function.
81
+ """
82
+ self.warning(*args, **kwargs)
83
+
84
+
85
+ def get_logger(name: str, log_level: str = None):
86
+ """
87
+ Returns a `logging.Logger` for `name` that can handle multiprocessing.
88
+
89
+ If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
90
+ processes and in order, also pass `in_order=True`
91
+
92
+ Args:
93
+ name (`str`):
94
+ The name for the logger, such as `__file__`
95
+ log_level (`str`, *optional*):
96
+ The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
97
+
98
+ Example:
99
+
100
+ ```python
101
+ >>> from accelerate.logging import get_logger
102
+ >>> from accelerate import Accelerator
103
+
104
+ >>> logger = get_logger(__name__)
105
+
106
+ >>> accelerator = Accelerator()
107
+ >>> logger.info("My log", main_process_only=False)
108
+ >>> logger.debug("My log", main_process_only=True)
109
+
110
+ >>> logger = get_logger(__name__, log_level="DEBUG")
111
+ >>> logger.info("My log")
112
+ >>> logger.debug("My second log")
113
+
114
+ >>> array = ["a", "b", "c", "d"]
115
+ >>> letter_at_rank = array[accelerator.process_index]
116
+ >>> logger.info(letter_at_rank, in_order=True)
117
+ ```
118
+ """
119
+ if log_level is None:
120
+ log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
121
+ logger = logging.getLogger(name)
122
+ if log_level is not None:
123
+ logger.setLevel(log_level.upper())
124
+ logger.root.setLevel(log_level.upper())
125
+ return MultiProcessAdapter(logger, {})
pythonProject/.venv/Lib/site-packages/accelerate/memory_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+
17
+
18
+ warnings.warn(
19
+ "memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: "
20
+ "`from accelerate import find_executable_batch_size` to avoid this warning.",
21
+ FutureWarning,
22
+ )
pythonProject/.venv/Lib/site-packages/accelerate/optimizer.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+
17
+ import torch
18
+
19
+ from .state import AcceleratorState, GradientState
20
+ from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
21
+
22
+
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+ import torch_xla.runtime as xr
26
+
27
+
28
+ def move_to_device(state, device):
29
+ if isinstance(state, (list, tuple)):
30
+ return honor_type(state, (move_to_device(t, device) for t in state))
31
+ elif isinstance(state, dict):
32
+ return type(state)({k: move_to_device(v, device) for k, v in state.items()})
33
+ elif isinstance(state, torch.Tensor):
34
+ return state.to(device)
35
+ return state
36
+
37
+
38
+ class AcceleratedOptimizer(torch.optim.Optimizer):
39
+ """
40
+ Internal wrapper around a torch optimizer.
41
+
42
+ Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
43
+ accumulation.
44
+
45
+ Args:
46
+ optimizer (`torch.optim.optimizer.Optimizer`):
47
+ The optimizer to wrap.
48
+ device_placement (`bool`, *optional*, defaults to `True`):
49
+ Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
50
+ `optimizer` on the right device.
51
+ scaler (`torch.amp.GradScaler` or `torch.cuda.amp.GradScaler`, *optional*):
52
+ The scaler to use in the step function if training with mixed precision.
53
+ """
54
+
55
+ def __init__(self, optimizer, device_placement=True, scaler=None):
56
+ self.optimizer = optimizer
57
+ self.scaler = scaler
58
+ self.accelerator_state = AcceleratorState()
59
+ self.gradient_state = GradientState()
60
+ self.device_placement = device_placement
61
+ self._is_overflow = False
62
+
63
+ if self.scaler is not None:
64
+ self._accelerate_step_called = False
65
+ self._optimizer_original_step_method = self.optimizer.step
66
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
67
+
68
+ # Handle device placement
69
+ if device_placement:
70
+ state_dict = self.optimizer.state_dict()
71
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
72
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
73
+ else:
74
+ state_dict = move_to_device(state_dict, self.accelerator_state.device)
75
+ self.optimizer.load_state_dict(state_dict)
76
+
77
+ @property
78
+ def state(self):
79
+ return self.optimizer.state
80
+
81
+ @state.setter
82
+ def state(self, state):
83
+ self.optimizer.state = state
84
+
85
+ @property
86
+ def param_groups(self):
87
+ return self.optimizer.param_groups
88
+
89
+ @param_groups.setter
90
+ def param_groups(self, param_groups):
91
+ self.optimizer.param_groups = param_groups
92
+
93
+ @property
94
+ def defaults(self):
95
+ return self.optimizer.defaults
96
+
97
+ @defaults.setter
98
+ def defaults(self, defaults):
99
+ self.optimizer.defaults = defaults
100
+
101
+ def add_param_group(self, param_group):
102
+ self.optimizer.add_param_group(param_group)
103
+
104
+ def load_state_dict(self, state_dict):
105
+ if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
106
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
107
+ self.optimizer.load_state_dict(state_dict)
108
+
109
+ def state_dict(self):
110
+ return self.optimizer.state_dict()
111
+
112
+ def zero_grad(self, set_to_none=None):
113
+ if self.gradient_state.sync_gradients:
114
+ accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
115
+ if accept_arg:
116
+ if set_to_none is None:
117
+ set_to_none = True
118
+ self.optimizer.zero_grad(set_to_none=set_to_none)
119
+ else:
120
+ if set_to_none is not None:
121
+ raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
122
+ self.optimizer.zero_grad()
123
+
124
+ def train(self):
125
+ """
126
+ Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
127
+ """
128
+ if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
129
+ self.optimizer.train()
130
+ elif (
131
+ hasattr(self.optimizer, "optimizer")
132
+ and hasattr(self.optimizer.optimizer, "train")
133
+ and callable(self.optimizer.optimizer.train)
134
+ ):
135
+ # the deepspeed optimizer further wraps the optimizer
136
+ self.optimizer.optimizer.train()
137
+
138
+ def eval(self):
139
+ """
140
+ Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
141
+ """
142
+ if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
143
+ self.optimizer.eval()
144
+
145
+ def step(self, closure=None):
146
+ if is_lomo_available():
147
+ from lomo_optim import AdaLomo, Lomo
148
+
149
+ if (
150
+ not self.gradient_state.is_xla_gradients_synced
151
+ and self.accelerator_state.distributed_type == DistributedType.XLA
152
+ ):
153
+ gradients = xm._fetch_gradients(self.optimizer)
154
+ xm.all_reduce("sum", gradients, scale=1.0 / xr.world_size())
155
+ self.gradient_state.is_xla_gradients_synced = True
156
+
157
+ if is_lomo_available():
158
+ # `step` should be a no-op for LOMO optimizers.
159
+ if isinstance(self.optimizer, (Lomo, AdaLomo)):
160
+ return
161
+
162
+ if self.gradient_state.sync_gradients:
163
+ if self.scaler is not None:
164
+ self.optimizer.step = self._optimizer_patched_step_method
165
+
166
+ self.scaler.step(self.optimizer, closure)
167
+ self.scaler.update()
168
+
169
+ if not self._accelerate_step_called:
170
+ # If the optimizer step was skipped, gradient overflow was detected.
171
+ self._is_overflow = True
172
+ else:
173
+ self._is_overflow = False
174
+ # Reset the step method to the original one
175
+ self.optimizer.step = self._optimizer_original_step_method
176
+ # Reset the indicator
177
+ self._accelerate_step_called = False
178
+ else:
179
+ self.optimizer.step(closure)
180
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
181
+ self.gradient_state.is_xla_gradients_synced = False
182
+
183
+ def _switch_parameters(self, parameters_map):
184
+ for param_group in self.optimizer.param_groups:
185
+ param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
186
+
187
+ @property
188
+ def step_was_skipped(self):
189
+ """Whether or not the optimizer step was skipped."""
190
+ return self._is_overflow
191
+
192
+ def __getstate__(self):
193
+ _ignored_keys = [
194
+ "_accelerate_step_called",
195
+ "_optimizer_original_step_method",
196
+ "_optimizer_patched_step_method",
197
+ ]
198
+ return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
199
+
200
+ def __setstate__(self, state):
201
+ self.__dict__.update(state)
202
+ if self.scaler is not None:
203
+ self._accelerate_step_called = False
204
+ self._optimizer_original_step_method = self.optimizer.step
205
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
206
+
207
+
208
+ def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
209
+ def patched_step(*args, **kwargs):
210
+ accelerated_optimizer._accelerate_step_called = True
211
+ return method(*args, **kwargs)
212
+
213
+ return patched_step
pythonProject/.venv/Lib/site-packages/accelerate/parallelism_config.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import warnings
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Optional, Union
19
+
20
+ from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
21
+ from accelerate.utils.versions import is_torch_version
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from accelerate import Accelerator
26
+
27
+
28
+ @dataclass
29
+ class ParallelismConfig:
30
+ """
31
+ A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
32
+ https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
33
+
34
+ Args:
35
+ dp_replicate_size (`int`, defaults to `1`):
36
+ The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
37
+ group will not be used.
38
+ dp_shard_size (`int`, defaults to `1`):
39
+ The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
40
+ be greater than 1, as composing DDP + TP is currently not supported.
41
+ tp_size (`int`, defaults to `1`):
42
+ The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
43
+ used.
44
+ cp_size (`int`, defaults to `1`):
45
+ The size of the context parallel group. Currently not supported, but reserved for future use and enabled
46
+ for downstream libraries.
47
+ tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
48
+ The handler for the tensor parallel group.
49
+
50
+ You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
51
+ together:
52
+ - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
53
+ - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
54
+ - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
55
+ `DistributedDataParallelKwargs` instead.
56
+
57
+ """
58
+
59
+ dp_replicate_size: int = None
60
+ dp_shard_size: int = None
61
+ tp_size: int = None
62
+ cp_size: int = None
63
+
64
+ # we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
65
+ tp_handler: Union[None, TorchTensorParallelConfig] = None
66
+ cp_handler: Union[None, TorchContextParallelConfig] = None
67
+
68
+ device_mesh = None
69
+
70
+ def __repr__(self):
71
+ return (
72
+ "ParallelismConfig(\n "
73
+ f"\tdp_replicate_size={self.dp_replicate_size},\n"
74
+ f"\tdp_shard_size={self.dp_shard_size},\n"
75
+ f"\ttp_size={self.tp_size},\n"
76
+ f"\tcp_size={self.cp_size},\n"
77
+ f"\ttotal_size={self.total_size}\n"
78
+ f"\ttp_handler={self.tp_handler},\n"
79
+ f"\tcp_handler={self.cp_handler})\n"
80
+ )
81
+
82
+ def to_json(self):
83
+ import copy
84
+
85
+ _non_serializable_fields = ["device_mesh"]
86
+
87
+ copy.deepcopy(
88
+ {
89
+ k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
90
+ for k, v in self.__dict__.items()
91
+ if k not in _non_serializable_fields
92
+ }
93
+ )
94
+
95
+ @property
96
+ def dp_dim_names(self):
97
+ """Names of enabled dimensions across which data parallelism is applied."""
98
+ dims = []
99
+ if self.dp_replicate_enabled:
100
+ dims += ["dp_replicate"]
101
+ if self.dp_shard_enabled:
102
+ dims += ["dp_shard"]
103
+ return dims
104
+
105
+ @property
106
+ def non_dp_dim_names(self):
107
+ """Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
108
+ dims = []
109
+ if self.tp_enabled:
110
+ dims += ["tp"]
111
+ if self.cp_enabled:
112
+ dims += ["cp"]
113
+ return dims
114
+
115
+ @property
116
+ def dp_shard_cp_dim_names(self):
117
+ """Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
118
+ dims = []
119
+ if self.dp_shard_enabled:
120
+ dims += ["dp_shard"]
121
+ if self.cp_enabled:
122
+ dims += ["cp"]
123
+ return dims
124
+
125
+ @property
126
+ def dp_cp_dim_names(self):
127
+ """Names of enabled dimensions across which loss should be averaged"""
128
+ dims = []
129
+ if self.dp_replicate_enabled:
130
+ dims += ["dp_replicate"]
131
+ if self.dp_shard_enabled:
132
+ dims += ["dp_shard"]
133
+ if self.cp_enabled:
134
+ dims += ["cp"]
135
+ return dims
136
+
137
+ @property
138
+ def fsdp_dim_names(self):
139
+ """Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
140
+ dims = []
141
+ if self.dp_replicate_enabled:
142
+ dims += ["dp_replicate"]
143
+ dims += ["dp_shard_cp"]
144
+ return dims
145
+
146
+ @property
147
+ def total_size(self):
148
+ """The total size of the parallelism configuration, which is the product of all sizes."""
149
+ return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size
150
+
151
+ @property
152
+ def non_data_parallel_size(self):
153
+ """The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
154
+ return self.tp_size * self.cp_size
155
+
156
+ @property
157
+ def data_parallel_size(self):
158
+ """The size of the data parallel dimensions, which is the product of data parallel replication and"""
159
+ return self.dp_replicate_size * self.dp_shard_size
160
+
161
+ @property
162
+ def dp_replicate_enabled(self):
163
+ """True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
164
+ return self.dp_replicate_size > 1
165
+
166
+ @property
167
+ def dp_shard_enabled(self):
168
+ """True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
169
+ return self.dp_shard_size > 1
170
+
171
+ @property
172
+ def tp_enabled(self):
173
+ """True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
174
+ return self.tp_size > 1
175
+
176
+ @property
177
+ def cp_enabled(self):
178
+ """True if context parallelism is enabled, i.e. `cp_size > 1`."""
179
+ return self.cp_size > 1
180
+
181
+ @property
182
+ def active_mesh_dims(self):
183
+ """Names of all active mesh dimensions."""
184
+ return self.dp_dim_names + self.non_dp_dim_names
185
+
186
+ def build_device_mesh(self, device_type: str):
187
+ """Builds a device mesh for the given device type based on the parallelism configuration.
188
+ This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
189
+
190
+ Args:
191
+ device_type (`str`): The type of device for which to build the mesh, e
192
+ """
193
+ if is_torch_version(">=", "2.2.0"):
194
+ from torch.distributed.device_mesh import init_device_mesh
195
+ else:
196
+ raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
197
+
198
+ mesh = self._get_mesh()
199
+ if len(mesh) == 0:
200
+ return None
201
+ mesh_dim_names, mesh_shape = mesh
202
+ device_mesh = init_device_mesh(
203
+ device_type,
204
+ mesh_shape,
205
+ mesh_dim_names=mesh_dim_names,
206
+ )
207
+ if self.dp_dim_names:
208
+ device_mesh[self.dp_dim_names]._flatten("dp")
209
+ if self.dp_shard_cp_dim_names:
210
+ device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
211
+ if self.dp_cp_dim_names:
212
+ device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
213
+
214
+ return device_mesh
215
+
216
+ def get_device_mesh(self, device_type: Optional[str] = None):
217
+ if self.device_mesh is None:
218
+ if device_type is not None:
219
+ self.device_mesh = self.build_device_mesh(device_type)
220
+ else:
221
+ raise ("You need to pass a device_type e.g cuda to build the device mesh")
222
+ else:
223
+ if device_type is not None:
224
+ if self.device_mesh.device_type != device_type:
225
+ raise ValueError(
226
+ f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
227
+ )
228
+ return self.device_mesh
229
+
230
+ def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
231
+ """Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
232
+
233
+ # Build mesh dimensions dictionary
234
+ mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
235
+
236
+ # Apply canonical ordering
237
+ mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"]
238
+ sorted_items = sorted(
239
+ mesh_dims.items(),
240
+ key=lambda x: (mesh_order.index(x[0])),
241
+ )
242
+ return tuple(zip(*sorted_items))
243
+
244
+ def __post_init__(self):
245
+ # Basic size validation
246
+ if self.dp_replicate_size is None:
247
+ self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
248
+ if self.dp_shard_size is None:
249
+ self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
250
+ if self.tp_size is None:
251
+ self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
252
+ if self.cp_size is None:
253
+ self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
254
+
255
+ if self.tp_size > 1:
256
+ if self.tp_handler is None:
257
+ self.tp_handler = TorchTensorParallelConfig()
258
+
259
+ if self.cp_size > 1:
260
+ if self.cp_handler is None:
261
+ self.cp_handler = TorchContextParallelConfig()
262
+
263
+ if self.dp_replicate_size < 1:
264
+ raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
265
+ if self.dp_shard_size < 1:
266
+ raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
267
+ if self.tp_size < 1:
268
+ raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
269
+ if self.cp_size < 1:
270
+ raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
271
+
272
+ if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
273
+ raise ValueError(
274
+ "Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
275
+ "Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
276
+ "or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
277
+ )
278
+ self._sizes = {
279
+ "dp_replicate": self.dp_replicate_size,
280
+ "dp_shard": self.dp_shard_size,
281
+ "tp": self.tp_size,
282
+ "cp": self.cp_size,
283
+ }
284
+
285
+ def _set_size(self, parallelism: str, size: int):
286
+ assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
287
+ self._sizes[parallelism] = size
288
+ setattr(self, f"{parallelism}_size", size)
289
+
290
+ def _validate_accelerator(self, accelerator: "Accelerator"):
291
+ _warnings = set()
292
+ if not accelerator.multi_device and self.total_size == 1:
293
+ # No distributed setup, valid parallelism config
294
+ return
295
+
296
+ # We need this to ensure DDP works
297
+ if self.total_size == 1:
298
+ self._set_size("dp_replicate", accelerator.num_processes)
299
+
300
+ if self.total_size != accelerator.num_processes:
301
+ raise ValueError(
302
+ f"ParallelismConfig total_size ({self.total_size}) does not match "
303
+ f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
304
+ f"dp_shard_size/tp_size/cp_size."
305
+ )
306
+
307
+ if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device):
308
+ raise ValueError(
309
+ f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
310
+ )
311
+
312
+ for parallelism, size in self._sizes.items():
313
+ if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
314
+ _warnings.add(
315
+ f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
316
+ )
317
+
318
+ if _warnings and accelerator.is_main_process:
319
+ warnings.warn(
320
+ "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
321
+ UserWarning,
322
+ )
pythonProject/.venv/Lib/site-packages/accelerate/scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
16
+
17
+ import warnings
18
+
19
+ from .state import AcceleratorState, GradientState
20
+
21
+
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
23
+
24
+
25
+ class AcceleratedScheduler:
26
+ """
27
+ A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
28
+ to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
29
+ precision training)
30
+
31
+ When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
32
+ step the scheduler to account for it.
33
+
34
+ Args:
35
+ scheduler (`torch.optim.lr_scheduler._LRScheduler`):
36
+ The scheduler to wrap.
37
+ optimizers (one or a list of `torch.optim.Optimizer`):
38
+ The optimizers used.
39
+ step_with_optimizer (`bool`, *optional*, defaults to `True`):
40
+ Whether or not the scheduler should be stepped at each optimizer step.
41
+ split_batches (`bool`, *optional*, defaults to `False`):
42
+ Whether or not the dataloaders split one batch across the different processes (so batch size is the same
43
+ regardless of the number of processes) or create batches on each process (so batch size is the original
44
+ batch size multiplied by the number of processes).
45
+ """
46
+
47
+ def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
48
+ self.scheduler = scheduler
49
+ self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
50
+ self.split_batches = split_batches
51
+ self.step_with_optimizer = step_with_optimizer
52
+ self.gradient_state = GradientState()
53
+
54
+ def step(self, *args, **kwargs):
55
+ if not self.step_with_optimizer:
56
+ # No link between scheduler and optimizer -> just step
57
+ self.scheduler.step(*args, **kwargs)
58
+ return
59
+
60
+ # Otherwise, first make sure the optimizer was stepped.
61
+ if not self.gradient_state.sync_gradients:
62
+ if self.gradient_state.adjust_scheduler:
63
+ self.scheduler._step_count += 1
64
+ return
65
+
66
+ for opt in self.optimizers:
67
+ if opt.step_was_skipped:
68
+ return
69
+ if self.split_batches:
70
+ # Split batches -> the training dataloader batch size is not changed so one step per training step
71
+ self.scheduler.step(*args, **kwargs)
72
+ else:
73
+ # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
74
+ # num_processes steps per training step
75
+ num_processes = AcceleratorState().num_processes
76
+ for _ in range(num_processes):
77
+ # Special case when using OneCycle and `drop_last` was not used
78
+ if hasattr(self.scheduler, "total_steps"):
79
+ if self.scheduler._step_count <= self.scheduler.total_steps:
80
+ self.scheduler.step(*args, **kwargs)
81
+ else:
82
+ self.scheduler.step(*args, **kwargs)
83
+
84
+ # Passthroughs
85
+ def get_last_lr(self):
86
+ return self.scheduler.get_last_lr()
87
+
88
+ def state_dict(self):
89
+ return self.scheduler.state_dict()
90
+
91
+ def load_state_dict(self, state_dict):
92
+ self.scheduler.load_state_dict(state_dict)
93
+
94
+ def get_lr(self):
95
+ return self.scheduler.get_lr()
96
+
97
+ def print_lr(self, *args, **kwargs):
98
+ return self.scheduler.print_lr(*args, **kwargs)
pythonProject/.venv/Lib/site-packages/isympy.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Python shell for SymPy.
3
+
4
+ This is just a normal Python shell (IPython shell if you have the
5
+ IPython package installed), that executes the following commands for
6
+ the user:
7
+
8
+ >>> from __future__ import division
9
+ >>> from sympy import *
10
+ >>> x, y, z, t = symbols('x y z t')
11
+ >>> k, m, n = symbols('k m n', integer=True)
12
+ >>> f, g, h = symbols('f g h', cls=Function)
13
+ >>> init_printing()
14
+
15
+ So starting 'isympy' is equivalent to starting Python (or IPython) and
16
+ executing the above commands by hand. It is intended for easy and quick
17
+ experimentation with SymPy. isympy is a good way to use SymPy as an
18
+ interactive calculator. If you have IPython and Matplotlib installed, then
19
+ interactive plotting is enabled by default.
20
+
21
+ COMMAND LINE OPTIONS
22
+ --------------------
23
+
24
+ -c CONSOLE, --console=CONSOLE
25
+
26
+ Use the specified shell (Python or IPython) shell as the console
27
+ backend instead of the default one (IPython if present, Python
28
+ otherwise), e.g.:
29
+
30
+ $isympy -c python
31
+
32
+ CONSOLE must be one of 'ipython' or 'python'
33
+
34
+ -p PRETTY, --pretty PRETTY
35
+
36
+ Setup pretty-printing in SymPy. When pretty-printing is enabled,
37
+ expressions can be printed with Unicode or ASCII. The default is
38
+ to use pretty-printing (with Unicode if the terminal supports it).
39
+ When this option is 'no', expressions will not be pretty-printed
40
+ and ASCII will be used:
41
+
42
+ $isympy -p no
43
+
44
+ PRETTY must be one of 'unicode', 'ascii', or 'no'
45
+
46
+ -t TYPES, --types=TYPES
47
+
48
+ Setup the ground types for the polys. By default, gmpy ground types
49
+ are used if gmpy2 or gmpy is installed, otherwise it falls back to python
50
+ ground types, which are a little bit slower. You can manually
51
+ choose python ground types even if gmpy is installed (e.g., for
52
+ testing purposes):
53
+
54
+ $isympy -t python
55
+
56
+ TYPES must be one of 'gmpy', 'gmpy1' or 'python'
57
+
58
+ Note that the ground type gmpy1 is primarily intended for testing; it
59
+ forces the use of gmpy version 1 even if gmpy2 is available.
60
+
61
+ This is the same as setting the environment variable
62
+ SYMPY_GROUND_TYPES to the given ground type (e.g.,
63
+ SYMPY_GROUND_TYPES='gmpy')
64
+
65
+ The ground types can be determined interactively from the variable
66
+ sympy.polys.domains.GROUND_TYPES.
67
+
68
+ -o ORDER, --order ORDER
69
+
70
+ Setup the ordering of terms for printing. The default is lex, which
71
+ orders terms lexicographically (e.g., x**2 + x + 1). You can choose
72
+ other orderings, such as rev-lex, which will use reverse
73
+ lexicographic ordering (e.g., 1 + x + x**2):
74
+
75
+ $isympy -o rev-lex
76
+
77
+ ORDER must be one of 'lex', 'rev-lex', 'grlex', 'rev-grlex',
78
+ 'grevlex', 'rev-grevlex', 'old', or 'none'.
79
+
80
+ Note that for very large expressions, ORDER='none' may speed up
81
+ printing considerably but the terms will have no canonical order.
82
+
83
+ -q, --quiet
84
+
85
+ Print only Python's and SymPy's versions to stdout at startup.
86
+
87
+ -d, --doctest
88
+
89
+ Use the same format that should be used for doctests. This is
90
+ equivalent to -c python -p no.
91
+
92
+ -C, --no-cache
93
+
94
+ Disable the caching mechanism. Disabling the cache may slow certain
95
+ operations down considerably. This is useful for testing the cache,
96
+ or for benchmarking, as the cache can result in deceptive timings.
97
+
98
+ This is equivalent to setting the environment variable
99
+ SYMPY_USE_CACHE to 'no'.
100
+
101
+ -a, --auto-symbols (requires at least IPython 0.11)
102
+
103
+ Automatically create missing symbols. Normally, typing a name of a
104
+ Symbol that has not been instantiated first would raise NameError,
105
+ but with this option enabled, any undefined name will be
106
+ automatically created as a Symbol.
107
+
108
+ Note that this is intended only for interactive, calculator style
109
+ usage. In a script that uses SymPy, Symbols should be instantiated
110
+ at the top, so that it's clear what they are.
111
+
112
+ This will not override any names that are already defined, which
113
+ includes the single character letters represented by the mnemonic
114
+ QCOSINE (see the "Gotchas and Pitfalls" document in the
115
+ documentation). You can delete existing names by executing "del
116
+ name". If a name is defined, typing "'name' in dir()" will return True.
117
+
118
+ The Symbols that are created using this have default assumptions.
119
+ If you want to place assumptions on symbols, you should create them
120
+ using symbols() or var().
121
+
122
+ Finally, this only works in the top level namespace. So, for
123
+ example, if you define a function in isympy with an undefined
124
+ Symbol, it will not work.
125
+
126
+ See also the -i and -I options.
127
+
128
+ -i, --int-to-Integer (requires at least IPython 0.11)
129
+
130
+ Automatically wrap int literals with Integer. This makes it so that
131
+ things like 1/2 will come out as Rational(1, 2), rather than 0.5. This
132
+ works by preprocessing the source and wrapping all int literals with
133
+ Integer. Note that this will not change the behavior of int literals
134
+ assigned to variables, and it also won't change the behavior of functions
135
+ that return int literals.
136
+
137
+ If you want an int, you can wrap the literal in int(), e.g. int(3)/int(2)
138
+ gives 1.5 (with division imported from __future__).
139
+
140
+ -I, --interactive (requires at least IPython 0.11)
141
+
142
+ This is equivalent to --auto-symbols --int-to-Integer. Future options
143
+ designed for ease of interactive use may be added to this.
144
+
145
+ -D, --debug
146
+
147
+ Enable debugging output. This is the same as setting the
148
+ environment variable SYMPY_DEBUG to 'True'. The debug status is set
149
+ in the variable SYMPY_DEBUG within isympy.
150
+
151
+ -- IPython options
152
+
153
+ Additionally you can pass command line options directly to the IPython
154
+ interpreter (the standard Python shell is not supported). However you
155
+ need to add the '--' separator between two types of options, e.g the
156
+ startup banner option and the colors option. You need to enter the
157
+ options as required by the version of IPython that you are using, too:
158
+
159
+ in IPython 0.11,
160
+
161
+ $isympy -q -- --colors=NoColor
162
+
163
+ or older versions of IPython,
164
+
165
+ $isympy -q -- -colors NoColor
166
+
167
+ See also isympy --help.
168
+ """
169
+
170
+ import os
171
+ import sys
172
+
173
+ # DO NOT IMPORT SYMPY HERE! Or the setting of the sympy environment variables
174
+ # by the command line will break.
175
+
176
+ def main() -> None:
177
+ from argparse import ArgumentParser, RawDescriptionHelpFormatter
178
+
179
+ VERSION = None
180
+ if '--version' in sys.argv:
181
+ # We cannot import sympy before this is run, because flags like -C and
182
+ # -t set environment variables that must be set before SymPy is
183
+ # imported. The only thing we need to import it for is to get the
184
+ # version, which only matters with the --version flag.
185
+ import sympy
186
+ VERSION = sympy.__version__
187
+
188
+ usage = 'isympy [options] -- [ipython options]'
189
+ parser = ArgumentParser(
190
+ usage=usage,
191
+ description=__doc__,
192
+ formatter_class=RawDescriptionHelpFormatter,
193
+ )
194
+
195
+ parser.add_argument('--version', action='version', version=VERSION)
196
+
197
+ parser.add_argument(
198
+ '-c', '--console',
199
+ dest='console',
200
+ action='store',
201
+ default=None,
202
+ choices=['ipython', 'python'],
203
+ metavar='CONSOLE',
204
+ help='select type of interactive session: ipython | python; defaults '
205
+ 'to ipython if IPython is installed, otherwise python')
206
+
207
+ parser.add_argument(
208
+ '-p', '--pretty',
209
+ dest='pretty',
210
+ action='store',
211
+ default=None,
212
+ metavar='PRETTY',
213
+ choices=['unicode', 'ascii', 'no'],
214
+ help='setup pretty printing: unicode | ascii | no; defaults to '
215
+ 'unicode printing if the terminal supports it, otherwise ascii')
216
+
217
+ parser.add_argument(
218
+ '-t', '--types',
219
+ dest='types',
220
+ action='store',
221
+ default=None,
222
+ metavar='TYPES',
223
+ choices=['gmpy', 'gmpy1', 'python'],
224
+ help='setup ground types: gmpy | gmpy1 | python; defaults to gmpy if gmpy2 '
225
+ 'or gmpy is installed, otherwise python')
226
+
227
+ parser.add_argument(
228
+ '-o', '--order',
229
+ dest='order',
230
+ action='store',
231
+ default=None,
232
+ metavar='ORDER',
233
+ choices=['lex', 'grlex', 'grevlex', 'rev-lex', 'rev-grlex', 'rev-grevlex', 'old', 'none'],
234
+ help='setup ordering of terms: [rev-]lex | [rev-]grlex | [rev-]grevlex | old | none; defaults to lex')
235
+
236
+ parser.add_argument(
237
+ '-q', '--quiet',
238
+ dest='quiet',
239
+ action='store_true',
240
+ default=False,
241
+ help='print only version information at startup')
242
+
243
+ parser.add_argument(
244
+ '-d', '--doctest',
245
+ dest='doctest',
246
+ action='store_true',
247
+ default=False,
248
+ help='use the doctest format for output (you can just copy and paste it)')
249
+
250
+ parser.add_argument(
251
+ '-C', '--no-cache',
252
+ dest='cache',
253
+ action='store_false',
254
+ default=True,
255
+ help='disable caching mechanism')
256
+
257
+ parser.add_argument(
258
+ '-a', '--auto-symbols',
259
+ dest='auto_symbols',
260
+ action='store_true',
261
+ default=False,
262
+ help='automatically construct missing symbols')
263
+
264
+ parser.add_argument(
265
+ '-i', '--int-to-Integer',
266
+ dest='auto_int_to_Integer',
267
+ action='store_true',
268
+ default=False,
269
+ help="automatically wrap int literals with Integer")
270
+
271
+ parser.add_argument(
272
+ '-I', '--interactive',
273
+ dest='interactive',
274
+ action='store_true',
275
+ default=False,
276
+ help="equivalent to -a -i")
277
+
278
+ parser.add_argument(
279
+ '-D', '--debug',
280
+ dest='debug',
281
+ action='store_true',
282
+ default=False,
283
+ help='enable debugging output')
284
+
285
+ (options, ipy_args) = parser.parse_known_args()
286
+ if '--' in ipy_args:
287
+ ipy_args.remove('--')
288
+
289
+ if not options.cache:
290
+ os.environ['SYMPY_USE_CACHE'] = 'no'
291
+
292
+ if options.types:
293
+ os.environ['SYMPY_GROUND_TYPES'] = options.types
294
+
295
+ if options.debug:
296
+ os.environ['SYMPY_DEBUG'] = str(options.debug)
297
+
298
+ if options.doctest:
299
+ options.pretty = 'no'
300
+ options.console = 'python'
301
+
302
+ session = options.console
303
+
304
+ if session is not None:
305
+ ipython = session == 'ipython'
306
+ else:
307
+ try:
308
+ import IPython
309
+ ipython = True
310
+ except ImportError:
311
+ if not options.quiet:
312
+ from sympy.interactive.session import no_ipython
313
+ print(no_ipython)
314
+ ipython = False
315
+
316
+ args = {
317
+ 'pretty_print': True,
318
+ 'use_unicode': None,
319
+ 'use_latex': None,
320
+ 'order': None,
321
+ 'argv': ipy_args,
322
+ }
323
+
324
+ if options.pretty == 'unicode':
325
+ args['use_unicode'] = True
326
+ elif options.pretty == 'ascii':
327
+ args['use_unicode'] = False
328
+ elif options.pretty == 'no':
329
+ args['pretty_print'] = False
330
+
331
+ if options.order is not None:
332
+ args['order'] = options.order
333
+
334
+ args['quiet'] = options.quiet
335
+ args['auto_symbols'] = options.auto_symbols or options.interactive
336
+ args['auto_int_to_Integer'] = options.auto_int_to_Integer or options.interactive
337
+
338
+ from sympy.interactive import init_session
339
+ init_session(ipython, **args)
340
+
341
+ if __name__ == "__main__":
342
+ main()
pythonProject/.venv/Lib/site-packages/numpy-2.2.6-cp310-cp310-win_amd64.whl ADDED
File without changes
pythonProject/.venv/Lib/site-packages/typing_extensions.py ADDED
The diff for this file is too large to render. See raw diff
 
pythonProject/.venv/pyvenv.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ home = C:\Users\ADMIN\AppData\Local\Programs\Python\Python310
2
+ include-system-site-packages = false
3
+ version = 3.10.11