Add files using upload-large-folder tool
Browse files- pythonProject/.idea/inspectionProfiles/profiles_settings.xml +6 -0
- pythonProject/.venv/Lib/site-packages/accelerate/accelerator.py +0 -0
- pythonProject/.venv/Lib/site-packages/accelerate/big_modeling.py +789 -0
- pythonProject/.venv/Lib/site-packages/accelerate/checkpointing.py +330 -0
- pythonProject/.venv/Lib/site-packages/accelerate/commands/merge.py +69 -0
- pythonProject/.venv/Lib/site-packages/accelerate/commands/test.py +65 -0
- pythonProject/.venv/Lib/site-packages/accelerate/data_loader.py +1451 -0
- pythonProject/.venv/Lib/site-packages/accelerate/hooks.py +776 -0
- pythonProject/.venv/Lib/site-packages/accelerate/inference.py +184 -0
- pythonProject/.venv/Lib/site-packages/accelerate/launchers.py +306 -0
- pythonProject/.venv/Lib/site-packages/accelerate/local_sgd.py +106 -0
- pythonProject/.venv/Lib/site-packages/accelerate/logging.py +125 -0
- pythonProject/.venv/Lib/site-packages/accelerate/memory_utils.py +22 -0
- pythonProject/.venv/Lib/site-packages/accelerate/optimizer.py +213 -0
- pythonProject/.venv/Lib/site-packages/accelerate/parallelism_config.py +322 -0
- pythonProject/.venv/Lib/site-packages/accelerate/scheduler.py +98 -0
- pythonProject/.venv/Lib/site-packages/isympy.py +342 -0
- pythonProject/.venv/Lib/site-packages/numpy-2.2.6-cp310-cp310-win_amd64.whl +0 -0
- pythonProject/.venv/Lib/site-packages/typing_extensions.py +0 -0
- pythonProject/.venv/pyvenv.cfg +3 -0
pythonProject/.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
pythonProject/.venv/Lib/site-packages/accelerate/accelerator.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pythonProject/.venv/Lib/site-packages/accelerate/big_modeling.py
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from functools import wraps
|
| 20 |
+
from typing import Optional, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
from .hooks import (
|
| 26 |
+
AlignDevicesHook,
|
| 27 |
+
CpuOffload,
|
| 28 |
+
LayerwiseCastingHook,
|
| 29 |
+
UserCpuOffloadHook,
|
| 30 |
+
add_hook_to_module,
|
| 31 |
+
attach_align_device_hook,
|
| 32 |
+
attach_align_device_hook_on_blocks,
|
| 33 |
+
)
|
| 34 |
+
from .utils import (
|
| 35 |
+
OffloadedWeightsLoader,
|
| 36 |
+
check_cuda_p2p_ib_support,
|
| 37 |
+
check_device_map,
|
| 38 |
+
extract_submodules_state_dict,
|
| 39 |
+
find_tied_parameters,
|
| 40 |
+
get_balanced_memory,
|
| 41 |
+
infer_auto_device_map,
|
| 42 |
+
is_bnb_available,
|
| 43 |
+
is_mlu_available,
|
| 44 |
+
is_musa_available,
|
| 45 |
+
is_npu_available,
|
| 46 |
+
is_sdaa_available,
|
| 47 |
+
is_xpu_available,
|
| 48 |
+
load_checkpoint_in_model,
|
| 49 |
+
offload_state_dict,
|
| 50 |
+
parse_flag_from_env,
|
| 51 |
+
retie_parameters,
|
| 52 |
+
)
|
| 53 |
+
from .utils.constants import SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING
|
| 54 |
+
from .utils.other import recursive_getattr
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
logger = logging.getLogger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@contextmanager
|
| 61 |
+
def init_empty_weights(include_buffers: bool = None):
|
| 62 |
+
"""
|
| 63 |
+
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
|
| 64 |
+
empty model. Useful when just initializing the model would blow the available RAM.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
include_buffers (`bool`, *optional*):
|
| 68 |
+
Whether or not to also put all buffers on the meta device while initializing.
|
| 69 |
+
|
| 70 |
+
Example:
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
from accelerate import init_empty_weights
|
| 75 |
+
|
| 76 |
+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
| 77 |
+
with init_empty_weights():
|
| 78 |
+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
<Tip warning={true}>
|
| 82 |
+
|
| 83 |
+
Any model created under this context manager has no weights. As such you can't do something like
|
| 84 |
+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
| 85 |
+
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
|
| 86 |
+
called.
|
| 87 |
+
|
| 88 |
+
</Tip>
|
| 89 |
+
"""
|
| 90 |
+
if include_buffers is None:
|
| 91 |
+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
|
| 92 |
+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
| 93 |
+
yield f
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@contextmanager
|
| 97 |
+
def init_on_device(device: torch.device, include_buffers: bool = None):
|
| 98 |
+
"""
|
| 99 |
+
A context manager under which models are initialized with all parameters on the specified device.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
device (`torch.device`):
|
| 103 |
+
Device to initialize all parameters on.
|
| 104 |
+
include_buffers (`bool`, *optional*):
|
| 105 |
+
Whether or not to also put all buffers on the meta device while initializing.
|
| 106 |
+
|
| 107 |
+
Example:
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
import torch.nn as nn
|
| 111 |
+
from accelerate import init_on_device
|
| 112 |
+
|
| 113 |
+
with init_on_device(device=torch.device("cuda")):
|
| 114 |
+
tst = nn.Linear(100, 100) # on `cuda` device
|
| 115 |
+
```
|
| 116 |
+
"""
|
| 117 |
+
if include_buffers is None:
|
| 118 |
+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
|
| 119 |
+
|
| 120 |
+
if include_buffers:
|
| 121 |
+
with device:
|
| 122 |
+
yield
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
old_register_parameter = nn.Module.register_parameter
|
| 126 |
+
if include_buffers:
|
| 127 |
+
old_register_buffer = nn.Module.register_buffer
|
| 128 |
+
|
| 129 |
+
def register_empty_parameter(module, name, param):
|
| 130 |
+
old_register_parameter(module, name, param)
|
| 131 |
+
if param is not None:
|
| 132 |
+
param_cls = type(module._parameters[name])
|
| 133 |
+
kwargs = module._parameters[name].__dict__
|
| 134 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 135 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 136 |
+
|
| 137 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
| 138 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
| 139 |
+
if buffer is not None:
|
| 140 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 141 |
+
|
| 142 |
+
# Patch tensor creation
|
| 143 |
+
if include_buffers:
|
| 144 |
+
tensor_constructors_to_patch = {
|
| 145 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 146 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 147 |
+
}
|
| 148 |
+
else:
|
| 149 |
+
tensor_constructors_to_patch = {}
|
| 150 |
+
|
| 151 |
+
def patch_tensor_constructor(fn):
|
| 152 |
+
def wrapper(*args, **kwargs):
|
| 153 |
+
kwargs["device"] = device
|
| 154 |
+
return fn(*args, **kwargs)
|
| 155 |
+
|
| 156 |
+
return wrapper
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
nn.Module.register_parameter = register_empty_parameter
|
| 160 |
+
if include_buffers:
|
| 161 |
+
nn.Module.register_buffer = register_empty_buffer
|
| 162 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 163 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 164 |
+
yield
|
| 165 |
+
finally:
|
| 166 |
+
nn.Module.register_parameter = old_register_parameter
|
| 167 |
+
if include_buffers:
|
| 168 |
+
nn.Module.register_buffer = old_register_buffer
|
| 169 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
| 170 |
+
setattr(torch, torch_function_name, old_torch_function)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def cpu_offload(
|
| 174 |
+
model: nn.Module,
|
| 175 |
+
execution_device: Optional[torch.device] = None,
|
| 176 |
+
offload_buffers: bool = False,
|
| 177 |
+
state_dict: Optional[dict[str, torch.Tensor]] = None,
|
| 178 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
|
| 182 |
+
copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
|
| 183 |
+
state dict and put on the execution device passed as they are needed, then offloaded again.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
model (`torch.nn.Module`):
|
| 187 |
+
The model to offload.
|
| 188 |
+
execution_device (`torch.device`, *optional*):
|
| 189 |
+
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
| 190 |
+
model first parameter device.
|
| 191 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 192 |
+
Whether or not to offload the buffers with the model parameters.
|
| 193 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
| 194 |
+
The state dict of the model that will be kept on CPU.
|
| 195 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 196 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 197 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 198 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 199 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 200 |
+
"""
|
| 201 |
+
if execution_device is None:
|
| 202 |
+
execution_device = next(iter(model.parameters())).device
|
| 203 |
+
if state_dict is None:
|
| 204 |
+
state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
|
| 205 |
+
|
| 206 |
+
add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
|
| 207 |
+
attach_align_device_hook(
|
| 208 |
+
model,
|
| 209 |
+
execution_device=execution_device,
|
| 210 |
+
offload=True,
|
| 211 |
+
offload_buffers=offload_buffers,
|
| 212 |
+
weights_map=state_dict,
|
| 213 |
+
preload_module_classes=preload_module_classes,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return model
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def cpu_offload_with_hook(
|
| 220 |
+
model: torch.nn.Module,
|
| 221 |
+
execution_device: Optional[Union[int, str, torch.device]] = None,
|
| 222 |
+
prev_module_hook: Optional[UserCpuOffloadHook] = None,
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
|
| 226 |
+
[`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
|
| 227 |
+
the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
model (`torch.nn.Module`):
|
| 231 |
+
The model to offload.
|
| 232 |
+
execution_device(`str`, `int` or `torch.device`, *optional*):
|
| 233 |
+
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
| 234 |
+
GPU 0 if there is a GPU, and finally to the CPU.
|
| 235 |
+
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
| 236 |
+
The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
|
| 237 |
+
offload method will be called just before the forward of the model to which this hook is attached.
|
| 238 |
+
|
| 239 |
+
Example:
|
| 240 |
+
|
| 241 |
+
```py
|
| 242 |
+
model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
|
| 243 |
+
model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
|
| 244 |
+
model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
|
| 245 |
+
|
| 246 |
+
hid_1 = model_1(input)
|
| 247 |
+
for i in range(50):
|
| 248 |
+
# model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
|
| 249 |
+
hid_2 = model_2(hid_1)
|
| 250 |
+
# model2 is offloaded to the CPU just before this forward.
|
| 251 |
+
hid_3 = model_3(hid_3)
|
| 252 |
+
|
| 253 |
+
# For model3, you need to manually call the hook offload method.
|
| 254 |
+
hook_3.offload()
|
| 255 |
+
```
|
| 256 |
+
"""
|
| 257 |
+
hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
|
| 258 |
+
add_hook_to_module(model, hook, append=True)
|
| 259 |
+
user_hook = UserCpuOffloadHook(model, hook)
|
| 260 |
+
return model, user_hook
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def disk_offload(
|
| 264 |
+
model: nn.Module,
|
| 265 |
+
offload_dir: Union[str, os.PathLike],
|
| 266 |
+
execution_device: Optional[torch.device] = None,
|
| 267 |
+
offload_buffers: bool = False,
|
| 268 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 269 |
+
):
|
| 270 |
+
"""
|
| 271 |
+
Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
|
| 272 |
+
memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
|
| 273 |
+
put on the execution device passed as they are needed, then offloaded again.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
model (`torch.nn.Module`): The model to offload.
|
| 277 |
+
offload_dir (`str` or `os.PathLike`):
|
| 278 |
+
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
| 279 |
+
execution_device (`torch.device`, *optional*):
|
| 280 |
+
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
| 281 |
+
model's first parameter device.
|
| 282 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 283 |
+
Whether or not to offload the buffers with the model parameters.
|
| 284 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 285 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 286 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 287 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 288 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 289 |
+
"""
|
| 290 |
+
if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
|
| 291 |
+
offload_state_dict(offload_dir, model.state_dict())
|
| 292 |
+
if execution_device is None:
|
| 293 |
+
execution_device = next(iter(model.parameters())).device
|
| 294 |
+
weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
|
| 295 |
+
|
| 296 |
+
add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
|
| 297 |
+
attach_align_device_hook(
|
| 298 |
+
model,
|
| 299 |
+
execution_device=execution_device,
|
| 300 |
+
offload=True,
|
| 301 |
+
offload_buffers=offload_buffers,
|
| 302 |
+
weights_map=weights_map,
|
| 303 |
+
preload_module_classes=preload_module_classes,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return model
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def dispatch_model(
|
| 310 |
+
model: nn.Module,
|
| 311 |
+
device_map: dict[str, Union[str, int, torch.device]],
|
| 312 |
+
main_device: Optional[torch.device] = None,
|
| 313 |
+
state_dict: Optional[dict[str, torch.Tensor]] = None,
|
| 314 |
+
offload_dir: Optional[Union[str, os.PathLike]] = None,
|
| 315 |
+
offload_index: Optional[dict[str, str]] = None,
|
| 316 |
+
offload_buffers: bool = False,
|
| 317 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 318 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 319 |
+
force_hooks: bool = False,
|
| 320 |
+
):
|
| 321 |
+
"""
|
| 322 |
+
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
|
| 323 |
+
the CPU or even the disk.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
model (`torch.nn.Module`):
|
| 327 |
+
The model to dispatch.
|
| 328 |
+
device_map (`Dict[str, Union[str, int, torch.device]]`):
|
| 329 |
+
A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
|
| 330 |
+
`"disk"` is accepted even if it's not a proper value for `torch.device`.
|
| 331 |
+
main_device (`str`, `int` or `torch.device`, *optional*):
|
| 332 |
+
The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
|
| 333 |
+
`"disk"`.
|
| 334 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
| 335 |
+
The state dict of the part of the model that will be kept on CPU.
|
| 336 |
+
offload_dir (`str` or `os.PathLike`):
|
| 337 |
+
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
| 338 |
+
offload_index (`Dict`, *optional*):
|
| 339 |
+
A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
|
| 340 |
+
to the index saved in `save_folder`.
|
| 341 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 342 |
+
Whether or not to offload the buffers with the model parameters.
|
| 343 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 344 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 345 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 346 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 347 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 348 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 349 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 350 |
+
force_hooks (`bool`, *optional*, defaults to `False`):
|
| 351 |
+
Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
|
| 352 |
+
single device.
|
| 353 |
+
"""
|
| 354 |
+
# Error early if the device map is incomplete.
|
| 355 |
+
check_device_map(model, device_map)
|
| 356 |
+
|
| 357 |
+
# We need to force hook for quantized model that can't be moved with to()
|
| 358 |
+
if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes":
|
| 359 |
+
# since bnb 0.43.2, we can move 4-bit model
|
| 360 |
+
if getattr(model, "is_loaded_in_8bit", False) or (
|
| 361 |
+
getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2")
|
| 362 |
+
):
|
| 363 |
+
force_hooks = True
|
| 364 |
+
|
| 365 |
+
# We attach hooks if the device_map has at least 2 different devices or if
|
| 366 |
+
# force_hooks is set to `True`. Otherwise, the model in already loaded
|
| 367 |
+
# in the unique device and the user can decide where to dispatch the model.
|
| 368 |
+
# If the model is quantized, we always force-dispatch the model
|
| 369 |
+
if (len(set(device_map.values())) > 1) or force_hooks:
|
| 370 |
+
if main_device is None:
|
| 371 |
+
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
|
| 372 |
+
main_device = "cpu"
|
| 373 |
+
else:
|
| 374 |
+
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
|
| 375 |
+
|
| 376 |
+
if main_device != "cpu":
|
| 377 |
+
cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
|
| 378 |
+
if state_dict is None and len(cpu_modules) > 0:
|
| 379 |
+
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
|
| 380 |
+
|
| 381 |
+
disk_modules = [name for name, device in device_map.items() if device == "disk"]
|
| 382 |
+
if offload_dir is None and offload_index is None and len(disk_modules) > 0:
|
| 383 |
+
raise ValueError(
|
| 384 |
+
"We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
|
| 385 |
+
f"need to be offloaded: {', '.join(disk_modules)}."
|
| 386 |
+
)
|
| 387 |
+
if (
|
| 388 |
+
len(disk_modules) > 0
|
| 389 |
+
and offload_index is None
|
| 390 |
+
and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")))
|
| 391 |
+
):
|
| 392 |
+
disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
|
| 393 |
+
offload_state_dict(offload_dir, disk_state_dict)
|
| 394 |
+
|
| 395 |
+
execution_device = {
|
| 396 |
+
name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
|
| 397 |
+
}
|
| 398 |
+
execution_device[""] = main_device
|
| 399 |
+
offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
|
| 400 |
+
offload = {name: device in offloaded_devices for name, device in device_map.items()}
|
| 401 |
+
save_folder = offload_dir if len(disk_modules) > 0 else None
|
| 402 |
+
if state_dict is not None or save_folder is not None or offload_index is not None:
|
| 403 |
+
device = main_device if offload_index is not None else None
|
| 404 |
+
weights_map = OffloadedWeightsLoader(
|
| 405 |
+
state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
|
| 406 |
+
)
|
| 407 |
+
else:
|
| 408 |
+
weights_map = None
|
| 409 |
+
|
| 410 |
+
# When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
|
| 411 |
+
# tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
|
| 412 |
+
# original pointer) on each devices.
|
| 413 |
+
tied_params = find_tied_parameters(model)
|
| 414 |
+
|
| 415 |
+
tied_params_map = {}
|
| 416 |
+
for group in tied_params:
|
| 417 |
+
for param_name in group:
|
| 418 |
+
# data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
|
| 419 |
+
# to care about views of tensors through storage_offset.
|
| 420 |
+
data_ptr = recursive_getattr(model, param_name).data_ptr()
|
| 421 |
+
tied_params_map[data_ptr] = {}
|
| 422 |
+
|
| 423 |
+
# Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
|
| 424 |
+
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
| 425 |
+
|
| 426 |
+
attach_align_device_hook_on_blocks(
|
| 427 |
+
model,
|
| 428 |
+
execution_device=execution_device,
|
| 429 |
+
offload=offload,
|
| 430 |
+
offload_buffers=offload_buffers,
|
| 431 |
+
weights_map=weights_map,
|
| 432 |
+
skip_keys=skip_keys,
|
| 433 |
+
preload_module_classes=preload_module_classes,
|
| 434 |
+
tied_params_map=tied_params_map,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# warn if there is any params on the meta device
|
| 438 |
+
offloaded_devices_str = " and ".join(
|
| 439 |
+
[device for device in set(device_map.values()) if device in ("cpu", "disk")]
|
| 440 |
+
)
|
| 441 |
+
if len(offloaded_devices_str) > 0:
|
| 442 |
+
logger.warning(
|
| 443 |
+
f"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}."
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Attaching the hook may break tied weights, so we retie them
|
| 447 |
+
retie_parameters(model, tied_params)
|
| 448 |
+
|
| 449 |
+
# add warning to cuda and to method
|
| 450 |
+
def add_warning(fn, model):
|
| 451 |
+
@wraps(fn)
|
| 452 |
+
def wrapper(*args, **kwargs):
|
| 453 |
+
warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks."
|
| 454 |
+
if str(fn.__name__) == "to":
|
| 455 |
+
to_device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
| 456 |
+
if to_device is not None:
|
| 457 |
+
logger.warning(warning_msg)
|
| 458 |
+
else:
|
| 459 |
+
logger.warning(warning_msg)
|
| 460 |
+
for param in model.parameters():
|
| 461 |
+
if param.device == torch.device("meta"):
|
| 462 |
+
raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.")
|
| 463 |
+
return fn(*args, **kwargs)
|
| 464 |
+
|
| 465 |
+
return wrapper
|
| 466 |
+
|
| 467 |
+
# Make sure to update _accelerate_added_attributes in hooks.py if you add any hook
|
| 468 |
+
model.to = add_warning(model.to, model)
|
| 469 |
+
if is_npu_available():
|
| 470 |
+
model.npu = add_warning(model.npu, model)
|
| 471 |
+
elif is_mlu_available():
|
| 472 |
+
model.mlu = add_warning(model.mlu, model)
|
| 473 |
+
elif is_sdaa_available():
|
| 474 |
+
model.sdaa = add_warning(model.sdaa, model)
|
| 475 |
+
elif is_musa_available():
|
| 476 |
+
model.musa = add_warning(model.musa, model)
|
| 477 |
+
elif is_xpu_available():
|
| 478 |
+
model.xpu = add_warning(model.xpu, model)
|
| 479 |
+
else:
|
| 480 |
+
model.cuda = add_warning(model.cuda, model)
|
| 481 |
+
|
| 482 |
+
# Check if we are using multi-gpus with RTX 4000 series
|
| 483 |
+
use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1
|
| 484 |
+
if use_multi_gpu and not check_cuda_p2p_ib_support():
|
| 485 |
+
logger.warning(
|
| 486 |
+
"We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. "
|
| 487 |
+
"This can affect the multi-gpu inference when using accelerate device_map."
|
| 488 |
+
"Please make sure to update your driver to the latest version which resolves this."
|
| 489 |
+
)
|
| 490 |
+
else:
|
| 491 |
+
device = list(device_map.values())[0]
|
| 492 |
+
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
|
| 493 |
+
if is_npu_available() and isinstance(device, int):
|
| 494 |
+
device = f"npu:{device}"
|
| 495 |
+
elif is_mlu_available() and isinstance(device, int):
|
| 496 |
+
device = f"mlu:{device}"
|
| 497 |
+
elif is_sdaa_available() and isinstance(device, int):
|
| 498 |
+
device = f"sdaa:{device}"
|
| 499 |
+
elif is_musa_available() and isinstance(device, int):
|
| 500 |
+
device = f"musa:{device}"
|
| 501 |
+
if device != "disk":
|
| 502 |
+
model.to(device)
|
| 503 |
+
else:
|
| 504 |
+
raise ValueError(
|
| 505 |
+
"You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead."
|
| 506 |
+
)
|
| 507 |
+
# Convert OrderedDict back to dict for easier usage
|
| 508 |
+
model.hf_device_map = dict(device_map)
|
| 509 |
+
return model
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def load_checkpoint_and_dispatch(
|
| 513 |
+
model: nn.Module,
|
| 514 |
+
checkpoint: Union[str, os.PathLike],
|
| 515 |
+
device_map: Optional[Union[str, dict[str, Union[int, str, torch.device]]]] = None,
|
| 516 |
+
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
|
| 517 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 518 |
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
| 519 |
+
offload_buffers: bool = False,
|
| 520 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
| 521 |
+
offload_state_dict: Optional[bool] = None,
|
| 522 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 523 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 524 |
+
force_hooks: bool = False,
|
| 525 |
+
strict: bool = False,
|
| 526 |
+
full_state_dict: bool = True,
|
| 527 |
+
broadcast_from_rank0: bool = False,
|
| 528 |
+
):
|
| 529 |
+
"""
|
| 530 |
+
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
|
| 531 |
+
loaded and adds the various hooks that will make this model run properly (even if split across devices).
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
model (`torch.nn.Module`): The model in which we want to load a checkpoint.
|
| 535 |
+
checkpoint (`str` or `os.PathLike`):
|
| 536 |
+
The folder checkpoint to load. It can be:
|
| 537 |
+
- a path to a file containing a whole model state dict
|
| 538 |
+
- a path to a `.json` file containing the index to a sharded checkpoint
|
| 539 |
+
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
| 540 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 541 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
| 542 |
+
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
| 543 |
+
|
| 544 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more
|
| 545 |
+
information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).
|
| 546 |
+
Defaults to None, which means [`dispatch_model`] will not be called.
|
| 547 |
+
max_memory (`Dict`, *optional*):
|
| 548 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
|
| 549 |
+
and the available CPU RAM if unset.
|
| 550 |
+
no_split_module_classes (`List[str]`, *optional*):
|
| 551 |
+
A list of layer class names that should never be split across device (for instance any layer that has a
|
| 552 |
+
residual connection).
|
| 553 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
| 554 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
| 555 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 556 |
+
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
|
| 557 |
+
well as the parameters.
|
| 558 |
+
dtype (`str` or `torch.dtype`, *optional*):
|
| 559 |
+
If provided, the weights will be converted to that type when loaded.
|
| 560 |
+
offload_state_dict (`bool`, *optional*):
|
| 561 |
+
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
|
| 562 |
+
the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
|
| 563 |
+
picked contains `"disk"` values.
|
| 564 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 565 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 566 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 567 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 568 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 569 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 570 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 571 |
+
force_hooks (`bool`, *optional*, defaults to `False`):
|
| 572 |
+
Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
|
| 573 |
+
single device.
|
| 574 |
+
strict (`bool`, *optional*, defaults to `False`):
|
| 575 |
+
Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
|
| 576 |
+
state_dict.
|
| 577 |
+
full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the
|
| 578 |
+
loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.
|
| 579 |
+
broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed
|
| 580 |
+
`ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors
|
| 581 |
+
in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)
|
| 582 |
+
according to the local shards in the model.
|
| 583 |
+
|
| 584 |
+
Example:
|
| 585 |
+
|
| 586 |
+
```python
|
| 587 |
+
>>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
| 588 |
+
>>> from huggingface_hub import hf_hub_download
|
| 589 |
+
>>> from transformers import AutoConfig, AutoModelForCausalLM
|
| 590 |
+
|
| 591 |
+
>>> # Download the Weights
|
| 592 |
+
>>> checkpoint = "EleutherAI/gpt-j-6B"
|
| 593 |
+
>>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
|
| 594 |
+
|
| 595 |
+
>>> # Create a model and initialize it with empty weights
|
| 596 |
+
>>> config = AutoConfig.from_pretrained(checkpoint)
|
| 597 |
+
>>> with init_empty_weights():
|
| 598 |
+
... model = AutoModelForCausalLM.from_config(config)
|
| 599 |
+
|
| 600 |
+
>>> # Load the checkpoint and dispatch it to the right devices
|
| 601 |
+
>>> model = load_checkpoint_and_dispatch(
|
| 602 |
+
... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
|
| 603 |
+
... )
|
| 604 |
+
```
|
| 605 |
+
"""
|
| 606 |
+
if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
| 607 |
+
raise ValueError(
|
| 608 |
+
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or 'sequential'."
|
| 609 |
+
)
|
| 610 |
+
if isinstance(device_map, str):
|
| 611 |
+
if device_map != "sequential":
|
| 612 |
+
max_memory = get_balanced_memory(
|
| 613 |
+
model,
|
| 614 |
+
max_memory=max_memory,
|
| 615 |
+
no_split_module_classes=no_split_module_classes,
|
| 616 |
+
dtype=dtype,
|
| 617 |
+
low_zero=(device_map == "balanced_low_0"),
|
| 618 |
+
)
|
| 619 |
+
device_map = infer_auto_device_map(
|
| 620 |
+
model,
|
| 621 |
+
max_memory=max_memory,
|
| 622 |
+
no_split_module_classes=no_split_module_classes,
|
| 623 |
+
dtype=dtype,
|
| 624 |
+
offload_buffers=offload_buffers,
|
| 625 |
+
)
|
| 626 |
+
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
|
| 627 |
+
offload_state_dict = True
|
| 628 |
+
load_checkpoint_in_model(
|
| 629 |
+
model,
|
| 630 |
+
checkpoint,
|
| 631 |
+
device_map=device_map,
|
| 632 |
+
offload_folder=offload_folder,
|
| 633 |
+
dtype=dtype,
|
| 634 |
+
offload_state_dict=offload_state_dict,
|
| 635 |
+
offload_buffers=offload_buffers,
|
| 636 |
+
strict=strict,
|
| 637 |
+
full_state_dict=full_state_dict,
|
| 638 |
+
broadcast_from_rank0=broadcast_from_rank0,
|
| 639 |
+
)
|
| 640 |
+
if device_map is None:
|
| 641 |
+
return model
|
| 642 |
+
return dispatch_model(
|
| 643 |
+
model,
|
| 644 |
+
device_map=device_map,
|
| 645 |
+
offload_dir=offload_folder,
|
| 646 |
+
offload_buffers=offload_buffers,
|
| 647 |
+
skip_keys=skip_keys,
|
| 648 |
+
preload_module_classes=preload_module_classes,
|
| 649 |
+
force_hooks=force_hooks,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def attach_layerwise_casting_hooks(
|
| 654 |
+
module: torch.nn.Module,
|
| 655 |
+
storage_dtype: torch.dtype,
|
| 656 |
+
compute_dtype: torch.dtype,
|
| 657 |
+
skip_modules_pattern: Union[str, tuple[str, ...]] = None,
|
| 658 |
+
skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,
|
| 659 |
+
non_blocking: bool = False,
|
| 660 |
+
) -> None:
|
| 661 |
+
r"""
|
| 662 |
+
Applies layerwise casting to a given module. The module expected here is a PyTorch `nn.Module`. This is helpful for
|
| 663 |
+
reducing memory requirements when one doesn't want to fully quantize a model. Model params can be kept in say,
|
| 664 |
+
`torch.float8_e4m3fn` and upcasted to a higher precision like `torch.bfloat16` during forward pass and downcasted
|
| 665 |
+
back to `torch.float8_e4m3fn` to realize memory savings.
|
| 666 |
+
|
| 667 |
+
Args:
|
| 668 |
+
module (`torch.nn.Module`):
|
| 669 |
+
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
|
| 670 |
+
precision dtype for storage.
|
| 671 |
+
storage_dtype (`torch.dtype`):
|
| 672 |
+
The dtype to cast the module to before/after the forward pass for storage.
|
| 673 |
+
compute_dtype (`torch.dtype`):
|
| 674 |
+
The dtype to cast the module to during the forward pass for computation.
|
| 675 |
+
skip_modules_pattern (`tuple[str, ...]`, defaults to `None`):
|
| 676 |
+
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
|
| 677 |
+
to `None` alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the
|
| 678 |
+
module instead of its internal submodules.
|
| 679 |
+
skip_modules_classes (`tuple[type[torch.nn.Module], ...]`, defaults to `None`):
|
| 680 |
+
A list of module classes to skip during the layerwise casting process.
|
| 681 |
+
non_blocking (`bool`, defaults to `False`):
|
| 682 |
+
If `True`, the weight casting operations are non-blocking.
|
| 683 |
+
|
| 684 |
+
Example:
|
| 685 |
+
|
| 686 |
+
```python
|
| 687 |
+
>>> from accelerate.hooks import attach_layerwise_casting_hooks
|
| 688 |
+
>>> from transformers import AutoModelForCausalLM
|
| 689 |
+
>>> import torch
|
| 690 |
+
|
| 691 |
+
>>> # Model
|
| 692 |
+
>>> checkpoint = "EleutherAI/gpt-j-6B"
|
| 693 |
+
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
| 694 |
+
|
| 695 |
+
>>> # Attach hooks and perform inference
|
| 696 |
+
>>> attach_layerwise_casting_hooks(model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
| 697 |
+
>>> with torch.no_grad():
|
| 698 |
+
... model(...)
|
| 699 |
+
```
|
| 700 |
+
|
| 701 |
+
Users can also pass modules they want to avoid from getting downcasted.
|
| 702 |
+
|
| 703 |
+
```py
|
| 704 |
+
>>> attach_layerwise_casting_hooks(
|
| 705 |
+
... model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16, skip_modules_pattern=["norm"]
|
| 706 |
+
... )
|
| 707 |
+
```
|
| 708 |
+
"""
|
| 709 |
+
_attach_layerwise_casting_hooks(
|
| 710 |
+
module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def _attach_layerwise_casting_hooks(
|
| 715 |
+
module: torch.nn.Module,
|
| 716 |
+
storage_dtype: torch.dtype,
|
| 717 |
+
compute_dtype: torch.dtype,
|
| 718 |
+
skip_modules_pattern: Union[str, tuple[str, ...]] = None,
|
| 719 |
+
skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,
|
| 720 |
+
non_blocking: bool = False,
|
| 721 |
+
_prefix: str = "",
|
| 722 |
+
):
|
| 723 |
+
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
|
| 724 |
+
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
|
| 725 |
+
)
|
| 726 |
+
if should_skip:
|
| 727 |
+
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
|
| 728 |
+
return
|
| 729 |
+
|
| 730 |
+
if isinstance(module, SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING):
|
| 731 |
+
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
|
| 732 |
+
add_hook_to_module(
|
| 733 |
+
module,
|
| 734 |
+
LayerwiseCastingHook(storage_dtype=storage_dtype, compute_dtype=compute_dtype, non_blocking=non_blocking),
|
| 735 |
+
append=True,
|
| 736 |
+
)
|
| 737 |
+
return
|
| 738 |
+
|
| 739 |
+
for name, submodule in module.named_children():
|
| 740 |
+
layer_name = f"{_prefix}.{name}" if _prefix else name
|
| 741 |
+
_attach_layerwise_casting_hooks(
|
| 742 |
+
submodule,
|
| 743 |
+
storage_dtype,
|
| 744 |
+
compute_dtype,
|
| 745 |
+
skip_modules_pattern,
|
| 746 |
+
skip_modules_classes,
|
| 747 |
+
non_blocking,
|
| 748 |
+
_prefix=layer_name,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def _attach_context_parallel_hooks(
|
| 753 |
+
model: nn.Module,
|
| 754 |
+
):
|
| 755 |
+
"""
|
| 756 |
+
Monkeypatch huggingface's `transformers` model to fix attention mask issues when using context parallelism.
|
| 757 |
+
|
| 758 |
+
This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the
|
| 759 |
+
args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,
|
| 760 |
+
if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does
|
| 761 |
+
not support attention masks. This function modifies the model in place.
|
| 762 |
+
|
| 763 |
+
Args:
|
| 764 |
+
model (`nn.Module`):
|
| 765 |
+
The model to attach the hooks to.
|
| 766 |
+
|
| 767 |
+
"""
|
| 768 |
+
|
| 769 |
+
def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
|
| 770 |
+
if "attention_mask" in module_kwargs:
|
| 771 |
+
module_kwargs["attention_mask"] = None
|
| 772 |
+
module_kwargs["is_causal"] = True
|
| 773 |
+
|
| 774 |
+
return module_args, module_kwargs
|
| 775 |
+
|
| 776 |
+
for name, module in model.named_modules():
|
| 777 |
+
# We hope (assume) that if user uses their own model (without this structure which transformers uses), they read the docs saying they can't pass in attention masks
|
| 778 |
+
# Then these cases can happen:
|
| 779 |
+
# 1) some modules end with a `self-attn` module, in which case we attach the hook, but the
|
| 780 |
+
# there's no attention mask kwarg -> hook is a no-op
|
| 781 |
+
# 2) some modules end with a `self-attn` module, in which case we attach the hook, and the
|
| 782 |
+
# attention mask kwarg is passed -> hook will remove the attention mask and add
|
| 783 |
+
# `is_causal=True` kwarg, which either crashes the training or fixes it
|
| 784 |
+
# (training would crash anyway as attention mask isn't supported)
|
| 785 |
+
# 3) no modules end with a `self-attn` module, in which case we don't attach the hook, this is
|
| 786 |
+
# a no-op as well
|
| 787 |
+
if name.endswith("self_attn"):
|
| 788 |
+
# we want the hook to be executed first, to avoid any other hooks doing work on the attention mask
|
| 789 |
+
module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)
|
pythonProject/.venv/Lib/site-packages/accelerate/checkpointing.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from safetensors.torch import load_model
|
| 21 |
+
|
| 22 |
+
from .utils import (
|
| 23 |
+
MODEL_NAME,
|
| 24 |
+
OPTIMIZER_NAME,
|
| 25 |
+
RNG_STATE_NAME,
|
| 26 |
+
SAFE_MODEL_NAME,
|
| 27 |
+
SAFE_WEIGHTS_NAME,
|
| 28 |
+
SAMPLER_NAME,
|
| 29 |
+
SCALER_NAME,
|
| 30 |
+
SCHEDULER_NAME,
|
| 31 |
+
WEIGHTS_NAME,
|
| 32 |
+
get_pretty_name,
|
| 33 |
+
is_cuda_available,
|
| 34 |
+
is_hpu_available,
|
| 35 |
+
is_mlu_available,
|
| 36 |
+
is_musa_available,
|
| 37 |
+
is_sdaa_available,
|
| 38 |
+
is_torch_version,
|
| 39 |
+
is_torch_xla_available,
|
| 40 |
+
is_xpu_available,
|
| 41 |
+
load,
|
| 42 |
+
save,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if is_torch_version(">=", "2.4.0"):
|
| 47 |
+
from torch.amp import GradScaler
|
| 48 |
+
else:
|
| 49 |
+
from torch.cuda.amp import GradScaler
|
| 50 |
+
|
| 51 |
+
if is_torch_xla_available():
|
| 52 |
+
import torch_xla.core.xla_model as xm
|
| 53 |
+
|
| 54 |
+
from .logging import get_logger
|
| 55 |
+
from .state import PartialState
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
logger = get_logger(__name__)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def save_accelerator_state(
|
| 62 |
+
output_dir: str,
|
| 63 |
+
model_states: list[dict],
|
| 64 |
+
optimizers: list,
|
| 65 |
+
schedulers: list,
|
| 66 |
+
dataloaders: list,
|
| 67 |
+
process_index: int,
|
| 68 |
+
step: int,
|
| 69 |
+
scaler: GradScaler = None,
|
| 70 |
+
save_on_each_node: bool = False,
|
| 71 |
+
safe_serialization: bool = True,
|
| 72 |
+
):
|
| 73 |
+
"""
|
| 74 |
+
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
|
| 75 |
+
|
| 76 |
+
<Tip>
|
| 77 |
+
|
| 78 |
+
If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
|
| 79 |
+
`pickle`.
|
| 80 |
+
|
| 81 |
+
</Tip>
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
output_dir (`str` or `os.PathLike`):
|
| 85 |
+
The name of the folder to save all relevant weights and states.
|
| 86 |
+
model_states (`List[torch.nn.Module]`):
|
| 87 |
+
A list of model states
|
| 88 |
+
optimizers (`List[torch.optim.Optimizer]`):
|
| 89 |
+
A list of optimizer instances
|
| 90 |
+
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
|
| 91 |
+
A list of learning rate schedulers
|
| 92 |
+
dataloaders (`List[torch.utils.data.DataLoader]`):
|
| 93 |
+
A list of dataloader instances to save their sampler states
|
| 94 |
+
process_index (`int`):
|
| 95 |
+
The current process index in the Accelerator state
|
| 96 |
+
step (`int`):
|
| 97 |
+
The current step in the internal step tracker
|
| 98 |
+
scaler (`torch.amp.GradScaler`, *optional*):
|
| 99 |
+
An optional gradient scaler instance to save;
|
| 100 |
+
save_on_each_node (`bool`, *optional*):
|
| 101 |
+
Whether to save on every node, or only the main node.
|
| 102 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
| 103 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
| 104 |
+
"""
|
| 105 |
+
output_dir = Path(output_dir)
|
| 106 |
+
# Model states
|
| 107 |
+
for i, state in enumerate(model_states):
|
| 108 |
+
weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
|
| 109 |
+
if i > 0:
|
| 110 |
+
weights_name = weights_name.replace(".", f"_{i}.")
|
| 111 |
+
output_model_file = output_dir.joinpath(weights_name)
|
| 112 |
+
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
|
| 113 |
+
logger.info(f"Model weights saved in {output_model_file}")
|
| 114 |
+
# Optimizer states
|
| 115 |
+
for i, opt in enumerate(optimizers):
|
| 116 |
+
state = opt.state_dict()
|
| 117 |
+
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
|
| 118 |
+
output_optimizer_file = output_dir.joinpath(optimizer_name)
|
| 119 |
+
save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
| 120 |
+
logger.info(f"Optimizer state saved in {output_optimizer_file}")
|
| 121 |
+
# Scheduler states
|
| 122 |
+
for i, scheduler in enumerate(schedulers):
|
| 123 |
+
state = scheduler.state_dict()
|
| 124 |
+
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
|
| 125 |
+
output_scheduler_file = output_dir.joinpath(scheduler_name)
|
| 126 |
+
save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
| 127 |
+
logger.info(f"Scheduler state saved in {output_scheduler_file}")
|
| 128 |
+
# DataLoader states
|
| 129 |
+
for i, dataloader in enumerate(dataloaders):
|
| 130 |
+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
| 131 |
+
output_sampler_file = output_dir.joinpath(sampler_name)
|
| 132 |
+
# Only save if we have our custom sampler
|
| 133 |
+
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
| 134 |
+
|
| 135 |
+
if isinstance(dataloader.dataset, IterableDatasetShard):
|
| 136 |
+
sampler = dataloader.get_sampler()
|
| 137 |
+
if isinstance(sampler, SeedableRandomSampler):
|
| 138 |
+
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
| 139 |
+
if getattr(dataloader, "use_stateful_dataloader", False):
|
| 140 |
+
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
| 141 |
+
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
|
| 142 |
+
state_dict = dataloader.state_dict()
|
| 143 |
+
torch.save(state_dict, output_dataloader_state_dict_file)
|
| 144 |
+
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
|
| 145 |
+
|
| 146 |
+
# GradScaler state
|
| 147 |
+
if scaler is not None:
|
| 148 |
+
state = scaler.state_dict()
|
| 149 |
+
output_scaler_file = output_dir.joinpath(SCALER_NAME)
|
| 150 |
+
torch.save(state, output_scaler_file)
|
| 151 |
+
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
|
| 152 |
+
# Random number generator states
|
| 153 |
+
states = {}
|
| 154 |
+
states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
|
| 155 |
+
states["step"] = step
|
| 156 |
+
states["random_state"] = random.getstate()
|
| 157 |
+
states["numpy_random_seed"] = np.random.get_state()
|
| 158 |
+
states["torch_manual_seed"] = torch.get_rng_state()
|
| 159 |
+
if is_xpu_available():
|
| 160 |
+
states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
|
| 161 |
+
if is_mlu_available():
|
| 162 |
+
states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
|
| 163 |
+
elif is_sdaa_available():
|
| 164 |
+
states["torch_sdaa_manual_seed"] = torch.sdaa.get_rng_state_all()
|
| 165 |
+
elif is_musa_available():
|
| 166 |
+
states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
|
| 167 |
+
if is_hpu_available():
|
| 168 |
+
states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
|
| 169 |
+
if is_cuda_available():
|
| 170 |
+
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
|
| 171 |
+
if is_torch_xla_available():
|
| 172 |
+
states["xm_seed"] = xm.get_rng_state()
|
| 173 |
+
output_states_file = output_dir.joinpath(states_name)
|
| 174 |
+
torch.save(states, output_states_file)
|
| 175 |
+
logger.info(f"Random states saved in {output_states_file}")
|
| 176 |
+
return output_dir
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def load_accelerator_state(
|
| 180 |
+
input_dir,
|
| 181 |
+
models,
|
| 182 |
+
optimizers,
|
| 183 |
+
schedulers,
|
| 184 |
+
dataloaders,
|
| 185 |
+
process_index,
|
| 186 |
+
scaler=None,
|
| 187 |
+
map_location=None,
|
| 188 |
+
load_kwargs=None,
|
| 189 |
+
**load_model_func_kwargs,
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
input_dir (`str` or `os.PathLike`):
|
| 196 |
+
The name of the folder to load all relevant weights and states.
|
| 197 |
+
models (`List[torch.nn.Module]`):
|
| 198 |
+
A list of model instances
|
| 199 |
+
optimizers (`List[torch.optim.Optimizer]`):
|
| 200 |
+
A list of optimizer instances
|
| 201 |
+
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
|
| 202 |
+
A list of learning rate schedulers
|
| 203 |
+
process_index (`int`):
|
| 204 |
+
The current process index in the Accelerator state
|
| 205 |
+
scaler (`torch.amp.GradScaler`, *optional*):
|
| 206 |
+
An optional *GradScaler* instance to load
|
| 207 |
+
map_location (`str`, *optional*):
|
| 208 |
+
What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
|
| 209 |
+
load_kwargs (`dict`, *optional*):
|
| 210 |
+
Additional arguments that can be passed to the `load` function.
|
| 211 |
+
load_model_func_kwargs (`dict`, *optional*):
|
| 212 |
+
Additional arguments that can be passed to the model's `load_state_dict` method.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
`dict`: Contains the `Accelerator` attributes to override while loading the state.
|
| 216 |
+
"""
|
| 217 |
+
# stores the `Accelerator` attributes to override
|
| 218 |
+
override_attributes = dict()
|
| 219 |
+
if map_location not in [None, "cpu", "on_device"]:
|
| 220 |
+
raise TypeError(
|
| 221 |
+
"Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
|
| 222 |
+
)
|
| 223 |
+
if map_location is None:
|
| 224 |
+
map_location = "cpu"
|
| 225 |
+
elif map_location == "on_device":
|
| 226 |
+
map_location = PartialState().device
|
| 227 |
+
|
| 228 |
+
if load_kwargs is None:
|
| 229 |
+
load_kwargs = {}
|
| 230 |
+
|
| 231 |
+
input_dir = Path(input_dir)
|
| 232 |
+
# Model states
|
| 233 |
+
for i, model in enumerate(models):
|
| 234 |
+
ending = f"_{i}" if i > 0 else ""
|
| 235 |
+
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
|
| 236 |
+
if input_model_file.exists():
|
| 237 |
+
load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
|
| 238 |
+
else:
|
| 239 |
+
# Load with torch
|
| 240 |
+
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
|
| 241 |
+
state_dict = load(input_model_file, map_location=map_location)
|
| 242 |
+
model.load_state_dict(state_dict, **load_model_func_kwargs)
|
| 243 |
+
logger.info("All model weights loaded successfully")
|
| 244 |
+
|
| 245 |
+
# Optimizer states
|
| 246 |
+
for i, opt in enumerate(optimizers):
|
| 247 |
+
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
|
| 248 |
+
input_optimizer_file = input_dir.joinpath(optimizer_name)
|
| 249 |
+
optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)
|
| 250 |
+
optimizers[i].load_state_dict(optimizer_state)
|
| 251 |
+
logger.info("All optimizer states loaded successfully")
|
| 252 |
+
|
| 253 |
+
# Scheduler states
|
| 254 |
+
for i, scheduler in enumerate(schedulers):
|
| 255 |
+
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
|
| 256 |
+
input_scheduler_file = input_dir.joinpath(scheduler_name)
|
| 257 |
+
scheduler_state = load(input_scheduler_file, **load_kwargs)
|
| 258 |
+
scheduler.load_state_dict(scheduler_state)
|
| 259 |
+
logger.info("All scheduler states loaded successfully")
|
| 260 |
+
|
| 261 |
+
for i, dataloader in enumerate(dataloaders):
|
| 262 |
+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
| 263 |
+
input_sampler_file = input_dir.joinpath(sampler_name)
|
| 264 |
+
# Only load if we have our custom sampler
|
| 265 |
+
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
| 266 |
+
|
| 267 |
+
if isinstance(dataloader.dataset, IterableDatasetShard):
|
| 268 |
+
sampler = dataloader.get_sampler()
|
| 269 |
+
if isinstance(sampler, SeedableRandomSampler):
|
| 270 |
+
sampler = dataloader.set_sampler(load(input_sampler_file))
|
| 271 |
+
if getattr(dataloader, "use_stateful_dataloader", False):
|
| 272 |
+
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
| 273 |
+
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
|
| 274 |
+
if input_dataloader_state_dict_file.exists():
|
| 275 |
+
state_dict = load(input_dataloader_state_dict_file, **load_kwargs)
|
| 276 |
+
dataloader.load_state_dict(state_dict)
|
| 277 |
+
logger.info("All dataloader sampler states loaded successfully")
|
| 278 |
+
|
| 279 |
+
# GradScaler state
|
| 280 |
+
if scaler is not None:
|
| 281 |
+
input_scaler_file = input_dir.joinpath(SCALER_NAME)
|
| 282 |
+
scaler_state = load(input_scaler_file)
|
| 283 |
+
scaler.load_state_dict(scaler_state)
|
| 284 |
+
logger.info("GradScaler state loaded successfully")
|
| 285 |
+
|
| 286 |
+
# Random states
|
| 287 |
+
try:
|
| 288 |
+
states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
|
| 289 |
+
if "step" in states:
|
| 290 |
+
override_attributes["step"] = states["step"]
|
| 291 |
+
random.setstate(states["random_state"])
|
| 292 |
+
np.random.set_state(states["numpy_random_seed"])
|
| 293 |
+
torch.set_rng_state(states["torch_manual_seed"])
|
| 294 |
+
if is_xpu_available():
|
| 295 |
+
torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
|
| 296 |
+
if is_mlu_available():
|
| 297 |
+
torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
|
| 298 |
+
elif is_sdaa_available():
|
| 299 |
+
torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"])
|
| 300 |
+
elif is_musa_available():
|
| 301 |
+
torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
|
| 302 |
+
else:
|
| 303 |
+
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
|
| 304 |
+
if is_torch_xla_available():
|
| 305 |
+
xm.set_rng_state(states["xm_seed"])
|
| 306 |
+
logger.info("All random states loaded successfully")
|
| 307 |
+
except Exception:
|
| 308 |
+
logger.info("Could not load random states")
|
| 309 |
+
|
| 310 |
+
return override_attributes
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
|
| 314 |
+
"""
|
| 315 |
+
Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
|
| 316 |
+
"""
|
| 317 |
+
# Should this be the right way to get a qual_name type value from `obj`?
|
| 318 |
+
save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
|
| 319 |
+
logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
|
| 320 |
+
save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def load_custom_state(obj, path, index: int = 0):
|
| 324 |
+
"""
|
| 325 |
+
Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
|
| 326 |
+
loading the state.
|
| 327 |
+
"""
|
| 328 |
+
load_location = f"{path}/custom_checkpoint_{index}.pkl"
|
| 329 |
+
logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
|
| 330 |
+
obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))
|
pythonProject/.venv/Lib/site-packages/accelerate/commands/merge.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 17 |
+
from accelerate.utils import merge_fsdp_weights
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
|
| 21 |
+
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
|
| 22 |
+
|
| 23 |
+
This is a CPU-bound process and requires enough RAM to load the entire model state dict."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def merge_command(args):
|
| 27 |
+
merge_fsdp_weights(
|
| 28 |
+
args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def merge_command_parser(subparsers=None):
|
| 33 |
+
if subparsers is not None:
|
| 34 |
+
parser = subparsers.add_parser("merge-weights", description=description)
|
| 35 |
+
else:
|
| 36 |
+
parser = CustomArgumentParser(description=description)
|
| 37 |
+
|
| 38 |
+
parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"output_path",
|
| 41 |
+
type=str,
|
| 42 |
+
help="The path to save the merged weights. Defaults to the current directory. ",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--unsafe_serialization",
|
| 46 |
+
action="store_true",
|
| 47 |
+
default=False,
|
| 48 |
+
help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--remove_checkpoint_dir",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Whether to remove the checkpoint directory after merging.",
|
| 54 |
+
default=False,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if subparsers is not None:
|
| 58 |
+
parser.set_defaults(func=merge_command)
|
| 59 |
+
return parser
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
parser = merge_command_parser()
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
merge_command(args)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
pythonProject/.venv/Lib/site-packages/accelerate/commands/test.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_command_parser(subparsers=None):
|
| 23 |
+
if subparsers is not None:
|
| 24 |
+
parser = subparsers.add_parser("test")
|
| 25 |
+
else:
|
| 26 |
+
parser = argparse.ArgumentParser("Accelerate test command")
|
| 27 |
+
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--config_file",
|
| 30 |
+
default=None,
|
| 31 |
+
help=(
|
| 32 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
| 33 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 34 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 35 |
+
"with 'huggingface'."
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if subparsers is not None:
|
| 40 |
+
parser.set_defaults(func=test_command)
|
| 41 |
+
return parser
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_command(args):
|
| 45 |
+
script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py")
|
| 46 |
+
|
| 47 |
+
if args.config_file is None:
|
| 48 |
+
test_args = [script_name]
|
| 49 |
+
else:
|
| 50 |
+
test_args = f"--config_file={args.config_file} {script_name}".split()
|
| 51 |
+
|
| 52 |
+
cmd = ["accelerate-launch"] + test_args
|
| 53 |
+
result = execute_subprocess_async(cmd)
|
| 54 |
+
if result.returncode == 0:
|
| 55 |
+
print("Test is a success! You are ready for your distributed training!")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
parser = test_command_parser()
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
test_command(args)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
pythonProject/.venv/Lib/site-packages/accelerate/data_loader.py
ADDED
|
@@ -0,0 +1,1451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import importlib
|
| 16 |
+
import math
|
| 17 |
+
from contextlib import suppress
|
| 18 |
+
from typing import Callable, Optional, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from packaging import version
|
| 22 |
+
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
|
| 23 |
+
|
| 24 |
+
from .logging import get_logger
|
| 25 |
+
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
|
| 26 |
+
from .utils import (
|
| 27 |
+
RNGType,
|
| 28 |
+
broadcast,
|
| 29 |
+
broadcast_object_list,
|
| 30 |
+
compare_versions,
|
| 31 |
+
concatenate,
|
| 32 |
+
find_batch_size,
|
| 33 |
+
get_data_structure,
|
| 34 |
+
initialize_tensors,
|
| 35 |
+
is_datasets_available,
|
| 36 |
+
is_torch_version,
|
| 37 |
+
is_torchdata_stateful_dataloader_available,
|
| 38 |
+
send_to_device,
|
| 39 |
+
slice_tensors,
|
| 40 |
+
synchronize_rng_states,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
# kwargs of the DataLoader in min version 2.0
|
| 47 |
+
_PYTORCH_DATALOADER_KWARGS = {
|
| 48 |
+
"batch_size": 1,
|
| 49 |
+
"shuffle": False,
|
| 50 |
+
"sampler": None,
|
| 51 |
+
"batch_sampler": None,
|
| 52 |
+
"num_workers": 0,
|
| 53 |
+
"collate_fn": None,
|
| 54 |
+
"pin_memory": False,
|
| 55 |
+
"drop_last": False,
|
| 56 |
+
"timeout": 0,
|
| 57 |
+
"worker_init_fn": None,
|
| 58 |
+
"multiprocessing_context": None,
|
| 59 |
+
"generator": None,
|
| 60 |
+
"prefetch_factor": 2,
|
| 61 |
+
"persistent_workers": False,
|
| 62 |
+
"pin_memory_device": "",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# kwargs added after by version
|
| 66 |
+
_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}}
|
| 67 |
+
|
| 68 |
+
for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
|
| 69 |
+
if is_torch_version(">=", v):
|
| 70 |
+
_PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SeedableRandomSampler(RandomSampler):
|
| 74 |
+
"""
|
| 75 |
+
Same as a random sampler, except that in `__iter__` a seed can be used.
|
| 76 |
+
|
| 77 |
+
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
|
| 78 |
+
and be fully reproducable on multiple iterations.
|
| 79 |
+
|
| 80 |
+
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
|
| 81 |
+
(stored in `self.epoch`).
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, *args, **kwargs):
|
| 85 |
+
data_seed = kwargs.pop("data_seed", None)
|
| 86 |
+
super().__init__(*args, **kwargs)
|
| 87 |
+
|
| 88 |
+
self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
|
| 89 |
+
self.epoch = 0
|
| 90 |
+
|
| 91 |
+
def __iter__(self):
|
| 92 |
+
if self.generator is None:
|
| 93 |
+
self.generator = torch.Generator(
|
| 94 |
+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
| 95 |
+
)
|
| 96 |
+
self.generator.manual_seed(self.initial_seed)
|
| 97 |
+
|
| 98 |
+
# Allow `self.epoch` to modify the seed of the generator
|
| 99 |
+
seed = self.epoch + self.initial_seed
|
| 100 |
+
# print("Setting seed at epoch", self.epoch, seed)
|
| 101 |
+
self.generator.manual_seed(seed)
|
| 102 |
+
yield from super().__iter__()
|
| 103 |
+
self.set_epoch(self.epoch + 1)
|
| 104 |
+
|
| 105 |
+
def set_epoch(self, epoch: int):
|
| 106 |
+
"Sets the current iteration of the sampler."
|
| 107 |
+
self.epoch = epoch
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class BatchSamplerShard(BatchSampler):
|
| 111 |
+
"""
|
| 112 |
+
Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
|
| 113 |
+
always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
|
| 114 |
+
Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
|
| 115 |
+
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
batch_sampler (`torch.utils.data.sampler.BatchSampler`):
|
| 119 |
+
The batch sampler to split in several shards.
|
| 120 |
+
num_processes (`int`, *optional*, defaults to 1):
|
| 121 |
+
The number of processes running concurrently.
|
| 122 |
+
process_index (`int`, *optional*, defaults to 0):
|
| 123 |
+
The index of the current process.
|
| 124 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 125 |
+
Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
|
| 126 |
+
yielding different full batches on each process.
|
| 127 |
+
|
| 128 |
+
On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
|
| 129 |
+
|
| 130 |
+
- the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
|
| 131 |
+
this argument is set to `False`.
|
| 132 |
+
- the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
|
| 133 |
+
then `[6, 7]` if this argument is set to `True`.
|
| 134 |
+
even_batches (`bool`, *optional*, defaults to `True`):
|
| 135 |
+
Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
|
| 136 |
+
multiple of (original batch size / number of processes).
|
| 137 |
+
|
| 138 |
+
<Tip warning={true}>
|
| 139 |
+
|
| 140 |
+
`BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
|
| 141 |
+
equal to `False`
|
| 142 |
+
|
| 143 |
+
</Tip>"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
batch_sampler: BatchSampler,
|
| 148 |
+
num_processes: int = 1,
|
| 149 |
+
process_index: int = 0,
|
| 150 |
+
split_batches: bool = False,
|
| 151 |
+
even_batches: bool = True,
|
| 152 |
+
):
|
| 153 |
+
if split_batches and batch_sampler.batch_size % num_processes != 0:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
|
| 156 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
| 157 |
+
)
|
| 158 |
+
self.batch_sampler = batch_sampler
|
| 159 |
+
self.num_processes = num_processes
|
| 160 |
+
self.process_index = process_index
|
| 161 |
+
self.split_batches = split_batches
|
| 162 |
+
self.even_batches = even_batches
|
| 163 |
+
self.batch_size = getattr(batch_sampler, "batch_size", None)
|
| 164 |
+
self.drop_last = getattr(batch_sampler, "drop_last", False)
|
| 165 |
+
if self.batch_size is None and self.even_batches:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
"You need to use `even_batches=False` when the batch sampler has no batch size. If you "
|
| 168 |
+
"are not calling this method directly, set `accelerator.even_batches=False` instead."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def total_length(self):
|
| 173 |
+
return len(self.batch_sampler)
|
| 174 |
+
|
| 175 |
+
def __len__(self):
|
| 176 |
+
if self.split_batches:
|
| 177 |
+
# Split batches does not change the length of the batch sampler
|
| 178 |
+
return len(self.batch_sampler)
|
| 179 |
+
if len(self.batch_sampler) % self.num_processes == 0:
|
| 180 |
+
# If the length is a round multiple of the number of processes, it's easy.
|
| 181 |
+
return len(self.batch_sampler) // self.num_processes
|
| 182 |
+
length = len(self.batch_sampler) // self.num_processes
|
| 183 |
+
if self.drop_last:
|
| 184 |
+
# Same if we drop the remainder.
|
| 185 |
+
return length
|
| 186 |
+
elif self.even_batches:
|
| 187 |
+
# When we even batches we always get +1
|
| 188 |
+
return length + 1
|
| 189 |
+
else:
|
| 190 |
+
# Otherwise it depends on the process index.
|
| 191 |
+
return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
|
| 192 |
+
|
| 193 |
+
def __iter__(self):
|
| 194 |
+
return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
|
| 195 |
+
|
| 196 |
+
def _iter_with_split(self):
|
| 197 |
+
initial_data = []
|
| 198 |
+
batch_length = self.batch_sampler.batch_size // self.num_processes
|
| 199 |
+
for idx, batch in enumerate(self.batch_sampler):
|
| 200 |
+
if idx == 0:
|
| 201 |
+
initial_data = batch
|
| 202 |
+
if len(batch) == self.batch_size:
|
| 203 |
+
# If the batch is full, we yield the part of it this process is responsible of.
|
| 204 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
| 205 |
+
|
| 206 |
+
# If drop_last is True of the last batch was full, iteration is over, otherwise...
|
| 207 |
+
if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
|
| 208 |
+
if not self.even_batches:
|
| 209 |
+
if len(batch) > batch_length * self.process_index:
|
| 210 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
| 211 |
+
else:
|
| 212 |
+
# For degenerate cases where the dataset has less than num_process * batch_size samples
|
| 213 |
+
while len(initial_data) < self.batch_size:
|
| 214 |
+
initial_data += initial_data
|
| 215 |
+
batch = batch + initial_data
|
| 216 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
| 217 |
+
|
| 218 |
+
def _iter_with_no_split(self):
|
| 219 |
+
initial_data = []
|
| 220 |
+
batch_to_yield = []
|
| 221 |
+
for idx, batch in enumerate(self.batch_sampler):
|
| 222 |
+
# We gather the initial indices in case we need to circle back at the end.
|
| 223 |
+
if not self.drop_last and idx < self.num_processes:
|
| 224 |
+
initial_data += batch
|
| 225 |
+
# We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
|
| 226 |
+
# yielding it.
|
| 227 |
+
if idx % self.num_processes == self.process_index:
|
| 228 |
+
batch_to_yield = batch
|
| 229 |
+
if idx % self.num_processes == self.num_processes - 1 and (
|
| 230 |
+
self.batch_size is None or len(batch) == self.batch_size
|
| 231 |
+
):
|
| 232 |
+
yield batch_to_yield
|
| 233 |
+
batch_to_yield = []
|
| 234 |
+
|
| 235 |
+
# If drop_last is True, iteration is over, otherwise...
|
| 236 |
+
if not self.drop_last and len(initial_data) > 0:
|
| 237 |
+
if not self.even_batches:
|
| 238 |
+
if len(batch_to_yield) > 0:
|
| 239 |
+
yield batch_to_yield
|
| 240 |
+
else:
|
| 241 |
+
# ... we yield the complete batch we had saved before if it has the proper length
|
| 242 |
+
if len(batch_to_yield) == self.batch_size:
|
| 243 |
+
yield batch_to_yield
|
| 244 |
+
|
| 245 |
+
# For degenerate cases where the dataset has less than num_process * batch_size samples
|
| 246 |
+
while len(initial_data) < self.num_processes * self.batch_size:
|
| 247 |
+
initial_data += initial_data
|
| 248 |
+
|
| 249 |
+
# If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
|
| 250 |
+
if len(batch) == self.batch_size:
|
| 251 |
+
batch = []
|
| 252 |
+
idx += 1
|
| 253 |
+
|
| 254 |
+
# Make sure we yield a multiple of self.num_processes batches
|
| 255 |
+
cycle_index = 0
|
| 256 |
+
while idx % self.num_processes != 0 or len(batch) > 0:
|
| 257 |
+
end_index = cycle_index + self.batch_size - len(batch)
|
| 258 |
+
batch += initial_data[cycle_index:end_index]
|
| 259 |
+
if idx % self.num_processes == self.process_index:
|
| 260 |
+
yield batch
|
| 261 |
+
cycle_index = end_index
|
| 262 |
+
batch = []
|
| 263 |
+
idx += 1
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class IterableDatasetShard(IterableDataset):
|
| 267 |
+
"""
|
| 268 |
+
Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
|
| 269 |
+
always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
|
| 270 |
+
`split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
|
| 271 |
+
`drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
|
| 272 |
+
be too small or loop with indices from the beginning.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
dataset (`torch.utils.data.dataset.IterableDataset`):
|
| 276 |
+
The batch sampler to split in several shards.
|
| 277 |
+
batch_size (`int`, *optional*, defaults to 1):
|
| 278 |
+
The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
|
| 279 |
+
`split_batches=True`).
|
| 280 |
+
drop_last (`bool`, *optional*, defaults to `False`):
|
| 281 |
+
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
|
| 282 |
+
beginning.
|
| 283 |
+
num_processes (`int`, *optional*, defaults to 1):
|
| 284 |
+
The number of processes running concurrently.
|
| 285 |
+
process_index (`int`, *optional*, defaults to 0):
|
| 286 |
+
The index of the current process.
|
| 287 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 288 |
+
Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
|
| 289 |
+
yielding different full batches on each process.
|
| 290 |
+
|
| 291 |
+
On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
|
| 292 |
+
|
| 293 |
+
- the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
|
| 294 |
+
argument is set to `False`.
|
| 295 |
+
- the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
|
| 296 |
+
this argument is set to `True`.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
dataset: IterableDataset,
|
| 302 |
+
batch_size: int = 1,
|
| 303 |
+
drop_last: bool = False,
|
| 304 |
+
num_processes: int = 1,
|
| 305 |
+
process_index: int = 0,
|
| 306 |
+
split_batches: bool = False,
|
| 307 |
+
):
|
| 308 |
+
if split_batches and batch_size > 1 and batch_size % num_processes != 0:
|
| 309 |
+
raise ValueError(
|
| 310 |
+
f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
|
| 311 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
| 312 |
+
)
|
| 313 |
+
self.dataset = dataset
|
| 314 |
+
self.batch_size = batch_size
|
| 315 |
+
self.drop_last = drop_last
|
| 316 |
+
self.num_processes = num_processes
|
| 317 |
+
self.process_index = process_index
|
| 318 |
+
self.split_batches = split_batches
|
| 319 |
+
|
| 320 |
+
def set_epoch(self, epoch):
|
| 321 |
+
self.epoch = epoch
|
| 322 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 323 |
+
self.dataset.set_epoch(epoch)
|
| 324 |
+
|
| 325 |
+
def __len__(self):
|
| 326 |
+
# We will just raise the downstream error if the underlying dataset is not sized
|
| 327 |
+
if self.drop_last:
|
| 328 |
+
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
|
| 329 |
+
else:
|
| 330 |
+
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
|
| 331 |
+
|
| 332 |
+
def __iter__(self):
|
| 333 |
+
if (
|
| 334 |
+
not hasattr(self.dataset, "set_epoch")
|
| 335 |
+
and hasattr(self.dataset, "generator")
|
| 336 |
+
and isinstance(self.dataset.generator, torch.Generator)
|
| 337 |
+
):
|
| 338 |
+
self.dataset.generator.manual_seed(self.epoch)
|
| 339 |
+
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
|
| 340 |
+
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
|
| 341 |
+
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
|
| 342 |
+
|
| 343 |
+
first_batch = None
|
| 344 |
+
current_batch = []
|
| 345 |
+
for element in self.dataset:
|
| 346 |
+
current_batch.append(element)
|
| 347 |
+
# Wait to have a full batch before yielding elements.
|
| 348 |
+
if len(current_batch) == real_batch_size:
|
| 349 |
+
for i in process_slice:
|
| 350 |
+
yield current_batch[i]
|
| 351 |
+
if first_batch is None:
|
| 352 |
+
first_batch = current_batch.copy()
|
| 353 |
+
current_batch = []
|
| 354 |
+
|
| 355 |
+
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
|
| 356 |
+
if not self.drop_last and len(current_batch) > 0:
|
| 357 |
+
if first_batch is None:
|
| 358 |
+
first_batch = current_batch.copy()
|
| 359 |
+
while len(current_batch) < real_batch_size:
|
| 360 |
+
current_batch += first_batch
|
| 361 |
+
for i in process_slice:
|
| 362 |
+
yield current_batch[i]
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class DataLoaderStateMixin:
|
| 366 |
+
"""
|
| 367 |
+
Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
|
| 368 |
+
end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
|
| 369 |
+
useful information that might be needed.
|
| 370 |
+
|
| 371 |
+
**Available attributes:**
|
| 372 |
+
|
| 373 |
+
- **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
|
| 374 |
+
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
|
| 375 |
+
batch size
|
| 376 |
+
|
| 377 |
+
<Tip warning={true}>
|
| 378 |
+
|
| 379 |
+
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
|
| 380 |
+
`self.gradient_state`.
|
| 381 |
+
|
| 382 |
+
</Tip>
|
| 383 |
+
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
def __init_subclass__(cls, **kwargs):
|
| 387 |
+
cls.end_of_dataloader = False
|
| 388 |
+
cls.remainder = -1
|
| 389 |
+
|
| 390 |
+
def reset(self):
|
| 391 |
+
self.end_of_dataloader = False
|
| 392 |
+
self.remainder = -1
|
| 393 |
+
|
| 394 |
+
def begin(self):
|
| 395 |
+
"Prepares the gradient state for the current dataloader"
|
| 396 |
+
self.reset()
|
| 397 |
+
with suppress(Exception):
|
| 398 |
+
if not self._drop_last:
|
| 399 |
+
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
|
| 400 |
+
self.remainder = length % self.total_batch_size
|
| 401 |
+
self.gradient_state._add_dataloader(self)
|
| 402 |
+
|
| 403 |
+
def end(self):
|
| 404 |
+
"Cleans up the gradient state after exiting the dataloader"
|
| 405 |
+
self.gradient_state._remove_dataloader(self)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class DataLoaderAdapter:
|
| 409 |
+
"""
|
| 410 |
+
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
|
| 411 |
+
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
|
| 415 |
+
self.use_stateful_dataloader = use_stateful_dataloader
|
| 416 |
+
if is_torchdata_stateful_dataloader_available():
|
| 417 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 418 |
+
|
| 419 |
+
if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
|
| 420 |
+
raise ImportError(
|
| 421 |
+
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
|
| 422 |
+
)
|
| 423 |
+
if use_stateful_dataloader:
|
| 424 |
+
torchdata_version = version.parse(importlib.metadata.version("torchdata"))
|
| 425 |
+
if (
|
| 426 |
+
"in_order" in kwargs
|
| 427 |
+
and compare_versions(torchdata_version, "<", "0.11")
|
| 428 |
+
and is_torch_version(">=", "2.6.0")
|
| 429 |
+
):
|
| 430 |
+
kwargs.pop("in_order")
|
| 431 |
+
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
| 432 |
+
else:
|
| 433 |
+
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
| 434 |
+
|
| 435 |
+
if hasattr(self.base_dataloader, "state_dict"):
|
| 436 |
+
self.dl_state_dict = self.base_dataloader.state_dict()
|
| 437 |
+
|
| 438 |
+
def __getattr__(self, name):
|
| 439 |
+
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
|
| 440 |
+
if name == "base_dataloader":
|
| 441 |
+
raise AttributeError()
|
| 442 |
+
# Delegate attribute access to the internal dataloader
|
| 443 |
+
return getattr(self.base_dataloader, name)
|
| 444 |
+
|
| 445 |
+
def state_dict(self):
|
| 446 |
+
return self.dl_state_dict
|
| 447 |
+
|
| 448 |
+
def load_state_dict(self, state_dict):
|
| 449 |
+
self.base_dataloader.load_state_dict(state_dict)
|
| 450 |
+
|
| 451 |
+
@property
|
| 452 |
+
def __class__(self):
|
| 453 |
+
"""
|
| 454 |
+
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
|
| 455 |
+
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
| 456 |
+
object.
|
| 457 |
+
"""
|
| 458 |
+
return self.base_dataloader.__class__
|
| 459 |
+
|
| 460 |
+
def __len__(self):
|
| 461 |
+
return len(self.base_dataloader)
|
| 462 |
+
|
| 463 |
+
def adjust_state_dict_for_prefetch(self):
|
| 464 |
+
"""
|
| 465 |
+
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
|
| 466 |
+
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
|
| 467 |
+
overridden.
|
| 468 |
+
|
| 469 |
+
This should modify `self.dl_state_dict` directly
|
| 470 |
+
"""
|
| 471 |
+
# The state dict will be off by a factor of `n-1` batch too many during DDP,
|
| 472 |
+
# so we need to adjust it here
|
| 473 |
+
if PartialState().distributed_type != DistributedType.NO:
|
| 474 |
+
factor = PartialState().num_processes - 1
|
| 475 |
+
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
|
| 476 |
+
self.dl_state_dict["_sampler_iter_yielded"] -= factor
|
| 477 |
+
if self.dl_state_dict["_num_yielded"] > 0:
|
| 478 |
+
self.dl_state_dict["_num_yielded"] -= factor
|
| 479 |
+
if self.dl_state_dict["_index_sampler_state"] is not None:
|
| 480 |
+
if (
|
| 481 |
+
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
|
| 482 |
+
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
|
| 483 |
+
):
|
| 484 |
+
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
|
| 485 |
+
|
| 486 |
+
def _update_state_dict(self):
|
| 487 |
+
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
|
| 488 |
+
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
|
| 489 |
+
# what it wants to yield.
|
| 490 |
+
#
|
| 491 |
+
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
|
| 492 |
+
if hasattr(self.base_dataloader, "state_dict"):
|
| 493 |
+
self.dl_state_dict = self.base_dataloader.state_dict()
|
| 494 |
+
# Potentially modify the state_dict to adjust for prefetching
|
| 495 |
+
self.adjust_state_dict_for_prefetch()
|
| 496 |
+
# Then tag if we are at the end of the dataloader
|
| 497 |
+
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
| 501 |
+
"""
|
| 502 |
+
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
dataset (`torch.utils.data.dataset.Dataset`):
|
| 506 |
+
The dataset to use to build this dataloader.
|
| 507 |
+
device (`torch.device`, *optional*):
|
| 508 |
+
If passed, the device to put all batches on.
|
| 509 |
+
rng_types (list of `str` or [`~utils.RNGType`]):
|
| 510 |
+
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
| 511 |
+
several of:
|
| 512 |
+
|
| 513 |
+
- `"torch"`: the base torch random number generator
|
| 514 |
+
- `"cuda"`: the CUDA random number generator (GPU only)
|
| 515 |
+
- `"xla"`: the XLA random number generator (TPU only)
|
| 516 |
+
- `"generator"`: an optional `torch.Generator`
|
| 517 |
+
synchronized_generator (`torch.Generator`, *optional*):
|
| 518 |
+
A random number generator to keep synchronized across processes.
|
| 519 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
| 520 |
+
The number of batches to skip at the beginning.
|
| 521 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
| 522 |
+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
| 523 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 524 |
+
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
| 525 |
+
|
| 526 |
+
**Available attributes:**
|
| 527 |
+
|
| 528 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
| 529 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
| 530 |
+
number of processes
|
| 531 |
+
|
| 532 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
| 533 |
+
"""
|
| 534 |
+
|
| 535 |
+
def __init__(
|
| 536 |
+
self,
|
| 537 |
+
dataset,
|
| 538 |
+
device=None,
|
| 539 |
+
rng_types=None,
|
| 540 |
+
synchronized_generator=None,
|
| 541 |
+
skip_batches=0,
|
| 542 |
+
use_stateful_dataloader=False,
|
| 543 |
+
_drop_last: bool = False,
|
| 544 |
+
_non_blocking: bool = False,
|
| 545 |
+
torch_device_mesh=None,
|
| 546 |
+
**kwargs,
|
| 547 |
+
):
|
| 548 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
| 549 |
+
self.device = device
|
| 550 |
+
self.rng_types = rng_types
|
| 551 |
+
self.synchronized_generator = synchronized_generator
|
| 552 |
+
self.skip_batches = skip_batches
|
| 553 |
+
self.gradient_state = GradientState()
|
| 554 |
+
self._drop_last = _drop_last
|
| 555 |
+
self._non_blocking = _non_blocking
|
| 556 |
+
self.iteration = 0
|
| 557 |
+
|
| 558 |
+
def __iter__(self):
|
| 559 |
+
if self.rng_types is not None:
|
| 560 |
+
synchronize_rng_states(self.rng_types, self.synchronized_generator)
|
| 561 |
+
self.begin()
|
| 562 |
+
|
| 563 |
+
self.set_epoch(self.iteration)
|
| 564 |
+
dataloader_iter = self.base_dataloader.__iter__()
|
| 565 |
+
# We iterate one batch ahead to check when we are at the end
|
| 566 |
+
try:
|
| 567 |
+
current_batch = next(dataloader_iter)
|
| 568 |
+
except StopIteration:
|
| 569 |
+
self.end()
|
| 570 |
+
return
|
| 571 |
+
|
| 572 |
+
batch_index = 0
|
| 573 |
+
while True:
|
| 574 |
+
try:
|
| 575 |
+
# But we still move it to the device so it is done before `StopIteration` is reached
|
| 576 |
+
if self.device is not None:
|
| 577 |
+
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
|
| 578 |
+
self._update_state_dict()
|
| 579 |
+
next_batch = next(dataloader_iter)
|
| 580 |
+
if batch_index >= self.skip_batches:
|
| 581 |
+
yield current_batch
|
| 582 |
+
batch_index += 1
|
| 583 |
+
current_batch = next_batch
|
| 584 |
+
except StopIteration:
|
| 585 |
+
self.end_of_dataloader = True
|
| 586 |
+
self._update_state_dict()
|
| 587 |
+
if batch_index >= self.skip_batches:
|
| 588 |
+
yield current_batch
|
| 589 |
+
break
|
| 590 |
+
|
| 591 |
+
self.iteration += 1
|
| 592 |
+
self.end()
|
| 593 |
+
|
| 594 |
+
def __reduce__(self):
|
| 595 |
+
"""
|
| 596 |
+
Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
|
| 597 |
+
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
| 598 |
+
`__class__` member.
|
| 599 |
+
"""
|
| 600 |
+
args = super().__reduce__()
|
| 601 |
+
return (DataLoaderShard, *args[1:])
|
| 602 |
+
|
| 603 |
+
def set_epoch(self, epoch: int):
|
| 604 |
+
# In case it is manually passed in, the user can set it to what they like
|
| 605 |
+
if self.iteration != epoch:
|
| 606 |
+
self.iteration = epoch
|
| 607 |
+
if hasattr(self.batch_sampler, "set_epoch"):
|
| 608 |
+
self.batch_sampler.set_epoch(epoch)
|
| 609 |
+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
| 610 |
+
self.batch_sampler.sampler.set_epoch(epoch)
|
| 611 |
+
if (
|
| 612 |
+
hasattr(self.batch_sampler, "batch_sampler")
|
| 613 |
+
and hasattr(self.batch_sampler.batch_sampler, "sampler")
|
| 614 |
+
and hasattr(self.batch_sampler.batch_sampler.sampler, "set_epoch")
|
| 615 |
+
):
|
| 616 |
+
self.batch_sampler.batch_sampler.sampler.set_epoch(epoch)
|
| 617 |
+
# We support if a custom `Dataset` implementation has `set_epoch`
|
| 618 |
+
# or in general HF datasets `Datasets`
|
| 619 |
+
elif hasattr(self.dataset, "set_epoch"):
|
| 620 |
+
self.dataset.set_epoch(epoch)
|
| 621 |
+
|
| 622 |
+
@property
|
| 623 |
+
def total_batch_size(self):
|
| 624 |
+
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
|
| 625 |
+
return (
|
| 626 |
+
batch_sampler.batch_size
|
| 627 |
+
if getattr(batch_sampler, "split_batches", False)
|
| 628 |
+
else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
@property
|
| 632 |
+
def total_dataset_length(self):
|
| 633 |
+
if hasattr(self.dataset, "total_length"):
|
| 634 |
+
return self.dataset.total_length
|
| 635 |
+
else:
|
| 636 |
+
return len(self.dataset)
|
| 637 |
+
|
| 638 |
+
def get_sampler(self):
|
| 639 |
+
return get_sampler(self)
|
| 640 |
+
|
| 641 |
+
def set_sampler(self, sampler):
|
| 642 |
+
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
|
| 643 |
+
if sampler_is_batch_sampler:
|
| 644 |
+
self.sampler.sampler = sampler
|
| 645 |
+
else:
|
| 646 |
+
self.batch_sampler.sampler = sampler
|
| 647 |
+
if hasattr(self.batch_sampler, "batch_sampler"):
|
| 648 |
+
self.batch_sampler.batch_sampler.sampler = sampler
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
if is_torch_xla_available():
|
| 652 |
+
import torch_xla.distributed.parallel_loader as xpl
|
| 653 |
+
|
| 654 |
+
class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
|
| 655 |
+
"""
|
| 656 |
+
Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
|
| 657 |
+
|
| 658 |
+
XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
|
| 659 |
+
prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
|
| 660 |
+
thread only.
|
| 661 |
+
|
| 662 |
+
**Available attributes:**
|
| 663 |
+
|
| 664 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
| 665 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
| 666 |
+
number of processes
|
| 667 |
+
|
| 668 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
def __init__(self, dataloader: DataLoaderShard, device: torch.device):
|
| 672 |
+
super().__init__(dataloader, device)
|
| 673 |
+
self._rng_types = self._loader.rng_types
|
| 674 |
+
self._loader.rng_types = None
|
| 675 |
+
self.device = device
|
| 676 |
+
|
| 677 |
+
def __iter__(self):
|
| 678 |
+
if self._rng_types is not None:
|
| 679 |
+
synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
|
| 680 |
+
|
| 681 |
+
return super().__iter__()
|
| 682 |
+
|
| 683 |
+
def set_epoch(self, epoch: int):
|
| 684 |
+
if hasattr(self.dataloader, "set_epoch"):
|
| 685 |
+
self.dataloader.set_epoch(epoch)
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def total_batch_size(self):
|
| 689 |
+
return self._loader.total_batch_size
|
| 690 |
+
|
| 691 |
+
@property
|
| 692 |
+
def total_dataset_length(self):
|
| 693 |
+
return self._loader.total_dataset_length
|
| 694 |
+
|
| 695 |
+
@property
|
| 696 |
+
def batch_sampler(self):
|
| 697 |
+
return self._loader.batch_sampler
|
| 698 |
+
|
| 699 |
+
@property
|
| 700 |
+
def dataloader(self):
|
| 701 |
+
return self._loader
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
| 705 |
+
"""
|
| 706 |
+
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
|
| 707 |
+
their part of the batch.
|
| 708 |
+
|
| 709 |
+
Args:
|
| 710 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 711 |
+
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
| 712 |
+
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
| 713 |
+
`num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
|
| 714 |
+
the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
|
| 715 |
+
`dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
|
| 716 |
+
size of the `dataloader` is a round multiple of `batch_size`.
|
| 717 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
| 718 |
+
The number of batches to skip at the beginning of an iteration.
|
| 719 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
| 720 |
+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
| 721 |
+
|
| 722 |
+
**Available attributes:**
|
| 723 |
+
|
| 724 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
| 725 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
| 726 |
+
number of processes
|
| 727 |
+
|
| 728 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
| 729 |
+
"""
|
| 730 |
+
|
| 731 |
+
def __init__(
|
| 732 |
+
self,
|
| 733 |
+
dataset,
|
| 734 |
+
split_batches: bool = False,
|
| 735 |
+
skip_batches=0,
|
| 736 |
+
use_stateful_dataloader=False,
|
| 737 |
+
_drop_last: bool = False,
|
| 738 |
+
_non_blocking: bool = False,
|
| 739 |
+
slice_fn=None,
|
| 740 |
+
torch_device_mesh=None,
|
| 741 |
+
**kwargs,
|
| 742 |
+
):
|
| 743 |
+
shuffle = False
|
| 744 |
+
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
| 745 |
+
|
| 746 |
+
# We need to save the shuffling state of the DataPipe
|
| 747 |
+
if isinstance(dataset, ShufflerIterDataPipe):
|
| 748 |
+
shuffle = dataset._shuffle_enabled
|
| 749 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
| 750 |
+
self.split_batches = split_batches
|
| 751 |
+
if shuffle:
|
| 752 |
+
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
|
| 753 |
+
|
| 754 |
+
self.gradient_state = GradientState()
|
| 755 |
+
self.state = PartialState()
|
| 756 |
+
self._drop_last = _drop_last
|
| 757 |
+
self._non_blocking = _non_blocking
|
| 758 |
+
self.skip_batches = skip_batches
|
| 759 |
+
self.torch_device_mesh = torch_device_mesh
|
| 760 |
+
|
| 761 |
+
self.slice_fn = slice_tensors if slice_fn is None else slice_fn
|
| 762 |
+
self.iteration = 0
|
| 763 |
+
|
| 764 |
+
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
|
| 765 |
+
# device mesh may hold any number of dimensions, however,
|
| 766 |
+
# below code is for targetted support for dp, fsdp and tp
|
| 767 |
+
|
| 768 |
+
# device mesh will be used only if there is tp involved
|
| 769 |
+
# or any multi-dimensional parallelism involving tp
|
| 770 |
+
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
|
| 771 |
+
# otherwise the default behavour not using device mesh should be sufficient
|
| 772 |
+
# since multi dimensional parallelism devoid of tp would anyway need
|
| 773 |
+
# different batches for each process irrespective of dp or fsdp
|
| 774 |
+
self.submesh_tp = None
|
| 775 |
+
self.submesh_dp = None
|
| 776 |
+
self.submesh_fsdp = None
|
| 777 |
+
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
|
| 778 |
+
self.submesh_tp = self.torch_device_mesh["tp"]
|
| 779 |
+
if "dp" in self.torch_device_mesh.mesh_dim_names:
|
| 780 |
+
self.submesh_dp = self.torch_device_mesh["dp"]
|
| 781 |
+
if "fsdp" in self.torch_device_mesh.mesh_dim_names:
|
| 782 |
+
self.submesh_fsdp = self.torch_device_mesh["fsdp"]
|
| 783 |
+
if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
|
| 784 |
+
raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")
|
| 785 |
+
|
| 786 |
+
def _fetch_batches(self, iterator):
|
| 787 |
+
batches, batch = None, None
|
| 788 |
+
# On process 0, we gather the batch to dispatch.
|
| 789 |
+
if self.state.process_index == 0:
|
| 790 |
+
# Procedure to support TP only is simpler
|
| 791 |
+
# since we want to dispatch the same batch of samples across all ranks
|
| 792 |
+
# this removes complexity of handling multiple tp rank groups when TP + DP
|
| 793 |
+
# combination is involved.
|
| 794 |
+
|
| 795 |
+
try:
|
| 796 |
+
# for TP case avoid using split_batches
|
| 797 |
+
# since it would mean that the dataloader should be spilling out
|
| 798 |
+
# duplicates of batches.
|
| 799 |
+
if self.split_batches:
|
| 800 |
+
# One batch of the main iterator is dispatched and split.
|
| 801 |
+
if self.submesh_tp:
|
| 802 |
+
logger.warning(
|
| 803 |
+
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
|
| 804 |
+
"otherwise, use dispatch_batches=True instead."
|
| 805 |
+
)
|
| 806 |
+
self._update_state_dict()
|
| 807 |
+
batch = next(iterator)
|
| 808 |
+
else:
|
| 809 |
+
# num_processes batches of the main iterator are concatenated then dispatched and split.
|
| 810 |
+
# We add the batches one by one so we have the remainder available when drop_last=False.
|
| 811 |
+
batches = []
|
| 812 |
+
if self.submesh_tp:
|
| 813 |
+
# when tp, extract single batch and then replicate
|
| 814 |
+
self._update_state_dict()
|
| 815 |
+
batch = next(iterator)
|
| 816 |
+
batches = [batch] * self.state.num_processes
|
| 817 |
+
else:
|
| 818 |
+
for _ in range(self.state.num_processes):
|
| 819 |
+
self._update_state_dict()
|
| 820 |
+
batches.append(next(iterator))
|
| 821 |
+
try:
|
| 822 |
+
batch = concatenate(batches, dim=0)
|
| 823 |
+
except RuntimeError as e:
|
| 824 |
+
raise RuntimeError(
|
| 825 |
+
"You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
|
| 826 |
+
"either pass `dispatch_batches=False` and have each process fetch its own batch "
|
| 827 |
+
" or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
|
| 828 |
+
"slice it into `num_processes` batches for each process."
|
| 829 |
+
) from e
|
| 830 |
+
# In both cases, we need to get the structure of the batch that we will broadcast on other
|
| 831 |
+
# processes to initialize the tensors with the right shape.
|
| 832 |
+
# data_structure, stop_iteration
|
| 833 |
+
batch_info = [get_data_structure(batch), False]
|
| 834 |
+
except StopIteration:
|
| 835 |
+
batch_info = [None, True]
|
| 836 |
+
else:
|
| 837 |
+
batch_info = [None, self._stop_iteration]
|
| 838 |
+
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
|
| 839 |
+
broadcast_object_list(batch_info)
|
| 840 |
+
self._stop_iteration = batch_info[1]
|
| 841 |
+
if self._stop_iteration:
|
| 842 |
+
# If drop_last is False and split_batches is False, we may have a remainder to take care of.
|
| 843 |
+
if not self.split_batches and not self._drop_last:
|
| 844 |
+
if self.state.process_index == 0 and len(batches) > 0:
|
| 845 |
+
batch = concatenate(batches, dim=0)
|
| 846 |
+
batch_info = [get_data_structure(batch), False]
|
| 847 |
+
else:
|
| 848 |
+
batch_info = [None, True]
|
| 849 |
+
broadcast_object_list(batch_info)
|
| 850 |
+
return batch, batch_info
|
| 851 |
+
|
| 852 |
+
def __iter__(self):
|
| 853 |
+
self.begin()
|
| 854 |
+
self.set_epoch(self.iteration)
|
| 855 |
+
main_iterator = None
|
| 856 |
+
if is_torch_version(">=", "2.0.1"):
|
| 857 |
+
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
|
| 858 |
+
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
|
| 859 |
+
# But, we only iterate through the DataLoader on process 0.
|
| 860 |
+
main_iterator = self.base_dataloader.__iter__()
|
| 861 |
+
elif self.state.process_index == 0:
|
| 862 |
+
main_iterator = self.base_dataloader.__iter__()
|
| 863 |
+
stop_iteration = False
|
| 864 |
+
self._stop_iteration = False
|
| 865 |
+
first_batch = None
|
| 866 |
+
next_batch, next_batch_info = self._fetch_batches(main_iterator)
|
| 867 |
+
batch_index = 0
|
| 868 |
+
while not stop_iteration:
|
| 869 |
+
batch, batch_info = next_batch, next_batch_info
|
| 870 |
+
|
| 871 |
+
if self.state.process_index != 0:
|
| 872 |
+
# Initialize tensors on other processes than process 0.
|
| 873 |
+
batch = initialize_tensors(batch_info[0])
|
| 874 |
+
batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
|
| 875 |
+
# Broadcast the batch before splitting it.
|
| 876 |
+
batch = broadcast(batch, from_process=0)
|
| 877 |
+
|
| 878 |
+
if not self._drop_last and first_batch is None:
|
| 879 |
+
# We keep at least num processes elements of the first batch to be able to complete the last batch
|
| 880 |
+
first_batch = self.slice_fn(
|
| 881 |
+
batch,
|
| 882 |
+
slice(0, self.state.num_processes),
|
| 883 |
+
process_index=self.state.process_index,
|
| 884 |
+
num_processes=self.state.num_processes,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
if batch is None:
|
| 888 |
+
raise ValueError(
|
| 889 |
+
f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
observed_batch_size = find_batch_size(batch)
|
| 893 |
+
batch_size = observed_batch_size // self.state.num_processes
|
| 894 |
+
|
| 895 |
+
stop_iteration = self._stop_iteration
|
| 896 |
+
if not stop_iteration:
|
| 897 |
+
# We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
|
| 898 |
+
# the dataloader since the number of batches is a round multiple of the number of processes.
|
| 899 |
+
next_batch, next_batch_info = self._fetch_batches(main_iterator)
|
| 900 |
+
# next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
|
| 901 |
+
if self._stop_iteration and next_batch_info[0] is None:
|
| 902 |
+
stop_iteration = True
|
| 903 |
+
|
| 904 |
+
if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
|
| 905 |
+
# If the last batch is not complete, let's add the first batch to it.
|
| 906 |
+
batch = concatenate([batch, first_batch], dim=0)
|
| 907 |
+
# Batch size computation above is wrong, it's off by 1 so we fix it.
|
| 908 |
+
batch_size += 1
|
| 909 |
+
|
| 910 |
+
data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
|
| 911 |
+
batch = self.slice_fn(
|
| 912 |
+
batch,
|
| 913 |
+
data_slice,
|
| 914 |
+
process_index=self.state.process_index,
|
| 915 |
+
num_processes=self.state.num_processes,
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
if stop_iteration:
|
| 919 |
+
self.end_of_dataloader = True
|
| 920 |
+
self._update_state_dict()
|
| 921 |
+
self.remainder = observed_batch_size
|
| 922 |
+
if batch_index >= self.skip_batches:
|
| 923 |
+
yield batch
|
| 924 |
+
batch_index += 1
|
| 925 |
+
self.iteration += 1
|
| 926 |
+
self.end()
|
| 927 |
+
|
| 928 |
+
def set_epoch(self, epoch: int):
|
| 929 |
+
# In case it is manually passed in, the user can set it to what they like
|
| 930 |
+
if self.iteration != epoch:
|
| 931 |
+
self.iteration = epoch
|
| 932 |
+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
| 933 |
+
self.batch_sampler.sampler.set_epoch(epoch)
|
| 934 |
+
elif hasattr(self.dataset, "set_epoch"):
|
| 935 |
+
self.dataset.set_epoch(epoch)
|
| 936 |
+
|
| 937 |
+
def __len__(self):
|
| 938 |
+
whole_length = len(self.base_dataloader)
|
| 939 |
+
if self.split_batches:
|
| 940 |
+
return whole_length
|
| 941 |
+
elif self._drop_last:
|
| 942 |
+
return whole_length // self.state.num_processes
|
| 943 |
+
else:
|
| 944 |
+
return math.ceil(whole_length / self.state.num_processes)
|
| 945 |
+
|
| 946 |
+
def __reduce__(self):
|
| 947 |
+
"""
|
| 948 |
+
Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
|
| 949 |
+
be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
| 950 |
+
`__class__` member.
|
| 951 |
+
"""
|
| 952 |
+
args = super().__reduce__()
|
| 953 |
+
return (DataLoaderDispatcher, *args[1:])
|
| 954 |
+
|
| 955 |
+
@property
|
| 956 |
+
def total_batch_size(self):
|
| 957 |
+
return (
|
| 958 |
+
self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
@property
|
| 962 |
+
def total_dataset_length(self):
|
| 963 |
+
return len(self.dataset)
|
| 964 |
+
|
| 965 |
+
def get_sampler(self):
|
| 966 |
+
return get_sampler(self)
|
| 967 |
+
|
| 968 |
+
def set_sampler(self, sampler):
|
| 969 |
+
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
|
| 970 |
+
if sampler_is_batch_sampler:
|
| 971 |
+
self.sampler.sampler = sampler
|
| 972 |
+
else:
|
| 973 |
+
self.batch_sampler.sampler = sampler
|
| 974 |
+
if hasattr(self.batch_sampler, "batch_sampler"):
|
| 975 |
+
self.batch_sampler.batch_sampler.sampler = sampler
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
def get_sampler(dataloader):
|
| 979 |
+
"""
|
| 980 |
+
Get the sampler associated to the dataloader
|
| 981 |
+
|
| 982 |
+
Args:
|
| 983 |
+
dataloader (`torch.utils.data.dataloader.DataLoader`):
|
| 984 |
+
The data loader to split across several devices.
|
| 985 |
+
Returns:
|
| 986 |
+
`torch.utils.data.Sampler`: The sampler associated to the dataloader
|
| 987 |
+
"""
|
| 988 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
| 989 |
+
if sampler_is_batch_sampler:
|
| 990 |
+
sampler = getattr(dataloader.sampler, "sampler", None)
|
| 991 |
+
else:
|
| 992 |
+
sampler = getattr(dataloader.batch_sampler, "sampler", None)
|
| 993 |
+
return sampler
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
def prepare_data_loader(
|
| 997 |
+
dataloader: DataLoader,
|
| 998 |
+
device: Optional[torch.device] = None,
|
| 999 |
+
num_processes: Optional[int] = None,
|
| 1000 |
+
process_index: Optional[int] = None,
|
| 1001 |
+
split_batches: bool = False,
|
| 1002 |
+
put_on_device: bool = False,
|
| 1003 |
+
rng_types: Optional[list[Union[str, RNGType]]] = None,
|
| 1004 |
+
dispatch_batches: Optional[bool] = None,
|
| 1005 |
+
even_batches: bool = True,
|
| 1006 |
+
slice_fn_for_dispatch: Optional[Callable] = None,
|
| 1007 |
+
use_seedable_sampler: bool = False,
|
| 1008 |
+
data_seed: Optional[int] = None,
|
| 1009 |
+
non_blocking: bool = False,
|
| 1010 |
+
use_stateful_dataloader: bool = False,
|
| 1011 |
+
torch_device_mesh=None,
|
| 1012 |
+
) -> DataLoader:
|
| 1013 |
+
"""
|
| 1014 |
+
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
|
| 1015 |
+
|
| 1016 |
+
Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
|
| 1017 |
+
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
|
| 1018 |
+
|
| 1019 |
+
Args:
|
| 1020 |
+
dataloader (`torch.utils.data.dataloader.DataLoader`):
|
| 1021 |
+
The data loader to split across several devices.
|
| 1022 |
+
device (`torch.device`):
|
| 1023 |
+
The target device for the returned `DataLoader`.
|
| 1024 |
+
num_processes (`int`, *optional*):
|
| 1025 |
+
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
|
| 1026 |
+
process_index (`int`, *optional*):
|
| 1027 |
+
The index of the current process. Will default to the value given by [`~state.PartialState`].
|
| 1028 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 1029 |
+
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
| 1030 |
+
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
| 1031 |
+
`num_processes` batches at each iteration).
|
| 1032 |
+
|
| 1033 |
+
Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
|
| 1034 |
+
this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
|
| 1035 |
+
otherwise.
|
| 1036 |
+
|
| 1037 |
+
Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
|
| 1038 |
+
`batch_size`.
|
| 1039 |
+
put_on_device (`bool`, *optional*, defaults to `False`):
|
| 1040 |
+
Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
|
| 1041 |
+
dictionaries of tensors).
|
| 1042 |
+
rng_types (list of `str` or [`~utils.RNGType`]):
|
| 1043 |
+
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
| 1044 |
+
several of:
|
| 1045 |
+
|
| 1046 |
+
- `"torch"`: the base torch random number generator
|
| 1047 |
+
- `"cuda"`: the CUDA random number generator (GPU only)
|
| 1048 |
+
- `"xla"`: the XLA random number generator (TPU only)
|
| 1049 |
+
- `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
|
| 1050 |
+
dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
|
| 1051 |
+
|
| 1052 |
+
dispatch_batches (`bool`, *optional*):
|
| 1053 |
+
If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
|
| 1054 |
+
are split and broadcast to each process. Will default to `True` when the underlying dataset is an
|
| 1055 |
+
`IterableDataset`, `False` otherwise.
|
| 1056 |
+
even_batches (`bool`, *optional*, defaults to `True`):
|
| 1057 |
+
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
|
| 1058 |
+
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
|
| 1059 |
+
all workers.
|
| 1060 |
+
slice_fn_for_dispatch (`Callable`, *optional*`):
|
| 1061 |
+
If passed, this function will be used to slice tensors across `num_processes`. Will default to
|
| 1062 |
+
[`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
|
| 1063 |
+
ignored otherwise.
|
| 1064 |
+
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
|
| 1065 |
+
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
|
| 1066 |
+
reproducability. Comes at a cost of potentially different performances due to different shuffling
|
| 1067 |
+
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
|
| 1068 |
+
`self.set_epoch`
|
| 1069 |
+
data_seed (`int`, *optional*, defaults to `None`):
|
| 1070 |
+
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
|
| 1071 |
+
will use the current default seed from torch.
|
| 1072 |
+
non_blocking (`bool`, *optional*, defaults to `False`):
|
| 1073 |
+
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
|
| 1074 |
+
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
|
| 1075 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
| 1076 |
+
"If set to true, the dataloader prepared by the Accelerator will be backed by "
|
| 1077 |
+
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
|
| 1078 |
+
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
|
| 1079 |
+
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
|
| 1080 |
+
PyTorch device mesh.
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
Returns:
|
| 1084 |
+
`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
|
| 1085 |
+
|
| 1086 |
+
<Tip warning={true}>
|
| 1087 |
+
|
| 1088 |
+
`BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
|
| 1089 |
+
equal to `False`
|
| 1090 |
+
|
| 1091 |
+
</Tip>
|
| 1092 |
+
"""
|
| 1093 |
+
if dispatch_batches is None:
|
| 1094 |
+
if not put_on_device:
|
| 1095 |
+
dispatch_batches = False
|
| 1096 |
+
else:
|
| 1097 |
+
dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
|
| 1098 |
+
|
| 1099 |
+
if dispatch_batches and not put_on_device:
|
| 1100 |
+
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
|
| 1101 |
+
# Grab defaults from PartialState
|
| 1102 |
+
state = PartialState()
|
| 1103 |
+
if num_processes is None:
|
| 1104 |
+
num_processes = state.num_processes
|
| 1105 |
+
|
| 1106 |
+
if process_index is None:
|
| 1107 |
+
process_index = state.process_index
|
| 1108 |
+
|
| 1109 |
+
if torch_device_mesh:
|
| 1110 |
+
if state.distributed_type == DistributedType.DEEPSPEED:
|
| 1111 |
+
# In DeepSpeed, the optimizer sharing level in DP is determined by the config file.
|
| 1112 |
+
# Only considers "dp" and "tp".
|
| 1113 |
+
# Given a device mesh (dp, tp) = (2, 3):
|
| 1114 |
+
# - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
|
| 1115 |
+
# - Processes with the same DP rank will receive the same batch.
|
| 1116 |
+
submesh_tp_size = 1
|
| 1117 |
+
if "tp" in torch_device_mesh.mesh_dim_names:
|
| 1118 |
+
submesh_tp_size = torch_device_mesh["tp"].size()
|
| 1119 |
+
process_index = process_index // submesh_tp_size
|
| 1120 |
+
num_processes = num_processes // submesh_tp_size
|
| 1121 |
+
else:
|
| 1122 |
+
# when device mesh is used, specifically with TP
|
| 1123 |
+
# then there is need to update process_index and num_processes
|
| 1124 |
+
# to bring in the effect of generating same batch across TP ranks
|
| 1125 |
+
# and different batch across FSDP and DP ranks.
|
| 1126 |
+
# Example:
|
| 1127 |
+
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
|
| 1128 |
+
# ranks would range from 0...11
|
| 1129 |
+
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
|
| 1130 |
+
# processes with same ranks/ids would receive the same batch
|
| 1131 |
+
# for CP the same as TP applies
|
| 1132 |
+
submesh_fsdp_size = 1
|
| 1133 |
+
submesh_dp_size = 1
|
| 1134 |
+
submesh_tp_size = 1
|
| 1135 |
+
submesh_cp_size = 1
|
| 1136 |
+
if "tp" in torch_device_mesh.mesh_dim_names:
|
| 1137 |
+
submesh_tp_size = torch_device_mesh["tp"].size()
|
| 1138 |
+
if "cp" in torch_device_mesh.mesh_dim_names:
|
| 1139 |
+
submesh_cp_size = torch_device_mesh["cp"].size()
|
| 1140 |
+
if "dp_replicate" in torch_device_mesh.mesh_dim_names:
|
| 1141 |
+
submesh_dp_size = torch_device_mesh["dp_replicate"].size()
|
| 1142 |
+
if "dp_shard" in torch_device_mesh.mesh_dim_names:
|
| 1143 |
+
submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
|
| 1144 |
+
process_index = process_index // (submesh_tp_size * submesh_cp_size)
|
| 1145 |
+
num_processes = submesh_fsdp_size * submesh_dp_size
|
| 1146 |
+
|
| 1147 |
+
# Sanity check
|
| 1148 |
+
if split_batches:
|
| 1149 |
+
if dataloader.batch_size is not None:
|
| 1150 |
+
batch_size_for_check = dataloader.batch_size
|
| 1151 |
+
else:
|
| 1152 |
+
# For custom batch_sampler
|
| 1153 |
+
if hasattr(dataloader.batch_sampler, "batch_size"):
|
| 1154 |
+
batch_size_for_check = dataloader.batch_sampler.batch_size
|
| 1155 |
+
else:
|
| 1156 |
+
raise ValueError(
|
| 1157 |
+
"In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
|
| 1158 |
+
"`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
|
| 1159 |
+
"Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
|
| 1160 |
+
f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
|
| 1161 |
+
)
|
| 1162 |
+
|
| 1163 |
+
if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
|
| 1164 |
+
raise ValueError(
|
| 1165 |
+
f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
|
| 1166 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
| 1167 |
+
)
|
| 1168 |
+
|
| 1169 |
+
new_dataset = dataloader.dataset
|
| 1170 |
+
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
|
| 1171 |
+
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
|
| 1172 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
| 1173 |
+
synchronized_generator = None
|
| 1174 |
+
|
| 1175 |
+
sampler = get_sampler(dataloader)
|
| 1176 |
+
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
|
| 1177 |
+
# When iterating through the dataloader during distributed processes
|
| 1178 |
+
# we want to ensure that on each process we are iterating through the same
|
| 1179 |
+
# samples in the same order if a seed is set. This requires a tweak
|
| 1180 |
+
# to the `torch.utils.data.RandomSampler` class (if used).
|
| 1181 |
+
sampler = SeedableRandomSampler(
|
| 1182 |
+
data_source=sampler.data_source,
|
| 1183 |
+
replacement=sampler.replacement,
|
| 1184 |
+
num_samples=sampler._num_samples,
|
| 1185 |
+
generator=getattr(
|
| 1186 |
+
sampler,
|
| 1187 |
+
"generator",
|
| 1188 |
+
torch.Generator(device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"),
|
| 1189 |
+
),
|
| 1190 |
+
data_seed=data_seed,
|
| 1191 |
+
)
|
| 1192 |
+
|
| 1193 |
+
if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
|
| 1194 |
+
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
|
| 1195 |
+
generator = torch.Generator(
|
| 1196 |
+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
| 1197 |
+
)
|
| 1198 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 1199 |
+
generator.manual_seed(seed)
|
| 1200 |
+
dataloader.generator = generator
|
| 1201 |
+
dataloader.sampler.generator = generator
|
| 1202 |
+
# No change if no multiprocess
|
| 1203 |
+
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
|
| 1204 |
+
if is_datasets_available():
|
| 1205 |
+
from datasets import IterableDataset as DatasetsIterableDataset
|
| 1206 |
+
if (
|
| 1207 |
+
is_datasets_available()
|
| 1208 |
+
and isinstance(new_dataset, DatasetsIterableDataset)
|
| 1209 |
+
and not split_batches
|
| 1210 |
+
and new_dataset.n_shards > num_processes
|
| 1211 |
+
):
|
| 1212 |
+
new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)
|
| 1213 |
+
elif isinstance(new_dataset, IterableDataset):
|
| 1214 |
+
if getattr(dataloader.dataset, "generator", None) is not None:
|
| 1215 |
+
synchronized_generator = dataloader.dataset.generator
|
| 1216 |
+
new_dataset = IterableDatasetShard(
|
| 1217 |
+
new_dataset,
|
| 1218 |
+
batch_size=dataloader.batch_size,
|
| 1219 |
+
drop_last=dataloader.drop_last,
|
| 1220 |
+
num_processes=num_processes,
|
| 1221 |
+
process_index=process_index,
|
| 1222 |
+
split_batches=split_batches,
|
| 1223 |
+
)
|
| 1224 |
+
else:
|
| 1225 |
+
if not use_seedable_sampler and hasattr(sampler, "generator"):
|
| 1226 |
+
if sampler.generator is None:
|
| 1227 |
+
sampler.generator = torch.Generator(
|
| 1228 |
+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
| 1229 |
+
)
|
| 1230 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 1231 |
+
sampler.generator.manual_seed(seed)
|
| 1232 |
+
synchronized_generator = sampler.generator
|
| 1233 |
+
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
| 1234 |
+
new_batch_sampler = BatchSamplerShard(
|
| 1235 |
+
batch_sampler,
|
| 1236 |
+
num_processes=num_processes,
|
| 1237 |
+
process_index=process_index,
|
| 1238 |
+
split_batches=split_batches,
|
| 1239 |
+
even_batches=even_batches,
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
# We ignore all of those since they are all dealt with by our new_batch_sampler
|
| 1243 |
+
ignore_kwargs = [
|
| 1244 |
+
"batch_size",
|
| 1245 |
+
"shuffle",
|
| 1246 |
+
"sampler",
|
| 1247 |
+
"batch_sampler",
|
| 1248 |
+
"drop_last",
|
| 1249 |
+
]
|
| 1250 |
+
|
| 1251 |
+
if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
|
| 1252 |
+
rng_types.remove("generator")
|
| 1253 |
+
|
| 1254 |
+
kwargs = {
|
| 1255 |
+
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
| 1256 |
+
for k in _PYTORCH_DATALOADER_KWARGS
|
| 1257 |
+
if k not in ignore_kwargs
|
| 1258 |
+
}
|
| 1259 |
+
|
| 1260 |
+
# Need to provide batch_size as batch_sampler is None for Iterable dataset
|
| 1261 |
+
if new_batch_sampler is None:
|
| 1262 |
+
kwargs["drop_last"] = dataloader.drop_last
|
| 1263 |
+
kwargs["batch_size"] = (
|
| 1264 |
+
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
|
| 1265 |
+
)
|
| 1266 |
+
if dispatch_batches:
|
| 1267 |
+
kwargs.pop("generator")
|
| 1268 |
+
dataloader = DataLoaderDispatcher(
|
| 1269 |
+
new_dataset,
|
| 1270 |
+
split_batches=split_batches,
|
| 1271 |
+
batch_sampler=new_batch_sampler,
|
| 1272 |
+
_drop_last=dataloader.drop_last,
|
| 1273 |
+
_non_blocking=non_blocking,
|
| 1274 |
+
slice_fn=slice_fn_for_dispatch,
|
| 1275 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
| 1276 |
+
torch_device_mesh=torch_device_mesh,
|
| 1277 |
+
**kwargs,
|
| 1278 |
+
)
|
| 1279 |
+
elif sampler_is_batch_sampler:
|
| 1280 |
+
dataloader = DataLoaderShard(
|
| 1281 |
+
new_dataset,
|
| 1282 |
+
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
|
| 1283 |
+
sampler=new_batch_sampler,
|
| 1284 |
+
batch_size=dataloader.batch_size,
|
| 1285 |
+
rng_types=rng_types,
|
| 1286 |
+
_drop_last=dataloader.drop_last,
|
| 1287 |
+
_non_blocking=non_blocking,
|
| 1288 |
+
synchronized_generator=synchronized_generator,
|
| 1289 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
| 1290 |
+
**kwargs,
|
| 1291 |
+
)
|
| 1292 |
+
else:
|
| 1293 |
+
dataloader = DataLoaderShard(
|
| 1294 |
+
new_dataset,
|
| 1295 |
+
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
|
| 1296 |
+
batch_sampler=new_batch_sampler,
|
| 1297 |
+
rng_types=rng_types,
|
| 1298 |
+
synchronized_generator=synchronized_generator,
|
| 1299 |
+
_drop_last=dataloader.drop_last,
|
| 1300 |
+
_non_blocking=non_blocking,
|
| 1301 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
| 1302 |
+
**kwargs,
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
|
| 1306 |
+
dataloader.set_sampler(sampler)
|
| 1307 |
+
if state.distributed_type == DistributedType.XLA:
|
| 1308 |
+
return MpDeviceLoaderWrapper(dataloader, device)
|
| 1309 |
+
return dataloader
|
| 1310 |
+
|
| 1311 |
+
|
| 1312 |
+
class SkipBatchSampler(BatchSampler):
|
| 1313 |
+
"""
|
| 1314 |
+
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
|
| 1315 |
+
Should not be used if the original dataloader is a `StatefulDataLoader`.
|
| 1316 |
+
"""
|
| 1317 |
+
|
| 1318 |
+
def __init__(self, batch_sampler, skip_batches=0):
|
| 1319 |
+
self.batch_sampler = batch_sampler
|
| 1320 |
+
self.skip_batches = skip_batches
|
| 1321 |
+
|
| 1322 |
+
def __iter__(self):
|
| 1323 |
+
for index, samples in enumerate(self.batch_sampler):
|
| 1324 |
+
if index >= self.skip_batches:
|
| 1325 |
+
yield samples
|
| 1326 |
+
|
| 1327 |
+
@property
|
| 1328 |
+
def total_length(self):
|
| 1329 |
+
return len(self.batch_sampler)
|
| 1330 |
+
|
| 1331 |
+
def __len__(self):
|
| 1332 |
+
return len(self.batch_sampler) - self.skip_batches
|
| 1333 |
+
|
| 1334 |
+
|
| 1335 |
+
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
|
| 1336 |
+
"""
|
| 1337 |
+
Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
|
| 1338 |
+
`skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
|
| 1339 |
+
|
| 1340 |
+
Args:
|
| 1341 |
+
dataset (`torch.utils.data.dataset.Dataset`):
|
| 1342 |
+
The dataset to use to build this dataloader.
|
| 1343 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
| 1344 |
+
The number of batches to skip at the beginning.
|
| 1345 |
+
kwargs:
|
| 1346 |
+
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
| 1347 |
+
"""
|
| 1348 |
+
|
| 1349 |
+
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
|
| 1350 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
| 1351 |
+
self.skip_batches = skip_batches
|
| 1352 |
+
self.gradient_state = GradientState()
|
| 1353 |
+
|
| 1354 |
+
def __iter__(self):
|
| 1355 |
+
self.begin()
|
| 1356 |
+
for index, batch in enumerate(self.base_dataloader.__iter__()):
|
| 1357 |
+
if index >= self.skip_batches:
|
| 1358 |
+
self._update_state_dict()
|
| 1359 |
+
yield batch
|
| 1360 |
+
self.end()
|
| 1361 |
+
|
| 1362 |
+
def __len__(self):
|
| 1363 |
+
return len(self.base_dataloader) - self.skip_batches
|
| 1364 |
+
|
| 1365 |
+
def __reduce__(self):
|
| 1366 |
+
"""
|
| 1367 |
+
Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
|
| 1368 |
+
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
| 1369 |
+
`__class__` member.
|
| 1370 |
+
"""
|
| 1371 |
+
args = super().__reduce__()
|
| 1372 |
+
return (SkipDataLoader, *args[1:])
|
| 1373 |
+
|
| 1374 |
+
|
| 1375 |
+
def skip_first_batches(dataloader, num_batches=0):
|
| 1376 |
+
"""
|
| 1377 |
+
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
|
| 1378 |
+
the original dataloader is a `StatefulDataLoader`.
|
| 1379 |
+
"""
|
| 1380 |
+
state = PartialState()
|
| 1381 |
+
if state.distributed_type == DistributedType.XLA:
|
| 1382 |
+
device = dataloader.device
|
| 1383 |
+
dataloader = dataloader.dataloader
|
| 1384 |
+
|
| 1385 |
+
dataset = dataloader.dataset
|
| 1386 |
+
sampler_is_batch_sampler = False
|
| 1387 |
+
if isinstance(dataset, IterableDataset):
|
| 1388 |
+
new_batch_sampler = None
|
| 1389 |
+
else:
|
| 1390 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
| 1391 |
+
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
| 1392 |
+
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
|
| 1393 |
+
|
| 1394 |
+
# We ignore all of those since they are all dealt with by our new_batch_sampler
|
| 1395 |
+
ignore_kwargs = [
|
| 1396 |
+
"batch_size",
|
| 1397 |
+
"shuffle",
|
| 1398 |
+
"sampler",
|
| 1399 |
+
"batch_sampler",
|
| 1400 |
+
"drop_last",
|
| 1401 |
+
]
|
| 1402 |
+
|
| 1403 |
+
kwargs = {
|
| 1404 |
+
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
| 1405 |
+
for k in _PYTORCH_DATALOADER_KWARGS
|
| 1406 |
+
if k not in ignore_kwargs
|
| 1407 |
+
}
|
| 1408 |
+
|
| 1409 |
+
# Need to provide batch_size as batch_sampler is None for Iterable dataset
|
| 1410 |
+
if new_batch_sampler is None:
|
| 1411 |
+
kwargs["drop_last"] = dataloader.drop_last
|
| 1412 |
+
kwargs["batch_size"] = dataloader.batch_size
|
| 1413 |
+
|
| 1414 |
+
if isinstance(dataloader, DataLoaderDispatcher):
|
| 1415 |
+
if new_batch_sampler is None:
|
| 1416 |
+
# Need to manually skip batches in the dataloader
|
| 1417 |
+
kwargs["skip_batches"] = num_batches
|
| 1418 |
+
dataloader = DataLoaderDispatcher(
|
| 1419 |
+
dataset,
|
| 1420 |
+
split_batches=dataloader.split_batches,
|
| 1421 |
+
batch_sampler=new_batch_sampler,
|
| 1422 |
+
_drop_last=dataloader._drop_last,
|
| 1423 |
+
**kwargs,
|
| 1424 |
+
)
|
| 1425 |
+
elif isinstance(dataloader, DataLoaderShard):
|
| 1426 |
+
if new_batch_sampler is None:
|
| 1427 |
+
# Need to manually skip batches in the dataloader
|
| 1428 |
+
kwargs["skip_batches"] = num_batches
|
| 1429 |
+
elif sampler_is_batch_sampler:
|
| 1430 |
+
kwargs["sampler"] = new_batch_sampler
|
| 1431 |
+
kwargs["batch_size"] = dataloader.batch_size
|
| 1432 |
+
else:
|
| 1433 |
+
kwargs["batch_sampler"] = new_batch_sampler
|
| 1434 |
+
dataloader = DataLoaderShard(
|
| 1435 |
+
dataset,
|
| 1436 |
+
device=dataloader.device,
|
| 1437 |
+
rng_types=dataloader.rng_types,
|
| 1438 |
+
synchronized_generator=dataloader.synchronized_generator,
|
| 1439 |
+
**kwargs,
|
| 1440 |
+
)
|
| 1441 |
+
else:
|
| 1442 |
+
if new_batch_sampler is None:
|
| 1443 |
+
# Need to manually skip batches in the dataloader
|
| 1444 |
+
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
|
| 1445 |
+
else:
|
| 1446 |
+
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
|
| 1447 |
+
|
| 1448 |
+
if state.distributed_type == DistributedType.XLA:
|
| 1449 |
+
dataloader = MpDeviceLoaderWrapper(dataloader, device)
|
| 1450 |
+
|
| 1451 |
+
return dataloader
|
pythonProject/.venv/Lib/site-packages/accelerate/hooks.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
from collections.abc import Mapping
|
| 17 |
+
from typing import Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from .state import PartialState
|
| 23 |
+
from .utils import (
|
| 24 |
+
PrefixedDataset,
|
| 25 |
+
find_device,
|
| 26 |
+
named_module_tensors,
|
| 27 |
+
send_to_device,
|
| 28 |
+
set_module_tensor_to_device,
|
| 29 |
+
)
|
| 30 |
+
from .utils.imports import (
|
| 31 |
+
is_mlu_available,
|
| 32 |
+
is_musa_available,
|
| 33 |
+
is_npu_available,
|
| 34 |
+
)
|
| 35 |
+
from .utils.memory import clear_device_cache
|
| 36 |
+
from .utils.modeling import get_non_persistent_buffers
|
| 37 |
+
from .utils.other import recursive_getattr
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
_accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "sdaa", "musa"]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ModelHook:
|
| 44 |
+
"""
|
| 45 |
+
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
|
| 46 |
+
with PyTorch existing hooks is that they get passed along the kwargs.
|
| 47 |
+
|
| 48 |
+
Class attribute:
|
| 49 |
+
- **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
|
| 50 |
+
the `torch.no_grad()` context manager.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
no_grad = False
|
| 54 |
+
|
| 55 |
+
def init_hook(self, module):
|
| 56 |
+
"""
|
| 57 |
+
To be executed when the hook is attached to the module.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
module (`torch.nn.Module`): The module attached to this hook.
|
| 61 |
+
"""
|
| 62 |
+
return module
|
| 63 |
+
|
| 64 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
To be executed just before the forward method of the model.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
|
| 70 |
+
args (`Tuple[Any]`): The positional arguments passed to the module.
|
| 71 |
+
kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
`Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
|
| 75 |
+
"""
|
| 76 |
+
return args, kwargs
|
| 77 |
+
|
| 78 |
+
def post_forward(self, module, output):
|
| 79 |
+
"""
|
| 80 |
+
To be executed just after the forward method of the model.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
|
| 84 |
+
output (`Any`): The output of the module.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
`Any`: The processed `output`.
|
| 88 |
+
"""
|
| 89 |
+
return output
|
| 90 |
+
|
| 91 |
+
def detach_hook(self, module):
|
| 92 |
+
"""
|
| 93 |
+
To be executed when the hook is detached from a module.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
module (`torch.nn.Module`): The module detached from this hook.
|
| 97 |
+
"""
|
| 98 |
+
return module
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class SequentialHook(ModelHook):
|
| 102 |
+
"""
|
| 103 |
+
A hook that can contain several hooks and iterates through them at each event.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, *hooks):
|
| 107 |
+
self.hooks = hooks
|
| 108 |
+
|
| 109 |
+
def init_hook(self, module):
|
| 110 |
+
for hook in self.hooks:
|
| 111 |
+
module = hook.init_hook(module)
|
| 112 |
+
return module
|
| 113 |
+
|
| 114 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 115 |
+
for hook in self.hooks:
|
| 116 |
+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
| 117 |
+
return args, kwargs
|
| 118 |
+
|
| 119 |
+
def post_forward(self, module, output):
|
| 120 |
+
for hook in self.hooks:
|
| 121 |
+
output = hook.post_forward(module, output)
|
| 122 |
+
return output
|
| 123 |
+
|
| 124 |
+
def detach_hook(self, module):
|
| 125 |
+
for hook in self.hooks:
|
| 126 |
+
module = hook.detach_hook(module)
|
| 127 |
+
return module
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
|
| 131 |
+
"""
|
| 132 |
+
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
| 133 |
+
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
| 134 |
+
|
| 135 |
+
<Tip warning={true}>
|
| 136 |
+
|
| 137 |
+
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
|
| 138 |
+
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
|
| 139 |
+
|
| 140 |
+
</Tip>
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
module (`torch.nn.Module`):
|
| 144 |
+
The module to attach a hook to.
|
| 145 |
+
hook (`ModelHook`):
|
| 146 |
+
The hook to attach.
|
| 147 |
+
append (`bool`, *optional*, defaults to `False`):
|
| 148 |
+
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
|
| 152 |
+
be discarded).
|
| 153 |
+
"""
|
| 154 |
+
if append and (getattr(module, "_hf_hook", None) is not None):
|
| 155 |
+
old_hook = module._hf_hook
|
| 156 |
+
remove_hook_from_module(module)
|
| 157 |
+
hook = SequentialHook(old_hook, hook)
|
| 158 |
+
|
| 159 |
+
if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
|
| 160 |
+
# If we already put some hook on this module, we replace it with the new one.
|
| 161 |
+
old_forward = module._old_forward
|
| 162 |
+
else:
|
| 163 |
+
old_forward = module.forward
|
| 164 |
+
module._old_forward = old_forward
|
| 165 |
+
|
| 166 |
+
module = hook.init_hook(module)
|
| 167 |
+
module._hf_hook = hook
|
| 168 |
+
|
| 169 |
+
def new_forward(module, *args, **kwargs):
|
| 170 |
+
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
|
| 171 |
+
if module._hf_hook.no_grad:
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
output = module._old_forward(*args, **kwargs)
|
| 174 |
+
else:
|
| 175 |
+
output = module._old_forward(*args, **kwargs)
|
| 176 |
+
return module._hf_hook.post_forward(module, output)
|
| 177 |
+
|
| 178 |
+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
| 179 |
+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
| 180 |
+
if "GraphModuleImpl" in str(type(module)):
|
| 181 |
+
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
| 182 |
+
else:
|
| 183 |
+
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
| 184 |
+
|
| 185 |
+
return module
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def remove_hook_from_module(module: nn.Module, recurse=False):
|
| 189 |
+
"""
|
| 190 |
+
Removes any hook attached to a module via `add_hook_to_module`.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
module (`torch.nn.Module`): The module to attach a hook to.
|
| 194 |
+
recurse (`bool`, **optional**): Whether to remove the hooks recursively
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
`torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
|
| 198 |
+
be discarded).
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
if hasattr(module, "_hf_hook"):
|
| 202 |
+
module._hf_hook.detach_hook(module)
|
| 203 |
+
delattr(module, "_hf_hook")
|
| 204 |
+
|
| 205 |
+
if hasattr(module, "_old_forward"):
|
| 206 |
+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
| 207 |
+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
| 208 |
+
if "GraphModuleImpl" in str(type(module)):
|
| 209 |
+
module.__class__.forward = module._old_forward
|
| 210 |
+
else:
|
| 211 |
+
module.forward = module._old_forward
|
| 212 |
+
delattr(module, "_old_forward")
|
| 213 |
+
|
| 214 |
+
# Remove accelerate added warning hooks from dispatch_model
|
| 215 |
+
for attr in _accelerate_added_attributes:
|
| 216 |
+
module.__dict__.pop(attr, None)
|
| 217 |
+
|
| 218 |
+
if recurse:
|
| 219 |
+
for child in module.children():
|
| 220 |
+
remove_hook_from_module(child, recurse)
|
| 221 |
+
|
| 222 |
+
return module
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class AlignDevicesHook(ModelHook):
|
| 226 |
+
"""
|
| 227 |
+
A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
|
| 228 |
+
associated module, potentially offloading the weights after the forward pass.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
execution_device (`torch.device`, *optional*):
|
| 232 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
| 233 |
+
offload (`bool`, *optional*, defaults to `False`):
|
| 234 |
+
Whether or not the weights should be offloaded after the forward pass.
|
| 235 |
+
io_same_device (`bool`, *optional*, defaults to `False`):
|
| 236 |
+
Whether or not the output should be placed on the same device as the input was.
|
| 237 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
| 238 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
| 239 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 240 |
+
Whether or not to include the associated module's buffers when offloading.
|
| 241 |
+
place_submodules (`bool`, *optional*, defaults to `False`):
|
| 242 |
+
Whether to place the submodules on `execution_device` during the `init_hook` event.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
execution_device: Optional[Union[int, str, torch.device]] = None,
|
| 248 |
+
offload: bool = False,
|
| 249 |
+
io_same_device: bool = False,
|
| 250 |
+
weights_map: Optional[Mapping] = None,
|
| 251 |
+
offload_buffers: bool = False,
|
| 252 |
+
place_submodules: bool = False,
|
| 253 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 254 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 255 |
+
):
|
| 256 |
+
self.execution_device = execution_device
|
| 257 |
+
self.offload = offload
|
| 258 |
+
self.io_same_device = io_same_device
|
| 259 |
+
self.weights_map = weights_map
|
| 260 |
+
self.offload_buffers = offload_buffers
|
| 261 |
+
self.place_submodules = place_submodules
|
| 262 |
+
self.skip_keys = skip_keys
|
| 263 |
+
|
| 264 |
+
# Will contain the input device when `io_same_device=True`.
|
| 265 |
+
self.input_device = None
|
| 266 |
+
self.param_original_devices = {}
|
| 267 |
+
self.buffer_original_devices = {}
|
| 268 |
+
self.tied_params_names = set()
|
| 269 |
+
|
| 270 |
+
# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
|
| 271 |
+
# for tied weights already loaded on the target execution device.
|
| 272 |
+
self.tied_params_map = tied_params_map
|
| 273 |
+
|
| 274 |
+
def __repr__(self):
|
| 275 |
+
return (
|
| 276 |
+
f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
|
| 277 |
+
f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
|
| 278 |
+
f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def init_hook(self, module):
|
| 282 |
+
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
|
| 283 |
+
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
|
| 284 |
+
self.tied_params_map = None
|
| 285 |
+
|
| 286 |
+
if not self.offload and self.execution_device is not None:
|
| 287 |
+
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
|
| 288 |
+
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
|
| 289 |
+
elif self.offload:
|
| 290 |
+
self.original_devices = {
|
| 291 |
+
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
|
| 292 |
+
}
|
| 293 |
+
if self.weights_map is None:
|
| 294 |
+
self.weights_map = {
|
| 295 |
+
name: param.to("cpu")
|
| 296 |
+
for name, param in named_module_tensors(
|
| 297 |
+
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
|
| 298 |
+
)
|
| 299 |
+
}
|
| 300 |
+
for name, _ in named_module_tensors(
|
| 301 |
+
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
|
| 302 |
+
):
|
| 303 |
+
# When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
|
| 304 |
+
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
| 305 |
+
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
|
| 306 |
+
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
|
| 307 |
+
if (
|
| 308 |
+
self.tied_params_map is not None
|
| 309 |
+
and recursive_getattr(module, name).data_ptr() in self.tied_params_map
|
| 310 |
+
):
|
| 311 |
+
self.tied_params_names.add(name)
|
| 312 |
+
|
| 313 |
+
set_module_tensor_to_device(module, name, "meta")
|
| 314 |
+
|
| 315 |
+
if not self.offload_buffers and self.execution_device is not None:
|
| 316 |
+
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
| 317 |
+
set_module_tensor_to_device(
|
| 318 |
+
module, name, self.execution_device, tied_params_map=self.tied_params_map
|
| 319 |
+
)
|
| 320 |
+
elif self.offload_buffers and self.execution_device is not None:
|
| 321 |
+
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
|
| 322 |
+
set_module_tensor_to_device(
|
| 323 |
+
module, name, self.execution_device, tied_params_map=self.tied_params_map
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
return module
|
| 327 |
+
|
| 328 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 329 |
+
if self.io_same_device:
|
| 330 |
+
self.input_device = find_device([args, kwargs])
|
| 331 |
+
if self.offload:
|
| 332 |
+
self.tied_pointers_to_remove = set()
|
| 333 |
+
|
| 334 |
+
for name, _ in named_module_tensors(
|
| 335 |
+
module,
|
| 336 |
+
include_buffers=self.offload_buffers,
|
| 337 |
+
recurse=self.place_submodules,
|
| 338 |
+
remove_non_persistent=True,
|
| 339 |
+
):
|
| 340 |
+
fp16_statistics = None
|
| 341 |
+
value = self.weights_map[name]
|
| 342 |
+
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
|
| 343 |
+
if value.dtype == torch.int8:
|
| 344 |
+
fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
|
| 345 |
+
|
| 346 |
+
# In case we are using offloading with tied weights, we need to keep track of the offloaded weights
|
| 347 |
+
# that are loaded on device at this point, as we will need to remove them as well from the dictionary
|
| 348 |
+
# self.tied_params_map in order to allow to free memory.
|
| 349 |
+
if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
|
| 350 |
+
self.tied_params_map[value.data_ptr()] = {}
|
| 351 |
+
|
| 352 |
+
if (
|
| 353 |
+
value is not None
|
| 354 |
+
and self.tied_params_map is not None
|
| 355 |
+
and value.data_ptr() in self.tied_params_map
|
| 356 |
+
and self.execution_device not in self.tied_params_map[value.data_ptr()]
|
| 357 |
+
):
|
| 358 |
+
self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
|
| 359 |
+
|
| 360 |
+
set_module_tensor_to_device(
|
| 361 |
+
module,
|
| 362 |
+
name,
|
| 363 |
+
self.execution_device,
|
| 364 |
+
value=value,
|
| 365 |
+
fp16_statistics=fp16_statistics,
|
| 366 |
+
tied_params_map=self.tied_params_map,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return send_to_device(args, self.execution_device), send_to_device(
|
| 370 |
+
kwargs, self.execution_device, skip_keys=self.skip_keys
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
def post_forward(self, module, output):
|
| 374 |
+
if self.offload:
|
| 375 |
+
for name, _ in named_module_tensors(
|
| 376 |
+
module,
|
| 377 |
+
include_buffers=self.offload_buffers,
|
| 378 |
+
recurse=self.place_submodules,
|
| 379 |
+
remove_non_persistent=True,
|
| 380 |
+
):
|
| 381 |
+
set_module_tensor_to_device(module, name, "meta")
|
| 382 |
+
if type(module).__name__ == "Linear8bitLt":
|
| 383 |
+
module.state.SCB = None
|
| 384 |
+
module.state.CxB = None
|
| 385 |
+
|
| 386 |
+
# We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
|
| 387 |
+
# this dictionary to allow the garbage collector to do its job.
|
| 388 |
+
for value_pointer, device in self.tied_pointers_to_remove:
|
| 389 |
+
if isinstance(device, int):
|
| 390 |
+
if is_npu_available():
|
| 391 |
+
device = f"npu:{device}"
|
| 392 |
+
elif is_mlu_available():
|
| 393 |
+
device = f"mlu:{device}"
|
| 394 |
+
elif is_musa_available():
|
| 395 |
+
device = f"musa:{device}"
|
| 396 |
+
if device in self.tied_params_map[value_pointer]:
|
| 397 |
+
del self.tied_params_map[value_pointer][device]
|
| 398 |
+
self.tied_pointers_to_remove = set()
|
| 399 |
+
if self.io_same_device and self.input_device is not None:
|
| 400 |
+
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
|
| 401 |
+
|
| 402 |
+
return output
|
| 403 |
+
|
| 404 |
+
def detach_hook(self, module):
|
| 405 |
+
if self.offload:
|
| 406 |
+
for name, device in self.original_devices.items():
|
| 407 |
+
if device != torch.device("meta"):
|
| 408 |
+
set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
|
| 409 |
+
return module
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def attach_execution_device_hook(
|
| 413 |
+
module: torch.nn.Module,
|
| 414 |
+
execution_device: Union[int, str, torch.device],
|
| 415 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 416 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 417 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 418 |
+
):
|
| 419 |
+
"""
|
| 420 |
+
Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
|
| 421 |
+
execution device
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
module (`torch.nn.Module`):
|
| 425 |
+
The module where we want to attach the hooks.
|
| 426 |
+
execution_device (`int`, `str` or `torch.device`):
|
| 427 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
| 428 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 429 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 430 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 431 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 432 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 433 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 434 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 435 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
| 436 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
| 437 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
| 438 |
+
instead of duplicating memory.
|
| 439 |
+
"""
|
| 440 |
+
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
|
| 441 |
+
add_hook_to_module(
|
| 442 |
+
module,
|
| 443 |
+
AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Break the recursion if we get to a preload module.
|
| 447 |
+
if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
|
| 448 |
+
return
|
| 449 |
+
|
| 450 |
+
for child in module.children():
|
| 451 |
+
attach_execution_device_hook(
|
| 452 |
+
child,
|
| 453 |
+
execution_device,
|
| 454 |
+
skip_keys=skip_keys,
|
| 455 |
+
preload_module_classes=preload_module_classes,
|
| 456 |
+
tied_params_map=tied_params_map,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def attach_align_device_hook(
|
| 461 |
+
module: torch.nn.Module,
|
| 462 |
+
execution_device: Optional[torch.device] = None,
|
| 463 |
+
offload: bool = False,
|
| 464 |
+
weights_map: Optional[Mapping] = None,
|
| 465 |
+
offload_buffers: bool = False,
|
| 466 |
+
module_name: str = "",
|
| 467 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 468 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 469 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 470 |
+
):
|
| 471 |
+
"""
|
| 472 |
+
Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
|
| 473 |
+
buffers.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
module (`torch.nn.Module`):
|
| 477 |
+
The module where we want to attach the hooks.
|
| 478 |
+
execution_device (`torch.device`, *optional*):
|
| 479 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
| 480 |
+
offload (`bool`, *optional*, defaults to `False`):
|
| 481 |
+
Whether or not the weights should be offloaded after the forward pass.
|
| 482 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
| 483 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
| 484 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 485 |
+
Whether or not to include the associated module's buffers when offloading.
|
| 486 |
+
module_name (`str`, *optional*, defaults to `""`):
|
| 487 |
+
The name of the module.
|
| 488 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 489 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 490 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 491 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 492 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 493 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 494 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 495 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
| 496 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
| 497 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
| 498 |
+
instead of duplicating memory.
|
| 499 |
+
"""
|
| 500 |
+
# Attach the hook on this module if it has any direct tensor.
|
| 501 |
+
directs = named_module_tensors(module)
|
| 502 |
+
full_offload = (
|
| 503 |
+
offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if len(list(directs)) > 0 or full_offload:
|
| 507 |
+
if weights_map is not None:
|
| 508 |
+
prefix = f"{module_name}." if len(module_name) > 0 else ""
|
| 509 |
+
prefixed_weights_map = PrefixedDataset(weights_map, prefix)
|
| 510 |
+
else:
|
| 511 |
+
prefixed_weights_map = None
|
| 512 |
+
hook = AlignDevicesHook(
|
| 513 |
+
execution_device=execution_device,
|
| 514 |
+
offload=offload,
|
| 515 |
+
weights_map=prefixed_weights_map,
|
| 516 |
+
offload_buffers=offload_buffers,
|
| 517 |
+
place_submodules=full_offload,
|
| 518 |
+
skip_keys=skip_keys,
|
| 519 |
+
tied_params_map=tied_params_map,
|
| 520 |
+
)
|
| 521 |
+
add_hook_to_module(module, hook, append=True)
|
| 522 |
+
|
| 523 |
+
# We stop the recursion in case we hit the full offload.
|
| 524 |
+
if full_offload:
|
| 525 |
+
return
|
| 526 |
+
|
| 527 |
+
# Recurse on all children of the module.
|
| 528 |
+
for child_name, child in module.named_children():
|
| 529 |
+
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
| 530 |
+
attach_align_device_hook(
|
| 531 |
+
child,
|
| 532 |
+
execution_device=execution_device,
|
| 533 |
+
offload=offload,
|
| 534 |
+
weights_map=weights_map,
|
| 535 |
+
offload_buffers=offload_buffers,
|
| 536 |
+
module_name=child_name,
|
| 537 |
+
preload_module_classes=preload_module_classes,
|
| 538 |
+
skip_keys=skip_keys,
|
| 539 |
+
tied_params_map=tied_params_map,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def remove_hook_from_submodules(module: nn.Module):
|
| 544 |
+
"""
|
| 545 |
+
Recursively removes all hooks attached on the submodules of a given model.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
module (`torch.nn.Module`): The module on which to remove all hooks.
|
| 549 |
+
"""
|
| 550 |
+
remove_hook_from_module(module)
|
| 551 |
+
for child in module.children():
|
| 552 |
+
remove_hook_from_submodules(child)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def attach_align_device_hook_on_blocks(
|
| 556 |
+
module: nn.Module,
|
| 557 |
+
execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,
|
| 558 |
+
offload: Union[bool, dict[str, bool]] = False,
|
| 559 |
+
weights_map: Mapping = None,
|
| 560 |
+
offload_buffers: bool = False,
|
| 561 |
+
module_name: str = "",
|
| 562 |
+
skip_keys: Optional[Union[str, list[str]]] = None,
|
| 563 |
+
preload_module_classes: Optional[list[str]] = None,
|
| 564 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 565 |
+
):
|
| 566 |
+
"""
|
| 567 |
+
Attaches `AlignDevicesHook` to all blocks of a given model as needed.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
module (`torch.nn.Module`):
|
| 571 |
+
The module where we want to attach the hooks.
|
| 572 |
+
execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
|
| 573 |
+
The device on which inputs and model weights should be placed before the forward pass. It can be one device
|
| 574 |
+
for the whole module, or a dictionary mapping module name to device.
|
| 575 |
+
offload (`bool`, *optional*, defaults to `False`):
|
| 576 |
+
Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
|
| 577 |
+
module, or a dictionary mapping module name to boolean.
|
| 578 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
| 579 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
| 580 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 581 |
+
Whether or not to include the associated module's buffers when offloading.
|
| 582 |
+
module_name (`str`, *optional*, defaults to `""`):
|
| 583 |
+
The name of the module.
|
| 584 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
| 585 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
| 586 |
+
preload_module_classes (`List[str]`, *optional*):
|
| 587 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
| 588 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
| 589 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
| 590 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
| 591 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
| 592 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
| 593 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
| 594 |
+
instead of duplicating memory.
|
| 595 |
+
"""
|
| 596 |
+
# If one device and one offload, we've got one hook.
|
| 597 |
+
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
|
| 598 |
+
if not offload:
|
| 599 |
+
hook = AlignDevicesHook(
|
| 600 |
+
execution_device=execution_device,
|
| 601 |
+
io_same_device=True,
|
| 602 |
+
skip_keys=skip_keys,
|
| 603 |
+
place_submodules=True,
|
| 604 |
+
tied_params_map=tied_params_map,
|
| 605 |
+
)
|
| 606 |
+
add_hook_to_module(module, hook)
|
| 607 |
+
else:
|
| 608 |
+
attach_align_device_hook(
|
| 609 |
+
module,
|
| 610 |
+
execution_device=execution_device,
|
| 611 |
+
offload=True,
|
| 612 |
+
weights_map=weights_map,
|
| 613 |
+
offload_buffers=offload_buffers,
|
| 614 |
+
module_name=module_name,
|
| 615 |
+
skip_keys=skip_keys,
|
| 616 |
+
tied_params_map=tied_params_map,
|
| 617 |
+
)
|
| 618 |
+
return
|
| 619 |
+
|
| 620 |
+
if not isinstance(execution_device, Mapping):
|
| 621 |
+
execution_device = {key: execution_device for key in offload.keys()}
|
| 622 |
+
if not isinstance(offload, Mapping):
|
| 623 |
+
offload = {key: offload for key in execution_device.keys()}
|
| 624 |
+
|
| 625 |
+
if module_name in execution_device and module_name in offload and not offload[module_name]:
|
| 626 |
+
hook = AlignDevicesHook(
|
| 627 |
+
execution_device=execution_device[module_name],
|
| 628 |
+
offload_buffers=offload_buffers,
|
| 629 |
+
io_same_device=(module_name == ""),
|
| 630 |
+
place_submodules=True,
|
| 631 |
+
skip_keys=skip_keys,
|
| 632 |
+
tied_params_map=tied_params_map,
|
| 633 |
+
)
|
| 634 |
+
add_hook_to_module(module, hook)
|
| 635 |
+
attach_execution_device_hook(
|
| 636 |
+
module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
|
| 637 |
+
)
|
| 638 |
+
elif module_name in execution_device and module_name in offload:
|
| 639 |
+
attach_align_device_hook(
|
| 640 |
+
module,
|
| 641 |
+
execution_device=execution_device[module_name],
|
| 642 |
+
offload=True,
|
| 643 |
+
weights_map=weights_map,
|
| 644 |
+
offload_buffers=offload_buffers,
|
| 645 |
+
module_name=module_name,
|
| 646 |
+
skip_keys=skip_keys,
|
| 647 |
+
preload_module_classes=preload_module_classes,
|
| 648 |
+
tied_params_map=tied_params_map,
|
| 649 |
+
)
|
| 650 |
+
if not hasattr(module, "_hf_hook"):
|
| 651 |
+
hook = AlignDevicesHook(
|
| 652 |
+
execution_device=execution_device[module_name],
|
| 653 |
+
io_same_device=(module_name == ""),
|
| 654 |
+
skip_keys=skip_keys,
|
| 655 |
+
tied_params_map=tied_params_map,
|
| 656 |
+
)
|
| 657 |
+
add_hook_to_module(module, hook)
|
| 658 |
+
attach_execution_device_hook(
|
| 659 |
+
module,
|
| 660 |
+
execution_device[module_name],
|
| 661 |
+
preload_module_classes=preload_module_classes,
|
| 662 |
+
skip_keys=skip_keys,
|
| 663 |
+
tied_params_map=tied_params_map,
|
| 664 |
+
)
|
| 665 |
+
elif module_name == "":
|
| 666 |
+
hook = AlignDevicesHook(
|
| 667 |
+
execution_device=execution_device.get(""),
|
| 668 |
+
io_same_device=True,
|
| 669 |
+
skip_keys=skip_keys,
|
| 670 |
+
tied_params_map=tied_params_map,
|
| 671 |
+
)
|
| 672 |
+
add_hook_to_module(module, hook)
|
| 673 |
+
|
| 674 |
+
for child_name, child in module.named_children():
|
| 675 |
+
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
| 676 |
+
attach_align_device_hook_on_blocks(
|
| 677 |
+
child,
|
| 678 |
+
execution_device=execution_device,
|
| 679 |
+
offload=offload,
|
| 680 |
+
weights_map=weights_map,
|
| 681 |
+
offload_buffers=offload_buffers,
|
| 682 |
+
module_name=child_name,
|
| 683 |
+
preload_module_classes=preload_module_classes,
|
| 684 |
+
skip_keys=skip_keys,
|
| 685 |
+
tied_params_map=tied_params_map,
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
class CpuOffload(ModelHook):
|
| 690 |
+
"""
|
| 691 |
+
Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
|
| 692 |
+
the forward, the user needs to call the `init_hook` method again for this.
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
execution_device(`str`, `int` or `torch.device`, *optional*):
|
| 696 |
+
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
| 697 |
+
GPU 0 if there is a GPU, and finally to the CPU.
|
| 698 |
+
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
| 699 |
+
The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
|
| 700 |
+
passed, its offload method will be called just before the forward of the model to which this hook is
|
| 701 |
+
attached.
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
def __init__(
|
| 705 |
+
self,
|
| 706 |
+
execution_device: Optional[Union[str, int, torch.device]] = None,
|
| 707 |
+
prev_module_hook: Optional["UserCpuOffloadHook"] = None,
|
| 708 |
+
):
|
| 709 |
+
self.prev_module_hook = prev_module_hook
|
| 710 |
+
|
| 711 |
+
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
|
| 712 |
+
|
| 713 |
+
def init_hook(self, module):
|
| 714 |
+
return module.to("cpu")
|
| 715 |
+
|
| 716 |
+
def pre_forward(self, module, *args, **kwargs):
|
| 717 |
+
if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
|
| 718 |
+
prev_module = self.prev_module_hook.model
|
| 719 |
+
prev_device = next(prev_module.parameters()).device
|
| 720 |
+
|
| 721 |
+
# Only offload the previous module if it is not already on CPU.
|
| 722 |
+
if prev_device != torch.device("cpu"):
|
| 723 |
+
self.prev_module_hook.offload()
|
| 724 |
+
clear_device_cache()
|
| 725 |
+
|
| 726 |
+
# If the current device is already the self.execution_device, we can skip the transfer.
|
| 727 |
+
current_device = next(module.parameters()).device
|
| 728 |
+
if current_device == self.execution_device:
|
| 729 |
+
return args, kwargs
|
| 730 |
+
|
| 731 |
+
module.to(self.execution_device)
|
| 732 |
+
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class UserCpuOffloadHook:
|
| 736 |
+
"""
|
| 737 |
+
A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
|
| 738 |
+
or remove it entirely.
|
| 739 |
+
"""
|
| 740 |
+
|
| 741 |
+
def __init__(self, model, hook):
|
| 742 |
+
self.model = model
|
| 743 |
+
self.hook = hook
|
| 744 |
+
|
| 745 |
+
def offload(self):
|
| 746 |
+
self.hook.init_hook(self.model)
|
| 747 |
+
|
| 748 |
+
def remove(self):
|
| 749 |
+
remove_hook_from_module(self.model)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
class LayerwiseCastingHook(ModelHook):
|
| 753 |
+
r"""
|
| 754 |
+
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
|
| 755 |
+
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
|
| 756 |
+
footprint.
|
| 757 |
+
"""
|
| 758 |
+
|
| 759 |
+
_is_stateful = False
|
| 760 |
+
|
| 761 |
+
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
| 762 |
+
self.storage_dtype = storage_dtype
|
| 763 |
+
self.compute_dtype = compute_dtype
|
| 764 |
+
self.non_blocking = non_blocking
|
| 765 |
+
|
| 766 |
+
def init_hook(self, module: torch.nn.Module):
|
| 767 |
+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
| 768 |
+
return module
|
| 769 |
+
|
| 770 |
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 771 |
+
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
|
| 772 |
+
return args, kwargs
|
| 773 |
+
|
| 774 |
+
def post_forward(self, module: torch.nn.Module, output):
|
| 775 |
+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
| 776 |
+
return output
|
pythonProject/.venv/Lib/site-packages/accelerate/inference.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import math
|
| 15 |
+
from types import MethodType
|
| 16 |
+
from typing import Any, Optional, Union
|
| 17 |
+
|
| 18 |
+
from .state import PartialState
|
| 19 |
+
from .utils import (
|
| 20 |
+
calculate_maximum_sizes,
|
| 21 |
+
convert_bytes,
|
| 22 |
+
copy_tensor_to_devices,
|
| 23 |
+
ignorant_find_batch_size,
|
| 24 |
+
infer_auto_device_map,
|
| 25 |
+
is_pippy_available,
|
| 26 |
+
pad_input_tensors,
|
| 27 |
+
send_to_device,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
|
| 32 |
+
"""
|
| 33 |
+
Calculates the device map for `model` with an offset for PiPPy
|
| 34 |
+
"""
|
| 35 |
+
if num_processes == 1:
|
| 36 |
+
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
|
| 37 |
+
if max_memory is None:
|
| 38 |
+
model_size, shared = calculate_maximum_sizes(model)
|
| 39 |
+
|
| 40 |
+
# Split into `n` chunks for each GPU
|
| 41 |
+
memory = (model_size + shared[0]) / num_processes
|
| 42 |
+
memory = convert_bytes(memory)
|
| 43 |
+
value, ending = memory.split(" ")
|
| 44 |
+
|
| 45 |
+
# Add a chunk to deal with potential extra shared memory instances
|
| 46 |
+
memory = math.ceil(float(value)) * 1.1
|
| 47 |
+
memory = f"{memory} {ending}"
|
| 48 |
+
max_memory = {i: memory for i in range(num_processes)}
|
| 49 |
+
device_map = infer_auto_device_map(
|
| 50 |
+
model,
|
| 51 |
+
max_memory=max_memory,
|
| 52 |
+
no_split_module_classes=no_split_module_classes,
|
| 53 |
+
clean_result=False,
|
| 54 |
+
)
|
| 55 |
+
return device_map
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def find_pippy_batch_size(args, kwargs):
|
| 59 |
+
found_batch_size = None
|
| 60 |
+
if args is not None:
|
| 61 |
+
for arg in args:
|
| 62 |
+
found_batch_size = ignorant_find_batch_size(arg)
|
| 63 |
+
if found_batch_size is not None:
|
| 64 |
+
break
|
| 65 |
+
if kwargs is not None and found_batch_size is None:
|
| 66 |
+
for kwarg in kwargs.values():
|
| 67 |
+
found_batch_size = ignorant_find_batch_size(kwarg)
|
| 68 |
+
if found_batch_size is not None:
|
| 69 |
+
break
|
| 70 |
+
return found_batch_size
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def build_pipeline(model, split_points, args, kwargs, num_chunks):
|
| 74 |
+
"""
|
| 75 |
+
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
|
| 76 |
+
in needed `args` and `kwargs` as the model needs on the CPU.
|
| 77 |
+
|
| 78 |
+
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
|
| 79 |
+
`AcceleratorState.num_processes`
|
| 80 |
+
"""
|
| 81 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
| 82 |
+
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
|
| 83 |
+
|
| 84 |
+
# We need to annotate the split points in the model for PiPPy
|
| 85 |
+
state = PartialState()
|
| 86 |
+
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
|
| 87 |
+
pipe = pipeline(
|
| 88 |
+
model,
|
| 89 |
+
mb_args=args,
|
| 90 |
+
mb_kwargs=kwargs,
|
| 91 |
+
split_spec=split_spec,
|
| 92 |
+
)
|
| 93 |
+
stage = pipe.build_stage(state.local_process_index, device=state.device)
|
| 94 |
+
schedule = ScheduleGPipe(stage, num_chunks)
|
| 95 |
+
|
| 96 |
+
return schedule
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
|
| 100 |
+
state = PartialState()
|
| 101 |
+
output = None
|
| 102 |
+
|
| 103 |
+
if state.num_processes == 1:
|
| 104 |
+
output = forward(*args, **kwargs)
|
| 105 |
+
elif state.is_local_main_process:
|
| 106 |
+
found_batch_size = find_pippy_batch_size(args, kwargs)
|
| 107 |
+
if found_batch_size is None:
|
| 108 |
+
raise ValueError("Could not find batch size from args or kwargs")
|
| 109 |
+
else:
|
| 110 |
+
if found_batch_size != num_chunks:
|
| 111 |
+
args = pad_input_tensors(args, found_batch_size, num_chunks)
|
| 112 |
+
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
|
| 113 |
+
forward(*args, **kwargs)
|
| 114 |
+
elif state.is_last_process:
|
| 115 |
+
output = forward()
|
| 116 |
+
else:
|
| 117 |
+
forward()
|
| 118 |
+
if gather_output:
|
| 119 |
+
# Each node will get a copy of the full output which is only on the last GPU
|
| 120 |
+
output = copy_tensor_to_devices(output)
|
| 121 |
+
return output
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def prepare_pippy(
|
| 125 |
+
model,
|
| 126 |
+
split_points: Optional[Union[str, list[str]]] = "auto",
|
| 127 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 128 |
+
example_args: Optional[tuple[Any]] = (),
|
| 129 |
+
example_kwargs: Optional[dict[str, Any]] = None,
|
| 130 |
+
num_chunks: Optional[int] = None,
|
| 131 |
+
gather_output: Optional[bool] = False,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Wraps `model` for pipeline parallel inference.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
model (`torch.nn.Module`):
|
| 138 |
+
A model we want to split for pipeline-parallel inference
|
| 139 |
+
split_points (`str` or `List[str]`, defaults to 'auto'):
|
| 140 |
+
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
|
| 141 |
+
split given any model. Should be a list of layer names in the model to split by otherwise.
|
| 142 |
+
no_split_module_classes (`List[str]`):
|
| 143 |
+
A list of class names for layers we don't want to be split.
|
| 144 |
+
example_args (tuple of model inputs):
|
| 145 |
+
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
|
| 146 |
+
this method if possible.
|
| 147 |
+
example_kwargs (dict of model inputs)
|
| 148 |
+
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
|
| 149 |
+
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not
|
| 150 |
+
recommended unless the prior condition is true for all cases.
|
| 151 |
+
num_chunks (`int`, defaults to the number of available GPUs):
|
| 152 |
+
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
|
| 153 |
+
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
|
| 154 |
+
gather_output (`bool`, defaults to `False`):
|
| 155 |
+
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
|
| 156 |
+
"""
|
| 157 |
+
if not is_pippy_available():
|
| 158 |
+
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
|
| 159 |
+
state = PartialState()
|
| 160 |
+
example_args = send_to_device(example_args, "cpu")
|
| 161 |
+
example_kwargs = send_to_device(example_kwargs, "cpu")
|
| 162 |
+
if num_chunks is None:
|
| 163 |
+
num_chunks = state.num_processes
|
| 164 |
+
if split_points == "auto":
|
| 165 |
+
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
|
| 166 |
+
split_points = []
|
| 167 |
+
for i in range(1, num_chunks):
|
| 168 |
+
split_points.append(next(k for k, v in device_map.items() if v == i))
|
| 169 |
+
model.hf_split_points = split_points
|
| 170 |
+
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
|
| 171 |
+
model._original_forward = model.forward
|
| 172 |
+
model._original_call = model.__call__
|
| 173 |
+
model.pippy_stage = stage
|
| 174 |
+
model.hf_split_points = split_points
|
| 175 |
+
|
| 176 |
+
def forward(*args, **kwargs):
|
| 177 |
+
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
|
| 178 |
+
|
| 179 |
+
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
| 180 |
+
# Note: creates an infinite recursion loop with `generate`
|
| 181 |
+
model_forward = MethodType(forward, model)
|
| 182 |
+
forward.__wrapped__ = model_forward
|
| 183 |
+
model.forward = forward
|
| 184 |
+
return model
|
pythonProject/.venv/Lib/site-packages/accelerate/launchers.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import tempfile
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from .state import AcceleratorState, PartialState
|
| 22 |
+
from .utils import (
|
| 23 |
+
PrecisionType,
|
| 24 |
+
PrepareForLaunch,
|
| 25 |
+
are_libraries_initialized,
|
| 26 |
+
check_cuda_p2p_ib_support,
|
| 27 |
+
get_gpu_info,
|
| 28 |
+
is_mps_available,
|
| 29 |
+
is_torch_version,
|
| 30 |
+
patch_environment,
|
| 31 |
+
)
|
| 32 |
+
from .utils.constants import ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_launch():
|
| 36 |
+
"Verify a `PartialState` can be initialized."
|
| 37 |
+
_ = PartialState()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def notebook_launcher(
|
| 41 |
+
function,
|
| 42 |
+
args=(),
|
| 43 |
+
num_processes=None,
|
| 44 |
+
mixed_precision="no",
|
| 45 |
+
use_port="29500",
|
| 46 |
+
master_addr="127.0.0.1",
|
| 47 |
+
node_rank=0,
|
| 48 |
+
num_nodes=1,
|
| 49 |
+
rdzv_backend="static",
|
| 50 |
+
rdzv_endpoint="",
|
| 51 |
+
rdzv_conf=None,
|
| 52 |
+
rdzv_id="none",
|
| 53 |
+
max_restarts=0,
|
| 54 |
+
monitor_interval=0.1,
|
| 55 |
+
log_line_prefix_template=None,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Launches a training function, using several processes or multiple nodes if it's possible in the current environment
|
| 59 |
+
(TPU with multiple cores for instance).
|
| 60 |
+
|
| 61 |
+
<Tip warning={true}>
|
| 62 |
+
|
| 63 |
+
To use this function absolutely zero calls to a device must be made in the notebook session before calling. If any
|
| 64 |
+
have been made, you will need to restart the notebook and make sure no cells use any device capability.
|
| 65 |
+
|
| 66 |
+
Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
|
| 67 |
+
of those calls have been made.
|
| 68 |
+
|
| 69 |
+
</Tip>
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
function (`Callable`):
|
| 73 |
+
The training function to execute. If it accepts arguments, the first argument should be the index of the
|
| 74 |
+
process run.
|
| 75 |
+
args (`Tuple`):
|
| 76 |
+
Tuple of arguments to pass to the function (it will receive `*args`).
|
| 77 |
+
num_processes (`int`, *optional*):
|
| 78 |
+
The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
|
| 79 |
+
the number of devices available otherwise.
|
| 80 |
+
mixed_precision (`str`, *optional*, defaults to `"no"`):
|
| 81 |
+
If `fp16` or `bf16`, will use mixed precision training on multi-device.
|
| 82 |
+
use_port (`str`, *optional*, defaults to `"29500"`):
|
| 83 |
+
The port to use to communicate between processes when launching a multi-device training.
|
| 84 |
+
master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
|
| 85 |
+
The address to use for communication between processes.
|
| 86 |
+
node_rank (`int`, *optional*, defaults to 0):
|
| 87 |
+
The rank of the current node.
|
| 88 |
+
num_nodes (`int`, *optional*, defaults to 1):
|
| 89 |
+
The number of nodes to use for training.
|
| 90 |
+
rdzv_backend (`str`, *optional*, defaults to `"static"`):
|
| 91 |
+
The rendezvous method to use, such as 'static' (the default) or 'c10d'
|
| 92 |
+
rdzv_endpoint (`str`, *optional*, defaults to `""`):
|
| 93 |
+
The endpoint of the rdzv sync. storage.
|
| 94 |
+
rdzv_conf (`Dict`, *optional*, defaults to `None`):
|
| 95 |
+
Additional rendezvous configuration.
|
| 96 |
+
rdzv_id (`str`, *optional*, defaults to `"none"`):
|
| 97 |
+
The unique run id of the job.
|
| 98 |
+
max_restarts (`int`, *optional*, defaults to 0):
|
| 99 |
+
The maximum amount of restarts that elastic agent will conduct on workers before failure.
|
| 100 |
+
monitor_interval (`float`, *optional*, defaults to 0.1):
|
| 101 |
+
The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
|
| 102 |
+
log_line_prefix_template (`str`, *optional*, defaults to `None`):
|
| 103 |
+
The prefix template for elastic launch logging. Available from PyTorch 2.2.0.
|
| 104 |
+
|
| 105 |
+
Example:
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
# Assume this is defined in a Jupyter Notebook on an instance with two devices
|
| 109 |
+
from accelerate import notebook_launcher
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def train(*args):
|
| 113 |
+
# Your training function here
|
| 114 |
+
...
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16")
|
| 118 |
+
```
|
| 119 |
+
"""
|
| 120 |
+
# Are we in a google colab or a Kaggle Kernel?
|
| 121 |
+
in_colab = False
|
| 122 |
+
in_kaggle = False
|
| 123 |
+
if any(key.startswith("KAGGLE") for key in os.environ.keys()):
|
| 124 |
+
in_kaggle = True
|
| 125 |
+
elif "IPython" in sys.modules:
|
| 126 |
+
in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
mixed_precision = PrecisionType(mixed_precision.lower())
|
| 130 |
+
except ValueError:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if (in_colab or in_kaggle) and (
|
| 136 |
+
(os.environ.get("TPU_NAME", None) is not None) or (os.environ.get("PJRT_DEVICE", "") == "TPU")
|
| 137 |
+
):
|
| 138 |
+
# TPU launch
|
| 139 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
| 140 |
+
|
| 141 |
+
if len(AcceleratorState._shared_state) > 0:
|
| 142 |
+
raise ValueError(
|
| 143 |
+
"To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
|
| 144 |
+
"your training function. Restart your notebook and make sure no cells initializes an "
|
| 145 |
+
"`Accelerator`."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
launcher = PrepareForLaunch(function, distributed_type="XLA")
|
| 149 |
+
print("Launching a training on TPU cores.")
|
| 150 |
+
xmp.spawn(launcher, args=args, start_method="fork")
|
| 151 |
+
elif in_colab and get_gpu_info()[1] < 2:
|
| 152 |
+
# No need for a distributed launch otherwise as it's either CPU or one GPU.
|
| 153 |
+
if torch.cuda.is_available():
|
| 154 |
+
print("Launching training on one GPU.")
|
| 155 |
+
else:
|
| 156 |
+
print("Launching training on one CPU.")
|
| 157 |
+
function(*args)
|
| 158 |
+
else:
|
| 159 |
+
if num_processes is None:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"You have to specify the number of devices you would like to use, add `num_processes=...` to your call."
|
| 162 |
+
)
|
| 163 |
+
if node_rank >= num_nodes:
|
| 164 |
+
raise ValueError("The node_rank must be less than the number of nodes.")
|
| 165 |
+
if num_processes > 1:
|
| 166 |
+
# Multi-device launch
|
| 167 |
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
| 168 |
+
from torch.multiprocessing import start_processes
|
| 169 |
+
from torch.multiprocessing.spawn import ProcessRaisedException
|
| 170 |
+
|
| 171 |
+
if len(AcceleratorState._shared_state) > 0:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
"To launch a multi-device training from your notebook, the `Accelerator` should only be initialized "
|
| 174 |
+
"inside your training function. Restart your notebook and make sure no cells initializes an "
|
| 175 |
+
"`Accelerator`."
|
| 176 |
+
)
|
| 177 |
+
# Check for specific libraries known to initialize device that users constantly use
|
| 178 |
+
problematic_imports = are_libraries_initialized("bitsandbytes")
|
| 179 |
+
if len(problematic_imports) > 0:
|
| 180 |
+
err = (
|
| 181 |
+
"Could not start distributed process. Libraries known to initialize device upon import have been "
|
| 182 |
+
"imported already. Please keep these imports inside your training function to try and help with this:"
|
| 183 |
+
)
|
| 184 |
+
for lib_name in problematic_imports:
|
| 185 |
+
err += f"\n\t* `{lib_name}`"
|
| 186 |
+
raise RuntimeError(err)
|
| 187 |
+
|
| 188 |
+
patched_env = dict(
|
| 189 |
+
nproc=num_processes,
|
| 190 |
+
node_rank=node_rank,
|
| 191 |
+
world_size=num_nodes * num_processes,
|
| 192 |
+
master_addr=master_addr,
|
| 193 |
+
master_port=use_port,
|
| 194 |
+
mixed_precision=mixed_precision,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Check for CUDA P2P and IB issues
|
| 198 |
+
if not check_cuda_p2p_ib_support():
|
| 199 |
+
patched_env["nccl_p2p_disable"] = "1"
|
| 200 |
+
patched_env["nccl_ib_disable"] = "1"
|
| 201 |
+
|
| 202 |
+
# torch.distributed will expect a few environment variable to be here. We set the ones common to each
|
| 203 |
+
# process here (the other ones will be set be the launcher).
|
| 204 |
+
with patch_environment(**patched_env):
|
| 205 |
+
# First dummy launch
|
| 206 |
+
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
| 207 |
+
distributed_type = "MULTI_XPU" if device_type == "xpu" else "MULTI_GPU"
|
| 208 |
+
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
|
| 209 |
+
launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)
|
| 210 |
+
try:
|
| 211 |
+
start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
|
| 212 |
+
except ProcessRaisedException as e:
|
| 213 |
+
err = "An issue was found when verifying a stable environment for the notebook launcher."
|
| 214 |
+
if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
|
| 215 |
+
raise RuntimeError(
|
| 216 |
+
f"{err}"
|
| 217 |
+
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
| 218 |
+
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
| 219 |
+
f"which one is problematic and causing {device_type.upper()} to be initialized."
|
| 220 |
+
) from e
|
| 221 |
+
else:
|
| 222 |
+
raise RuntimeError(f"{err} The following error was raised: {e}") from e
|
| 223 |
+
# Now the actual launch
|
| 224 |
+
launcher = PrepareForLaunch(function, distributed_type=distributed_type)
|
| 225 |
+
print(f"Launching training on {num_processes} {device_type.upper()}s.")
|
| 226 |
+
try:
|
| 227 |
+
if rdzv_conf is None:
|
| 228 |
+
rdzv_conf = {}
|
| 229 |
+
if rdzv_backend == "static":
|
| 230 |
+
rdzv_conf["rank"] = node_rank
|
| 231 |
+
if not rdzv_endpoint:
|
| 232 |
+
rdzv_endpoint = f"{master_addr}:{use_port}"
|
| 233 |
+
launch_config_kwargs = dict(
|
| 234 |
+
min_nodes=num_nodes,
|
| 235 |
+
max_nodes=num_nodes,
|
| 236 |
+
nproc_per_node=num_processes,
|
| 237 |
+
run_id=rdzv_id,
|
| 238 |
+
rdzv_endpoint=rdzv_endpoint,
|
| 239 |
+
rdzv_backend=rdzv_backend,
|
| 240 |
+
rdzv_configs=rdzv_conf,
|
| 241 |
+
max_restarts=max_restarts,
|
| 242 |
+
monitor_interval=monitor_interval,
|
| 243 |
+
start_method="fork",
|
| 244 |
+
)
|
| 245 |
+
if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
|
| 246 |
+
launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
|
| 247 |
+
elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
|
| 248 |
+
except ProcessRaisedException as e:
|
| 249 |
+
if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
|
| 250 |
+
raise RuntimeError(
|
| 251 |
+
f"{device_type.upper()} has been initialized before the `notebook_launcher` could create a forked subprocess. "
|
| 252 |
+
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
| 253 |
+
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
| 254 |
+
f"which one is problematic and causing {device_type.upper()} to be initialized."
|
| 255 |
+
) from e
|
| 256 |
+
else:
|
| 257 |
+
raise RuntimeError(f"An issue was found when launching the training: {e}") from e
|
| 258 |
+
|
| 259 |
+
else:
|
| 260 |
+
# No need for a distributed launch otherwise as it's either CPU, GPU, XPU or MPS.
|
| 261 |
+
if is_mps_available():
|
| 262 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 263 |
+
print("Launching training on MPS.")
|
| 264 |
+
elif torch.cuda.is_available():
|
| 265 |
+
print("Launching training on one GPU.")
|
| 266 |
+
elif torch.xpu.is_available():
|
| 267 |
+
print("Launching training on one XPU.")
|
| 268 |
+
else:
|
| 269 |
+
print("Launching training on CPU.")
|
| 270 |
+
function(*args)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def debug_launcher(function, args=(), num_processes=2):
|
| 274 |
+
"""
|
| 275 |
+
Launches a training function using several processes on CPU for debugging purposes.
|
| 276 |
+
|
| 277 |
+
<Tip warning={true}>
|
| 278 |
+
|
| 279 |
+
This function is provided for internal testing and debugging, but it's not intended for real trainings. It will
|
| 280 |
+
only use the CPU.
|
| 281 |
+
|
| 282 |
+
</Tip>
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
function (`Callable`):
|
| 286 |
+
The training function to execute.
|
| 287 |
+
args (`Tuple`):
|
| 288 |
+
Tuple of arguments to pass to the function (it will receive `*args`).
|
| 289 |
+
num_processes (`int`, *optional*, defaults to 2):
|
| 290 |
+
The number of processes to use for training.
|
| 291 |
+
"""
|
| 292 |
+
from torch.multiprocessing import start_processes
|
| 293 |
+
|
| 294 |
+
with tempfile.NamedTemporaryFile() as tmp_file:
|
| 295 |
+
# torch.distributed will expect a few environment variable to be here. We set the ones common to each
|
| 296 |
+
# process here (the other ones will be set be the launcher).
|
| 297 |
+
with patch_environment(
|
| 298 |
+
world_size=num_processes,
|
| 299 |
+
master_addr="127.0.0.1",
|
| 300 |
+
master_port="29500",
|
| 301 |
+
accelerate_mixed_precision="no",
|
| 302 |
+
accelerate_debug_rdv_file=tmp_file.name,
|
| 303 |
+
accelerate_use_cpu="yes",
|
| 304 |
+
):
|
| 305 |
+
launcher = PrepareForLaunch(function, debug=True)
|
| 306 |
+
start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
|
pythonProject/.venv/Lib/site-packages/accelerate/local_sgd.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from accelerate import Accelerator, DistributedType
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LocalSGD:
|
| 20 |
+
"""
|
| 21 |
+
A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
|
| 22 |
+
on each device, and averages model weights every K synchronization step.
|
| 23 |
+
|
| 24 |
+
It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
|
| 25 |
+
this is a simple implementation that cannot support scenarios such as model parallelism.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
|
| 29 |
+
back to at least:
|
| 30 |
+
|
| 31 |
+
Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
|
| 32 |
+
arXiv:1606.07365.](https://arxiv.org/abs/1606.07365)
|
| 33 |
+
|
| 34 |
+
We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).
|
| 35 |
+
|
| 36 |
+
Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
|
| 37 |
+
Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767)
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __enter__(self):
|
| 42 |
+
if self.enabled:
|
| 43 |
+
self.model_sync_obj = self.model.no_sync()
|
| 44 |
+
self.model_sync_obj.__enter__()
|
| 45 |
+
|
| 46 |
+
return self
|
| 47 |
+
|
| 48 |
+
def __exit__(self, type, value, tb):
|
| 49 |
+
if self.enabled:
|
| 50 |
+
# Average all models on exit
|
| 51 |
+
self._sync_and_avg_model_params()
|
| 52 |
+
self.model_sync_obj.__exit__(type, value, tb)
|
| 53 |
+
|
| 54 |
+
def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
|
| 55 |
+
"""
|
| 56 |
+
Constructor.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model (`torch.nn.Module):
|
| 60 |
+
The model whose parameters we need to average.
|
| 61 |
+
accelerator (`Accelerator`):
|
| 62 |
+
Accelerator object.
|
| 63 |
+
local_sgd_steps (`int`):
|
| 64 |
+
A number of local SGD steps (before model parameters are synchronized).
|
| 65 |
+
enabled (`bool):
|
| 66 |
+
Local SGD is disabled if this parameter set to `False`.
|
| 67 |
+
"""
|
| 68 |
+
if accelerator.distributed_type not in [
|
| 69 |
+
DistributedType.NO,
|
| 70 |
+
DistributedType.MULTI_CPU,
|
| 71 |
+
DistributedType.MULTI_GPU,
|
| 72 |
+
DistributedType.MULTI_XPU,
|
| 73 |
+
DistributedType.MULTI_MLU,
|
| 74 |
+
DistributedType.MULTI_HPU,
|
| 75 |
+
DistributedType.MULTI_SDAA,
|
| 76 |
+
DistributedType.MULTI_MUSA,
|
| 77 |
+
DistributedType.MULTI_NPU,
|
| 78 |
+
]:
|
| 79 |
+
raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
|
| 80 |
+
self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
|
| 81 |
+
self.num_steps = 0
|
| 82 |
+
if self.enabled:
|
| 83 |
+
self.accelerator = accelerator
|
| 84 |
+
self.model = model
|
| 85 |
+
self.local_sgd_steps = local_sgd_steps
|
| 86 |
+
|
| 87 |
+
def step(self):
|
| 88 |
+
"""
|
| 89 |
+
This function makes a "step" and synchronizes model parameters if necessary.
|
| 90 |
+
"""
|
| 91 |
+
self.num_steps += 1
|
| 92 |
+
if not self.enabled:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
if self.num_steps % self.local_sgd_steps == 0:
|
| 96 |
+
self._sync_and_avg_model_params()
|
| 97 |
+
|
| 98 |
+
def _sync_and_avg_model_params(self):
|
| 99 |
+
"""
|
| 100 |
+
Synchronize + Average model parameters across all GPUs
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
self.accelerator.wait_for_everyone()
|
| 104 |
+
with self.accelerator.autocast():
|
| 105 |
+
for param in self.model.parameters():
|
| 106 |
+
param.data = self.accelerator.reduce(param.data, reduction="mean")
|
pythonProject/.venv/Lib/site-packages/accelerate/logging.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from .state import PartialState
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MultiProcessAdapter(logging.LoggerAdapter):
|
| 23 |
+
"""
|
| 24 |
+
An adapter to assist with logging in multiprocess.
|
| 25 |
+
|
| 26 |
+
`log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
|
| 27 |
+
or only the main executed one. Default is `main_process_only=True`.
|
| 28 |
+
|
| 29 |
+
Does not require an `Accelerator` object to be created first.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def _should_log(main_process_only):
|
| 34 |
+
"Check if log should be performed"
|
| 35 |
+
state = PartialState()
|
| 36 |
+
return not main_process_only or (main_process_only and state.is_main_process)
|
| 37 |
+
|
| 38 |
+
def log(self, level, msg, *args, **kwargs):
|
| 39 |
+
"""
|
| 40 |
+
Delegates logger call after checking if we should log.
|
| 41 |
+
|
| 42 |
+
Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
|
| 43 |
+
or only the main executed one. Default is `True` if not passed
|
| 44 |
+
|
| 45 |
+
Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
|
| 46 |
+
read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
|
| 47 |
+
break with the previous behavior.
|
| 48 |
+
|
| 49 |
+
`in_order` is ignored if `main_process_only` is passed.
|
| 50 |
+
"""
|
| 51 |
+
if PartialState._shared_state == {}:
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
"You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
|
| 54 |
+
)
|
| 55 |
+
main_process_only = kwargs.pop("main_process_only", True)
|
| 56 |
+
in_order = kwargs.pop("in_order", False)
|
| 57 |
+
# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
|
| 58 |
+
kwargs.setdefault("stacklevel", 2)
|
| 59 |
+
|
| 60 |
+
if self.isEnabledFor(level):
|
| 61 |
+
if self._should_log(main_process_only):
|
| 62 |
+
msg, kwargs = self.process(msg, kwargs)
|
| 63 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 64 |
+
|
| 65 |
+
elif in_order:
|
| 66 |
+
state = PartialState()
|
| 67 |
+
for i in range(state.num_processes):
|
| 68 |
+
if i == state.process_index:
|
| 69 |
+
msg, kwargs = self.process(msg, kwargs)
|
| 70 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 71 |
+
state.wait_for_everyone()
|
| 72 |
+
|
| 73 |
+
@functools.lru_cache(None)
|
| 74 |
+
def warning_once(self, *args, **kwargs):
|
| 75 |
+
"""
|
| 76 |
+
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
|
| 77 |
+
|
| 78 |
+
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
|
| 79 |
+
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
|
| 80 |
+
switch to another type of cache that includes the caller frame information in the hashing function.
|
| 81 |
+
"""
|
| 82 |
+
self.warning(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_logger(name: str, log_level: str = None):
|
| 86 |
+
"""
|
| 87 |
+
Returns a `logging.Logger` for `name` that can handle multiprocessing.
|
| 88 |
+
|
| 89 |
+
If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
|
| 90 |
+
processes and in order, also pass `in_order=True`
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
name (`str`):
|
| 94 |
+
The name for the logger, such as `__file__`
|
| 95 |
+
log_level (`str`, *optional*):
|
| 96 |
+
The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
|
| 97 |
+
|
| 98 |
+
Example:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
>>> from accelerate.logging import get_logger
|
| 102 |
+
>>> from accelerate import Accelerator
|
| 103 |
+
|
| 104 |
+
>>> logger = get_logger(__name__)
|
| 105 |
+
|
| 106 |
+
>>> accelerator = Accelerator()
|
| 107 |
+
>>> logger.info("My log", main_process_only=False)
|
| 108 |
+
>>> logger.debug("My log", main_process_only=True)
|
| 109 |
+
|
| 110 |
+
>>> logger = get_logger(__name__, log_level="DEBUG")
|
| 111 |
+
>>> logger.info("My log")
|
| 112 |
+
>>> logger.debug("My second log")
|
| 113 |
+
|
| 114 |
+
>>> array = ["a", "b", "c", "d"]
|
| 115 |
+
>>> letter_at_rank = array[accelerator.process_index]
|
| 116 |
+
>>> logger.info(letter_at_rank, in_order=True)
|
| 117 |
+
```
|
| 118 |
+
"""
|
| 119 |
+
if log_level is None:
|
| 120 |
+
log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
|
| 121 |
+
logger = logging.getLogger(name)
|
| 122 |
+
if log_level is not None:
|
| 123 |
+
logger.setLevel(log_level.upper())
|
| 124 |
+
logger.root.setLevel(log_level.upper())
|
| 125 |
+
return MultiProcessAdapter(logger, {})
|
pythonProject/.venv/Lib/site-packages/accelerate/memory_utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
warnings.warn(
|
| 19 |
+
"memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: "
|
| 20 |
+
"`from accelerate import find_executable_batch_size` to avoid this warning.",
|
| 21 |
+
FutureWarning,
|
| 22 |
+
)
|
pythonProject/.venv/Lib/site-packages/accelerate/optimizer.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from .state import AcceleratorState, GradientState
|
| 20 |
+
from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if is_torch_xla_available():
|
| 24 |
+
import torch_xla.core.xla_model as xm
|
| 25 |
+
import torch_xla.runtime as xr
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def move_to_device(state, device):
|
| 29 |
+
if isinstance(state, (list, tuple)):
|
| 30 |
+
return honor_type(state, (move_to_device(t, device) for t in state))
|
| 31 |
+
elif isinstance(state, dict):
|
| 32 |
+
return type(state)({k: move_to_device(v, device) for k, v in state.items()})
|
| 33 |
+
elif isinstance(state, torch.Tensor):
|
| 34 |
+
return state.to(device)
|
| 35 |
+
return state
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AcceleratedOptimizer(torch.optim.Optimizer):
|
| 39 |
+
"""
|
| 40 |
+
Internal wrapper around a torch optimizer.
|
| 41 |
+
|
| 42 |
+
Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
|
| 43 |
+
accumulation.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
optimizer (`torch.optim.optimizer.Optimizer`):
|
| 47 |
+
The optimizer to wrap.
|
| 48 |
+
device_placement (`bool`, *optional*, defaults to `True`):
|
| 49 |
+
Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
|
| 50 |
+
`optimizer` on the right device.
|
| 51 |
+
scaler (`torch.amp.GradScaler` or `torch.cuda.amp.GradScaler`, *optional*):
|
| 52 |
+
The scaler to use in the step function if training with mixed precision.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, optimizer, device_placement=True, scaler=None):
|
| 56 |
+
self.optimizer = optimizer
|
| 57 |
+
self.scaler = scaler
|
| 58 |
+
self.accelerator_state = AcceleratorState()
|
| 59 |
+
self.gradient_state = GradientState()
|
| 60 |
+
self.device_placement = device_placement
|
| 61 |
+
self._is_overflow = False
|
| 62 |
+
|
| 63 |
+
if self.scaler is not None:
|
| 64 |
+
self._accelerate_step_called = False
|
| 65 |
+
self._optimizer_original_step_method = self.optimizer.step
|
| 66 |
+
self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
|
| 67 |
+
|
| 68 |
+
# Handle device placement
|
| 69 |
+
if device_placement:
|
| 70 |
+
state_dict = self.optimizer.state_dict()
|
| 71 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA:
|
| 72 |
+
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
|
| 73 |
+
else:
|
| 74 |
+
state_dict = move_to_device(state_dict, self.accelerator_state.device)
|
| 75 |
+
self.optimizer.load_state_dict(state_dict)
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def state(self):
|
| 79 |
+
return self.optimizer.state
|
| 80 |
+
|
| 81 |
+
@state.setter
|
| 82 |
+
def state(self, state):
|
| 83 |
+
self.optimizer.state = state
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def param_groups(self):
|
| 87 |
+
return self.optimizer.param_groups
|
| 88 |
+
|
| 89 |
+
@param_groups.setter
|
| 90 |
+
def param_groups(self, param_groups):
|
| 91 |
+
self.optimizer.param_groups = param_groups
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def defaults(self):
|
| 95 |
+
return self.optimizer.defaults
|
| 96 |
+
|
| 97 |
+
@defaults.setter
|
| 98 |
+
def defaults(self, defaults):
|
| 99 |
+
self.optimizer.defaults = defaults
|
| 100 |
+
|
| 101 |
+
def add_param_group(self, param_group):
|
| 102 |
+
self.optimizer.add_param_group(param_group)
|
| 103 |
+
|
| 104 |
+
def load_state_dict(self, state_dict):
|
| 105 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
|
| 106 |
+
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
|
| 107 |
+
self.optimizer.load_state_dict(state_dict)
|
| 108 |
+
|
| 109 |
+
def state_dict(self):
|
| 110 |
+
return self.optimizer.state_dict()
|
| 111 |
+
|
| 112 |
+
def zero_grad(self, set_to_none=None):
|
| 113 |
+
if self.gradient_state.sync_gradients:
|
| 114 |
+
accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
|
| 115 |
+
if accept_arg:
|
| 116 |
+
if set_to_none is None:
|
| 117 |
+
set_to_none = True
|
| 118 |
+
self.optimizer.zero_grad(set_to_none=set_to_none)
|
| 119 |
+
else:
|
| 120 |
+
if set_to_none is not None:
|
| 121 |
+
raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
|
| 122 |
+
self.optimizer.zero_grad()
|
| 123 |
+
|
| 124 |
+
def train(self):
|
| 125 |
+
"""
|
| 126 |
+
Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
|
| 127 |
+
"""
|
| 128 |
+
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
|
| 129 |
+
self.optimizer.train()
|
| 130 |
+
elif (
|
| 131 |
+
hasattr(self.optimizer, "optimizer")
|
| 132 |
+
and hasattr(self.optimizer.optimizer, "train")
|
| 133 |
+
and callable(self.optimizer.optimizer.train)
|
| 134 |
+
):
|
| 135 |
+
# the deepspeed optimizer further wraps the optimizer
|
| 136 |
+
self.optimizer.optimizer.train()
|
| 137 |
+
|
| 138 |
+
def eval(self):
|
| 139 |
+
"""
|
| 140 |
+
Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
|
| 141 |
+
"""
|
| 142 |
+
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
|
| 143 |
+
self.optimizer.eval()
|
| 144 |
+
|
| 145 |
+
def step(self, closure=None):
|
| 146 |
+
if is_lomo_available():
|
| 147 |
+
from lomo_optim import AdaLomo, Lomo
|
| 148 |
+
|
| 149 |
+
if (
|
| 150 |
+
not self.gradient_state.is_xla_gradients_synced
|
| 151 |
+
and self.accelerator_state.distributed_type == DistributedType.XLA
|
| 152 |
+
):
|
| 153 |
+
gradients = xm._fetch_gradients(self.optimizer)
|
| 154 |
+
xm.all_reduce("sum", gradients, scale=1.0 / xr.world_size())
|
| 155 |
+
self.gradient_state.is_xla_gradients_synced = True
|
| 156 |
+
|
| 157 |
+
if is_lomo_available():
|
| 158 |
+
# `step` should be a no-op for LOMO optimizers.
|
| 159 |
+
if isinstance(self.optimizer, (Lomo, AdaLomo)):
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
if self.gradient_state.sync_gradients:
|
| 163 |
+
if self.scaler is not None:
|
| 164 |
+
self.optimizer.step = self._optimizer_patched_step_method
|
| 165 |
+
|
| 166 |
+
self.scaler.step(self.optimizer, closure)
|
| 167 |
+
self.scaler.update()
|
| 168 |
+
|
| 169 |
+
if not self._accelerate_step_called:
|
| 170 |
+
# If the optimizer step was skipped, gradient overflow was detected.
|
| 171 |
+
self._is_overflow = True
|
| 172 |
+
else:
|
| 173 |
+
self._is_overflow = False
|
| 174 |
+
# Reset the step method to the original one
|
| 175 |
+
self.optimizer.step = self._optimizer_original_step_method
|
| 176 |
+
# Reset the indicator
|
| 177 |
+
self._accelerate_step_called = False
|
| 178 |
+
else:
|
| 179 |
+
self.optimizer.step(closure)
|
| 180 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA:
|
| 181 |
+
self.gradient_state.is_xla_gradients_synced = False
|
| 182 |
+
|
| 183 |
+
def _switch_parameters(self, parameters_map):
|
| 184 |
+
for param_group in self.optimizer.param_groups:
|
| 185 |
+
param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def step_was_skipped(self):
|
| 189 |
+
"""Whether or not the optimizer step was skipped."""
|
| 190 |
+
return self._is_overflow
|
| 191 |
+
|
| 192 |
+
def __getstate__(self):
|
| 193 |
+
_ignored_keys = [
|
| 194 |
+
"_accelerate_step_called",
|
| 195 |
+
"_optimizer_original_step_method",
|
| 196 |
+
"_optimizer_patched_step_method",
|
| 197 |
+
]
|
| 198 |
+
return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
|
| 199 |
+
|
| 200 |
+
def __setstate__(self, state):
|
| 201 |
+
self.__dict__.update(state)
|
| 202 |
+
if self.scaler is not None:
|
| 203 |
+
self._accelerate_step_called = False
|
| 204 |
+
self._optimizer_original_step_method = self.optimizer.step
|
| 205 |
+
self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
|
| 209 |
+
def patched_step(*args, **kwargs):
|
| 210 |
+
accelerated_optimizer._accelerate_step_called = True
|
| 211 |
+
return method(*args, **kwargs)
|
| 212 |
+
|
| 213 |
+
return patched_step
|
pythonProject/.venv/Lib/site-packages/accelerate/parallelism_config.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import warnings
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 19 |
+
|
| 20 |
+
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
|
| 21 |
+
from accelerate.utils.versions import is_torch_version
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from accelerate import Accelerator
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class ParallelismConfig:
|
| 30 |
+
"""
|
| 31 |
+
A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
|
| 32 |
+
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
dp_replicate_size (`int`, defaults to `1`):
|
| 36 |
+
The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
|
| 37 |
+
group will not be used.
|
| 38 |
+
dp_shard_size (`int`, defaults to `1`):
|
| 39 |
+
The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
|
| 40 |
+
be greater than 1, as composing DDP + TP is currently not supported.
|
| 41 |
+
tp_size (`int`, defaults to `1`):
|
| 42 |
+
The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
|
| 43 |
+
used.
|
| 44 |
+
cp_size (`int`, defaults to `1`):
|
| 45 |
+
The size of the context parallel group. Currently not supported, but reserved for future use and enabled
|
| 46 |
+
for downstream libraries.
|
| 47 |
+
tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
|
| 48 |
+
The handler for the tensor parallel group.
|
| 49 |
+
|
| 50 |
+
You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
|
| 51 |
+
together:
|
| 52 |
+
- `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
|
| 53 |
+
- `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
|
| 54 |
+
- `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
|
| 55 |
+
`DistributedDataParallelKwargs` instead.
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
dp_replicate_size: int = None
|
| 60 |
+
dp_shard_size: int = None
|
| 61 |
+
tp_size: int = None
|
| 62 |
+
cp_size: int = None
|
| 63 |
+
|
| 64 |
+
# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
|
| 65 |
+
tp_handler: Union[None, TorchTensorParallelConfig] = None
|
| 66 |
+
cp_handler: Union[None, TorchContextParallelConfig] = None
|
| 67 |
+
|
| 68 |
+
device_mesh = None
|
| 69 |
+
|
| 70 |
+
def __repr__(self):
|
| 71 |
+
return (
|
| 72 |
+
"ParallelismConfig(\n "
|
| 73 |
+
f"\tdp_replicate_size={self.dp_replicate_size},\n"
|
| 74 |
+
f"\tdp_shard_size={self.dp_shard_size},\n"
|
| 75 |
+
f"\ttp_size={self.tp_size},\n"
|
| 76 |
+
f"\tcp_size={self.cp_size},\n"
|
| 77 |
+
f"\ttotal_size={self.total_size}\n"
|
| 78 |
+
f"\ttp_handler={self.tp_handler},\n"
|
| 79 |
+
f"\tcp_handler={self.cp_handler})\n"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def to_json(self):
|
| 83 |
+
import copy
|
| 84 |
+
|
| 85 |
+
_non_serializable_fields = ["device_mesh"]
|
| 86 |
+
|
| 87 |
+
copy.deepcopy(
|
| 88 |
+
{
|
| 89 |
+
k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
|
| 90 |
+
for k, v in self.__dict__.items()
|
| 91 |
+
if k not in _non_serializable_fields
|
| 92 |
+
}
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def dp_dim_names(self):
|
| 97 |
+
"""Names of enabled dimensions across which data parallelism is applied."""
|
| 98 |
+
dims = []
|
| 99 |
+
if self.dp_replicate_enabled:
|
| 100 |
+
dims += ["dp_replicate"]
|
| 101 |
+
if self.dp_shard_enabled:
|
| 102 |
+
dims += ["dp_shard"]
|
| 103 |
+
return dims
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def non_dp_dim_names(self):
|
| 107 |
+
"""Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
|
| 108 |
+
dims = []
|
| 109 |
+
if self.tp_enabled:
|
| 110 |
+
dims += ["tp"]
|
| 111 |
+
if self.cp_enabled:
|
| 112 |
+
dims += ["cp"]
|
| 113 |
+
return dims
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def dp_shard_cp_dim_names(self):
|
| 117 |
+
"""Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
|
| 118 |
+
dims = []
|
| 119 |
+
if self.dp_shard_enabled:
|
| 120 |
+
dims += ["dp_shard"]
|
| 121 |
+
if self.cp_enabled:
|
| 122 |
+
dims += ["cp"]
|
| 123 |
+
return dims
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def dp_cp_dim_names(self):
|
| 127 |
+
"""Names of enabled dimensions across which loss should be averaged"""
|
| 128 |
+
dims = []
|
| 129 |
+
if self.dp_replicate_enabled:
|
| 130 |
+
dims += ["dp_replicate"]
|
| 131 |
+
if self.dp_shard_enabled:
|
| 132 |
+
dims += ["dp_shard"]
|
| 133 |
+
if self.cp_enabled:
|
| 134 |
+
dims += ["cp"]
|
| 135 |
+
return dims
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def fsdp_dim_names(self):
|
| 139 |
+
"""Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
|
| 140 |
+
dims = []
|
| 141 |
+
if self.dp_replicate_enabled:
|
| 142 |
+
dims += ["dp_replicate"]
|
| 143 |
+
dims += ["dp_shard_cp"]
|
| 144 |
+
return dims
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def total_size(self):
|
| 148 |
+
"""The total size of the parallelism configuration, which is the product of all sizes."""
|
| 149 |
+
return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def non_data_parallel_size(self):
|
| 153 |
+
"""The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
|
| 154 |
+
return self.tp_size * self.cp_size
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def data_parallel_size(self):
|
| 158 |
+
"""The size of the data parallel dimensions, which is the product of data parallel replication and"""
|
| 159 |
+
return self.dp_replicate_size * self.dp_shard_size
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def dp_replicate_enabled(self):
|
| 163 |
+
"""True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
|
| 164 |
+
return self.dp_replicate_size > 1
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def dp_shard_enabled(self):
|
| 168 |
+
"""True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
|
| 169 |
+
return self.dp_shard_size > 1
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def tp_enabled(self):
|
| 173 |
+
"""True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
|
| 174 |
+
return self.tp_size > 1
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def cp_enabled(self):
|
| 178 |
+
"""True if context parallelism is enabled, i.e. `cp_size > 1`."""
|
| 179 |
+
return self.cp_size > 1
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def active_mesh_dims(self):
|
| 183 |
+
"""Names of all active mesh dimensions."""
|
| 184 |
+
return self.dp_dim_names + self.non_dp_dim_names
|
| 185 |
+
|
| 186 |
+
def build_device_mesh(self, device_type: str):
|
| 187 |
+
"""Builds a device mesh for the given device type based on the parallelism configuration.
|
| 188 |
+
This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
device_type (`str`): The type of device for which to build the mesh, e
|
| 192 |
+
"""
|
| 193 |
+
if is_torch_version(">=", "2.2.0"):
|
| 194 |
+
from torch.distributed.device_mesh import init_device_mesh
|
| 195 |
+
else:
|
| 196 |
+
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
|
| 197 |
+
|
| 198 |
+
mesh = self._get_mesh()
|
| 199 |
+
if len(mesh) == 0:
|
| 200 |
+
return None
|
| 201 |
+
mesh_dim_names, mesh_shape = mesh
|
| 202 |
+
device_mesh = init_device_mesh(
|
| 203 |
+
device_type,
|
| 204 |
+
mesh_shape,
|
| 205 |
+
mesh_dim_names=mesh_dim_names,
|
| 206 |
+
)
|
| 207 |
+
if self.dp_dim_names:
|
| 208 |
+
device_mesh[self.dp_dim_names]._flatten("dp")
|
| 209 |
+
if self.dp_shard_cp_dim_names:
|
| 210 |
+
device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
|
| 211 |
+
if self.dp_cp_dim_names:
|
| 212 |
+
device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
|
| 213 |
+
|
| 214 |
+
return device_mesh
|
| 215 |
+
|
| 216 |
+
def get_device_mesh(self, device_type: Optional[str] = None):
|
| 217 |
+
if self.device_mesh is None:
|
| 218 |
+
if device_type is not None:
|
| 219 |
+
self.device_mesh = self.build_device_mesh(device_type)
|
| 220 |
+
else:
|
| 221 |
+
raise ("You need to pass a device_type e.g cuda to build the device mesh")
|
| 222 |
+
else:
|
| 223 |
+
if device_type is not None:
|
| 224 |
+
if self.device_mesh.device_type != device_type:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
|
| 227 |
+
)
|
| 228 |
+
return self.device_mesh
|
| 229 |
+
|
| 230 |
+
def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
|
| 231 |
+
"""Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
|
| 232 |
+
|
| 233 |
+
# Build mesh dimensions dictionary
|
| 234 |
+
mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
|
| 235 |
+
|
| 236 |
+
# Apply canonical ordering
|
| 237 |
+
mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"]
|
| 238 |
+
sorted_items = sorted(
|
| 239 |
+
mesh_dims.items(),
|
| 240 |
+
key=lambda x: (mesh_order.index(x[0])),
|
| 241 |
+
)
|
| 242 |
+
return tuple(zip(*sorted_items))
|
| 243 |
+
|
| 244 |
+
def __post_init__(self):
|
| 245 |
+
# Basic size validation
|
| 246 |
+
if self.dp_replicate_size is None:
|
| 247 |
+
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
|
| 248 |
+
if self.dp_shard_size is None:
|
| 249 |
+
self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
|
| 250 |
+
if self.tp_size is None:
|
| 251 |
+
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
|
| 252 |
+
if self.cp_size is None:
|
| 253 |
+
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
|
| 254 |
+
|
| 255 |
+
if self.tp_size > 1:
|
| 256 |
+
if self.tp_handler is None:
|
| 257 |
+
self.tp_handler = TorchTensorParallelConfig()
|
| 258 |
+
|
| 259 |
+
if self.cp_size > 1:
|
| 260 |
+
if self.cp_handler is None:
|
| 261 |
+
self.cp_handler = TorchContextParallelConfig()
|
| 262 |
+
|
| 263 |
+
if self.dp_replicate_size < 1:
|
| 264 |
+
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
|
| 265 |
+
if self.dp_shard_size < 1:
|
| 266 |
+
raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
|
| 267 |
+
if self.tp_size < 1:
|
| 268 |
+
raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
|
| 269 |
+
if self.cp_size < 1:
|
| 270 |
+
raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
|
| 271 |
+
|
| 272 |
+
if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
|
| 273 |
+
raise ValueError(
|
| 274 |
+
"Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
|
| 275 |
+
"Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
|
| 276 |
+
"or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
|
| 277 |
+
)
|
| 278 |
+
self._sizes = {
|
| 279 |
+
"dp_replicate": self.dp_replicate_size,
|
| 280 |
+
"dp_shard": self.dp_shard_size,
|
| 281 |
+
"tp": self.tp_size,
|
| 282 |
+
"cp": self.cp_size,
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
def _set_size(self, parallelism: str, size: int):
|
| 286 |
+
assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
|
| 287 |
+
self._sizes[parallelism] = size
|
| 288 |
+
setattr(self, f"{parallelism}_size", size)
|
| 289 |
+
|
| 290 |
+
def _validate_accelerator(self, accelerator: "Accelerator"):
|
| 291 |
+
_warnings = set()
|
| 292 |
+
if not accelerator.multi_device and self.total_size == 1:
|
| 293 |
+
# No distributed setup, valid parallelism config
|
| 294 |
+
return
|
| 295 |
+
|
| 296 |
+
# We need this to ensure DDP works
|
| 297 |
+
if self.total_size == 1:
|
| 298 |
+
self._set_size("dp_replicate", accelerator.num_processes)
|
| 299 |
+
|
| 300 |
+
if self.total_size != accelerator.num_processes:
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
| 303 |
+
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
| 304 |
+
f"dp_shard_size/tp_size/cp_size."
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device):
|
| 308 |
+
raise ValueError(
|
| 309 |
+
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
for parallelism, size in self._sizes.items():
|
| 313 |
+
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
| 314 |
+
_warnings.add(
|
| 315 |
+
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if _warnings and accelerator.is_main_process:
|
| 319 |
+
warnings.warn(
|
| 320 |
+
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
| 321 |
+
UserWarning,
|
| 322 |
+
)
|
pythonProject/.venv/Lib/site-packages/accelerate/scheduler.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from .state import AcceleratorState, GradientState
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AcceleratedScheduler:
|
| 26 |
+
"""
|
| 27 |
+
A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
|
| 28 |
+
to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
|
| 29 |
+
precision training)
|
| 30 |
+
|
| 31 |
+
When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
|
| 32 |
+
step the scheduler to account for it.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
scheduler (`torch.optim.lr_scheduler._LRScheduler`):
|
| 36 |
+
The scheduler to wrap.
|
| 37 |
+
optimizers (one or a list of `torch.optim.Optimizer`):
|
| 38 |
+
The optimizers used.
|
| 39 |
+
step_with_optimizer (`bool`, *optional*, defaults to `True`):
|
| 40 |
+
Whether or not the scheduler should be stepped at each optimizer step.
|
| 41 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
| 42 |
+
Whether or not the dataloaders split one batch across the different processes (so batch size is the same
|
| 43 |
+
regardless of the number of processes) or create batches on each process (so batch size is the original
|
| 44 |
+
batch size multiplied by the number of processes).
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
|
| 48 |
+
self.scheduler = scheduler
|
| 49 |
+
self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
|
| 50 |
+
self.split_batches = split_batches
|
| 51 |
+
self.step_with_optimizer = step_with_optimizer
|
| 52 |
+
self.gradient_state = GradientState()
|
| 53 |
+
|
| 54 |
+
def step(self, *args, **kwargs):
|
| 55 |
+
if not self.step_with_optimizer:
|
| 56 |
+
# No link between scheduler and optimizer -> just step
|
| 57 |
+
self.scheduler.step(*args, **kwargs)
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
# Otherwise, first make sure the optimizer was stepped.
|
| 61 |
+
if not self.gradient_state.sync_gradients:
|
| 62 |
+
if self.gradient_state.adjust_scheduler:
|
| 63 |
+
self.scheduler._step_count += 1
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
for opt in self.optimizers:
|
| 67 |
+
if opt.step_was_skipped:
|
| 68 |
+
return
|
| 69 |
+
if self.split_batches:
|
| 70 |
+
# Split batches -> the training dataloader batch size is not changed so one step per training step
|
| 71 |
+
self.scheduler.step(*args, **kwargs)
|
| 72 |
+
else:
|
| 73 |
+
# Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
|
| 74 |
+
# num_processes steps per training step
|
| 75 |
+
num_processes = AcceleratorState().num_processes
|
| 76 |
+
for _ in range(num_processes):
|
| 77 |
+
# Special case when using OneCycle and `drop_last` was not used
|
| 78 |
+
if hasattr(self.scheduler, "total_steps"):
|
| 79 |
+
if self.scheduler._step_count <= self.scheduler.total_steps:
|
| 80 |
+
self.scheduler.step(*args, **kwargs)
|
| 81 |
+
else:
|
| 82 |
+
self.scheduler.step(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
# Passthroughs
|
| 85 |
+
def get_last_lr(self):
|
| 86 |
+
return self.scheduler.get_last_lr()
|
| 87 |
+
|
| 88 |
+
def state_dict(self):
|
| 89 |
+
return self.scheduler.state_dict()
|
| 90 |
+
|
| 91 |
+
def load_state_dict(self, state_dict):
|
| 92 |
+
self.scheduler.load_state_dict(state_dict)
|
| 93 |
+
|
| 94 |
+
def get_lr(self):
|
| 95 |
+
return self.scheduler.get_lr()
|
| 96 |
+
|
| 97 |
+
def print_lr(self, *args, **kwargs):
|
| 98 |
+
return self.scheduler.print_lr(*args, **kwargs)
|
pythonProject/.venv/Lib/site-packages/isympy.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Python shell for SymPy.
|
| 3 |
+
|
| 4 |
+
This is just a normal Python shell (IPython shell if you have the
|
| 5 |
+
IPython package installed), that executes the following commands for
|
| 6 |
+
the user:
|
| 7 |
+
|
| 8 |
+
>>> from __future__ import division
|
| 9 |
+
>>> from sympy import *
|
| 10 |
+
>>> x, y, z, t = symbols('x y z t')
|
| 11 |
+
>>> k, m, n = symbols('k m n', integer=True)
|
| 12 |
+
>>> f, g, h = symbols('f g h', cls=Function)
|
| 13 |
+
>>> init_printing()
|
| 14 |
+
|
| 15 |
+
So starting 'isympy' is equivalent to starting Python (or IPython) and
|
| 16 |
+
executing the above commands by hand. It is intended for easy and quick
|
| 17 |
+
experimentation with SymPy. isympy is a good way to use SymPy as an
|
| 18 |
+
interactive calculator. If you have IPython and Matplotlib installed, then
|
| 19 |
+
interactive plotting is enabled by default.
|
| 20 |
+
|
| 21 |
+
COMMAND LINE OPTIONS
|
| 22 |
+
--------------------
|
| 23 |
+
|
| 24 |
+
-c CONSOLE, --console=CONSOLE
|
| 25 |
+
|
| 26 |
+
Use the specified shell (Python or IPython) shell as the console
|
| 27 |
+
backend instead of the default one (IPython if present, Python
|
| 28 |
+
otherwise), e.g.:
|
| 29 |
+
|
| 30 |
+
$isympy -c python
|
| 31 |
+
|
| 32 |
+
CONSOLE must be one of 'ipython' or 'python'
|
| 33 |
+
|
| 34 |
+
-p PRETTY, --pretty PRETTY
|
| 35 |
+
|
| 36 |
+
Setup pretty-printing in SymPy. When pretty-printing is enabled,
|
| 37 |
+
expressions can be printed with Unicode or ASCII. The default is
|
| 38 |
+
to use pretty-printing (with Unicode if the terminal supports it).
|
| 39 |
+
When this option is 'no', expressions will not be pretty-printed
|
| 40 |
+
and ASCII will be used:
|
| 41 |
+
|
| 42 |
+
$isympy -p no
|
| 43 |
+
|
| 44 |
+
PRETTY must be one of 'unicode', 'ascii', or 'no'
|
| 45 |
+
|
| 46 |
+
-t TYPES, --types=TYPES
|
| 47 |
+
|
| 48 |
+
Setup the ground types for the polys. By default, gmpy ground types
|
| 49 |
+
are used if gmpy2 or gmpy is installed, otherwise it falls back to python
|
| 50 |
+
ground types, which are a little bit slower. You can manually
|
| 51 |
+
choose python ground types even if gmpy is installed (e.g., for
|
| 52 |
+
testing purposes):
|
| 53 |
+
|
| 54 |
+
$isympy -t python
|
| 55 |
+
|
| 56 |
+
TYPES must be one of 'gmpy', 'gmpy1' or 'python'
|
| 57 |
+
|
| 58 |
+
Note that the ground type gmpy1 is primarily intended for testing; it
|
| 59 |
+
forces the use of gmpy version 1 even if gmpy2 is available.
|
| 60 |
+
|
| 61 |
+
This is the same as setting the environment variable
|
| 62 |
+
SYMPY_GROUND_TYPES to the given ground type (e.g.,
|
| 63 |
+
SYMPY_GROUND_TYPES='gmpy')
|
| 64 |
+
|
| 65 |
+
The ground types can be determined interactively from the variable
|
| 66 |
+
sympy.polys.domains.GROUND_TYPES.
|
| 67 |
+
|
| 68 |
+
-o ORDER, --order ORDER
|
| 69 |
+
|
| 70 |
+
Setup the ordering of terms for printing. The default is lex, which
|
| 71 |
+
orders terms lexicographically (e.g., x**2 + x + 1). You can choose
|
| 72 |
+
other orderings, such as rev-lex, which will use reverse
|
| 73 |
+
lexicographic ordering (e.g., 1 + x + x**2):
|
| 74 |
+
|
| 75 |
+
$isympy -o rev-lex
|
| 76 |
+
|
| 77 |
+
ORDER must be one of 'lex', 'rev-lex', 'grlex', 'rev-grlex',
|
| 78 |
+
'grevlex', 'rev-grevlex', 'old', or 'none'.
|
| 79 |
+
|
| 80 |
+
Note that for very large expressions, ORDER='none' may speed up
|
| 81 |
+
printing considerably but the terms will have no canonical order.
|
| 82 |
+
|
| 83 |
+
-q, --quiet
|
| 84 |
+
|
| 85 |
+
Print only Python's and SymPy's versions to stdout at startup.
|
| 86 |
+
|
| 87 |
+
-d, --doctest
|
| 88 |
+
|
| 89 |
+
Use the same format that should be used for doctests. This is
|
| 90 |
+
equivalent to -c python -p no.
|
| 91 |
+
|
| 92 |
+
-C, --no-cache
|
| 93 |
+
|
| 94 |
+
Disable the caching mechanism. Disabling the cache may slow certain
|
| 95 |
+
operations down considerably. This is useful for testing the cache,
|
| 96 |
+
or for benchmarking, as the cache can result in deceptive timings.
|
| 97 |
+
|
| 98 |
+
This is equivalent to setting the environment variable
|
| 99 |
+
SYMPY_USE_CACHE to 'no'.
|
| 100 |
+
|
| 101 |
+
-a, --auto-symbols (requires at least IPython 0.11)
|
| 102 |
+
|
| 103 |
+
Automatically create missing symbols. Normally, typing a name of a
|
| 104 |
+
Symbol that has not been instantiated first would raise NameError,
|
| 105 |
+
but with this option enabled, any undefined name will be
|
| 106 |
+
automatically created as a Symbol.
|
| 107 |
+
|
| 108 |
+
Note that this is intended only for interactive, calculator style
|
| 109 |
+
usage. In a script that uses SymPy, Symbols should be instantiated
|
| 110 |
+
at the top, so that it's clear what they are.
|
| 111 |
+
|
| 112 |
+
This will not override any names that are already defined, which
|
| 113 |
+
includes the single character letters represented by the mnemonic
|
| 114 |
+
QCOSINE (see the "Gotchas and Pitfalls" document in the
|
| 115 |
+
documentation). You can delete existing names by executing "del
|
| 116 |
+
name". If a name is defined, typing "'name' in dir()" will return True.
|
| 117 |
+
|
| 118 |
+
The Symbols that are created using this have default assumptions.
|
| 119 |
+
If you want to place assumptions on symbols, you should create them
|
| 120 |
+
using symbols() or var().
|
| 121 |
+
|
| 122 |
+
Finally, this only works in the top level namespace. So, for
|
| 123 |
+
example, if you define a function in isympy with an undefined
|
| 124 |
+
Symbol, it will not work.
|
| 125 |
+
|
| 126 |
+
See also the -i and -I options.
|
| 127 |
+
|
| 128 |
+
-i, --int-to-Integer (requires at least IPython 0.11)
|
| 129 |
+
|
| 130 |
+
Automatically wrap int literals with Integer. This makes it so that
|
| 131 |
+
things like 1/2 will come out as Rational(1, 2), rather than 0.5. This
|
| 132 |
+
works by preprocessing the source and wrapping all int literals with
|
| 133 |
+
Integer. Note that this will not change the behavior of int literals
|
| 134 |
+
assigned to variables, and it also won't change the behavior of functions
|
| 135 |
+
that return int literals.
|
| 136 |
+
|
| 137 |
+
If you want an int, you can wrap the literal in int(), e.g. int(3)/int(2)
|
| 138 |
+
gives 1.5 (with division imported from __future__).
|
| 139 |
+
|
| 140 |
+
-I, --interactive (requires at least IPython 0.11)
|
| 141 |
+
|
| 142 |
+
This is equivalent to --auto-symbols --int-to-Integer. Future options
|
| 143 |
+
designed for ease of interactive use may be added to this.
|
| 144 |
+
|
| 145 |
+
-D, --debug
|
| 146 |
+
|
| 147 |
+
Enable debugging output. This is the same as setting the
|
| 148 |
+
environment variable SYMPY_DEBUG to 'True'. The debug status is set
|
| 149 |
+
in the variable SYMPY_DEBUG within isympy.
|
| 150 |
+
|
| 151 |
+
-- IPython options
|
| 152 |
+
|
| 153 |
+
Additionally you can pass command line options directly to the IPython
|
| 154 |
+
interpreter (the standard Python shell is not supported). However you
|
| 155 |
+
need to add the '--' separator between two types of options, e.g the
|
| 156 |
+
startup banner option and the colors option. You need to enter the
|
| 157 |
+
options as required by the version of IPython that you are using, too:
|
| 158 |
+
|
| 159 |
+
in IPython 0.11,
|
| 160 |
+
|
| 161 |
+
$isympy -q -- --colors=NoColor
|
| 162 |
+
|
| 163 |
+
or older versions of IPython,
|
| 164 |
+
|
| 165 |
+
$isympy -q -- -colors NoColor
|
| 166 |
+
|
| 167 |
+
See also isympy --help.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
import os
|
| 171 |
+
import sys
|
| 172 |
+
|
| 173 |
+
# DO NOT IMPORT SYMPY HERE! Or the setting of the sympy environment variables
|
| 174 |
+
# by the command line will break.
|
| 175 |
+
|
| 176 |
+
def main() -> None:
|
| 177 |
+
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
| 178 |
+
|
| 179 |
+
VERSION = None
|
| 180 |
+
if '--version' in sys.argv:
|
| 181 |
+
# We cannot import sympy before this is run, because flags like -C and
|
| 182 |
+
# -t set environment variables that must be set before SymPy is
|
| 183 |
+
# imported. The only thing we need to import it for is to get the
|
| 184 |
+
# version, which only matters with the --version flag.
|
| 185 |
+
import sympy
|
| 186 |
+
VERSION = sympy.__version__
|
| 187 |
+
|
| 188 |
+
usage = 'isympy [options] -- [ipython options]'
|
| 189 |
+
parser = ArgumentParser(
|
| 190 |
+
usage=usage,
|
| 191 |
+
description=__doc__,
|
| 192 |
+
formatter_class=RawDescriptionHelpFormatter,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
parser.add_argument('--version', action='version', version=VERSION)
|
| 196 |
+
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
'-c', '--console',
|
| 199 |
+
dest='console',
|
| 200 |
+
action='store',
|
| 201 |
+
default=None,
|
| 202 |
+
choices=['ipython', 'python'],
|
| 203 |
+
metavar='CONSOLE',
|
| 204 |
+
help='select type of interactive session: ipython | python; defaults '
|
| 205 |
+
'to ipython if IPython is installed, otherwise python')
|
| 206 |
+
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
'-p', '--pretty',
|
| 209 |
+
dest='pretty',
|
| 210 |
+
action='store',
|
| 211 |
+
default=None,
|
| 212 |
+
metavar='PRETTY',
|
| 213 |
+
choices=['unicode', 'ascii', 'no'],
|
| 214 |
+
help='setup pretty printing: unicode | ascii | no; defaults to '
|
| 215 |
+
'unicode printing if the terminal supports it, otherwise ascii')
|
| 216 |
+
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
'-t', '--types',
|
| 219 |
+
dest='types',
|
| 220 |
+
action='store',
|
| 221 |
+
default=None,
|
| 222 |
+
metavar='TYPES',
|
| 223 |
+
choices=['gmpy', 'gmpy1', 'python'],
|
| 224 |
+
help='setup ground types: gmpy | gmpy1 | python; defaults to gmpy if gmpy2 '
|
| 225 |
+
'or gmpy is installed, otherwise python')
|
| 226 |
+
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
'-o', '--order',
|
| 229 |
+
dest='order',
|
| 230 |
+
action='store',
|
| 231 |
+
default=None,
|
| 232 |
+
metavar='ORDER',
|
| 233 |
+
choices=['lex', 'grlex', 'grevlex', 'rev-lex', 'rev-grlex', 'rev-grevlex', 'old', 'none'],
|
| 234 |
+
help='setup ordering of terms: [rev-]lex | [rev-]grlex | [rev-]grevlex | old | none; defaults to lex')
|
| 235 |
+
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
'-q', '--quiet',
|
| 238 |
+
dest='quiet',
|
| 239 |
+
action='store_true',
|
| 240 |
+
default=False,
|
| 241 |
+
help='print only version information at startup')
|
| 242 |
+
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
'-d', '--doctest',
|
| 245 |
+
dest='doctest',
|
| 246 |
+
action='store_true',
|
| 247 |
+
default=False,
|
| 248 |
+
help='use the doctest format for output (you can just copy and paste it)')
|
| 249 |
+
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
'-C', '--no-cache',
|
| 252 |
+
dest='cache',
|
| 253 |
+
action='store_false',
|
| 254 |
+
default=True,
|
| 255 |
+
help='disable caching mechanism')
|
| 256 |
+
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
'-a', '--auto-symbols',
|
| 259 |
+
dest='auto_symbols',
|
| 260 |
+
action='store_true',
|
| 261 |
+
default=False,
|
| 262 |
+
help='automatically construct missing symbols')
|
| 263 |
+
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
'-i', '--int-to-Integer',
|
| 266 |
+
dest='auto_int_to_Integer',
|
| 267 |
+
action='store_true',
|
| 268 |
+
default=False,
|
| 269 |
+
help="automatically wrap int literals with Integer")
|
| 270 |
+
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
'-I', '--interactive',
|
| 273 |
+
dest='interactive',
|
| 274 |
+
action='store_true',
|
| 275 |
+
default=False,
|
| 276 |
+
help="equivalent to -a -i")
|
| 277 |
+
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
'-D', '--debug',
|
| 280 |
+
dest='debug',
|
| 281 |
+
action='store_true',
|
| 282 |
+
default=False,
|
| 283 |
+
help='enable debugging output')
|
| 284 |
+
|
| 285 |
+
(options, ipy_args) = parser.parse_known_args()
|
| 286 |
+
if '--' in ipy_args:
|
| 287 |
+
ipy_args.remove('--')
|
| 288 |
+
|
| 289 |
+
if not options.cache:
|
| 290 |
+
os.environ['SYMPY_USE_CACHE'] = 'no'
|
| 291 |
+
|
| 292 |
+
if options.types:
|
| 293 |
+
os.environ['SYMPY_GROUND_TYPES'] = options.types
|
| 294 |
+
|
| 295 |
+
if options.debug:
|
| 296 |
+
os.environ['SYMPY_DEBUG'] = str(options.debug)
|
| 297 |
+
|
| 298 |
+
if options.doctest:
|
| 299 |
+
options.pretty = 'no'
|
| 300 |
+
options.console = 'python'
|
| 301 |
+
|
| 302 |
+
session = options.console
|
| 303 |
+
|
| 304 |
+
if session is not None:
|
| 305 |
+
ipython = session == 'ipython'
|
| 306 |
+
else:
|
| 307 |
+
try:
|
| 308 |
+
import IPython
|
| 309 |
+
ipython = True
|
| 310 |
+
except ImportError:
|
| 311 |
+
if not options.quiet:
|
| 312 |
+
from sympy.interactive.session import no_ipython
|
| 313 |
+
print(no_ipython)
|
| 314 |
+
ipython = False
|
| 315 |
+
|
| 316 |
+
args = {
|
| 317 |
+
'pretty_print': True,
|
| 318 |
+
'use_unicode': None,
|
| 319 |
+
'use_latex': None,
|
| 320 |
+
'order': None,
|
| 321 |
+
'argv': ipy_args,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
if options.pretty == 'unicode':
|
| 325 |
+
args['use_unicode'] = True
|
| 326 |
+
elif options.pretty == 'ascii':
|
| 327 |
+
args['use_unicode'] = False
|
| 328 |
+
elif options.pretty == 'no':
|
| 329 |
+
args['pretty_print'] = False
|
| 330 |
+
|
| 331 |
+
if options.order is not None:
|
| 332 |
+
args['order'] = options.order
|
| 333 |
+
|
| 334 |
+
args['quiet'] = options.quiet
|
| 335 |
+
args['auto_symbols'] = options.auto_symbols or options.interactive
|
| 336 |
+
args['auto_int_to_Integer'] = options.auto_int_to_Integer or options.interactive
|
| 337 |
+
|
| 338 |
+
from sympy.interactive import init_session
|
| 339 |
+
init_session(ipython, **args)
|
| 340 |
+
|
| 341 |
+
if __name__ == "__main__":
|
| 342 |
+
main()
|
pythonProject/.venv/Lib/site-packages/numpy-2.2.6-cp310-cp310-win_amd64.whl
ADDED
|
File without changes
|
pythonProject/.venv/Lib/site-packages/typing_extensions.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pythonProject/.venv/pyvenv.cfg
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
home = C:\Users\ADMIN\AppData\Local\Programs\Python\Python310
|
| 2 |
+
include-system-site-packages = false
|
| 3 |
+
version = 3.10.11
|