Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/override.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/tpu.py +157 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_ds_multiple_model.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/__init__.py +306 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/ao.py +140 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/bnb.py +469 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/constants.py +106 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/dataclasses.py +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/deepspeed.py +385 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/environment.py +471 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/fsdp_utils.py +829 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/imports.py +564 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/launch.py +781 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/megatron_lm.py +1424 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/memory.py +210 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/modeling.py +2186 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/offload.py +213 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py +867 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/other.py +561 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/random.py +156 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/rich.py +24 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/torch_xla.py +51 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/tqdm.py +43 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py +186 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/versions.py +56 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/main.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/cuda_pathfinder-1.4.0.dist-info/licenses/LICENSE +177 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__version__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align_getter.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_base.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_column.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_common.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_container.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_converter.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_dataproperty.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_extractor.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_formatter.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_function.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_interface.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_line_break.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_preprocessor.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/typing.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__init__.py +7 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_logger.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_null_logger.cpython-312.pyc +0 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/override.cpython-312.pyc
ADDED
|
Binary file (322 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/tpu.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import subprocess
|
| 20 |
+
|
| 21 |
+
from packaging.version import Version, parse
|
| 22 |
+
|
| 23 |
+
from accelerate.commands.config.config_args import default_config_file, load_config_from_file
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_description = "Run commands across TPU VMs for initial setup before running `accelerate launch`."
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def tpu_command_parser(subparsers=None):
|
| 30 |
+
if subparsers is not None:
|
| 31 |
+
parser = subparsers.add_parser("tpu-config", description=_description)
|
| 32 |
+
else:
|
| 33 |
+
parser = argparse.ArgumentParser("Accelerate tpu-config command", description=_description)
|
| 34 |
+
# Core arguments
|
| 35 |
+
config_args = parser.add_argument_group(
|
| 36 |
+
"Config Arguments", "Arguments that can be configured through `accelerate config`."
|
| 37 |
+
)
|
| 38 |
+
config_args.add_argument(
|
| 39 |
+
"--config_file",
|
| 40 |
+
type=str,
|
| 41 |
+
default=None,
|
| 42 |
+
help="Path to the config file to use for accelerate.",
|
| 43 |
+
)
|
| 44 |
+
config_args.add_argument(
|
| 45 |
+
"--tpu_name",
|
| 46 |
+
default=None,
|
| 47 |
+
help="The name of the TPU to use. If not specified, will use the TPU specified in the config file.",
|
| 48 |
+
)
|
| 49 |
+
config_args.add_argument(
|
| 50 |
+
"--tpu_zone",
|
| 51 |
+
default=None,
|
| 52 |
+
help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.",
|
| 53 |
+
)
|
| 54 |
+
pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.")
|
| 55 |
+
pod_args.add_argument(
|
| 56 |
+
"--use_alpha",
|
| 57 |
+
action="store_true",
|
| 58 |
+
help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.",
|
| 59 |
+
)
|
| 60 |
+
pod_args.add_argument(
|
| 61 |
+
"--command_file",
|
| 62 |
+
default=None,
|
| 63 |
+
help="The path to the file containing the commands to run on the pod on startup.",
|
| 64 |
+
)
|
| 65 |
+
pod_args.add_argument(
|
| 66 |
+
"--command",
|
| 67 |
+
action="append",
|
| 68 |
+
nargs="+",
|
| 69 |
+
help="A command to run on the pod. Can be passed multiple times.",
|
| 70 |
+
)
|
| 71 |
+
pod_args.add_argument(
|
| 72 |
+
"--install_accelerate",
|
| 73 |
+
action="store_true",
|
| 74 |
+
help="Whether to install accelerate on the pod. Defaults to False.",
|
| 75 |
+
)
|
| 76 |
+
pod_args.add_argument(
|
| 77 |
+
"--accelerate_version",
|
| 78 |
+
default="latest",
|
| 79 |
+
help="The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.",
|
| 80 |
+
)
|
| 81 |
+
pod_args.add_argument(
|
| 82 |
+
"--debug", action="store_true", help="If set, will print the command that would be run instead of running it."
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if subparsers is not None:
|
| 86 |
+
parser.set_defaults(func=tpu_command_launcher)
|
| 87 |
+
return parser
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def tpu_command_launcher(args):
|
| 91 |
+
defaults = None
|
| 92 |
+
|
| 93 |
+
# Get the default from the config file if it exists.
|
| 94 |
+
if args.config_file is not None or os.path.isfile(default_config_file):
|
| 95 |
+
defaults = load_config_from_file(args.config_file)
|
| 96 |
+
if not args.command_file and defaults.command_file is not None and not args.command:
|
| 97 |
+
args.command_file = defaults.command_file
|
| 98 |
+
if not args.command and defaults.commands is not None:
|
| 99 |
+
args.command = defaults.commands
|
| 100 |
+
if not args.tpu_name:
|
| 101 |
+
args.tpu_name = defaults.tpu_name
|
| 102 |
+
if not args.tpu_zone:
|
| 103 |
+
args.tpu_zone = defaults.tpu_zone
|
| 104 |
+
if args.accelerate_version == "dev":
|
| 105 |
+
args.accelerate_version = "git+https://github.com/huggingface/accelerate.git"
|
| 106 |
+
elif args.accelerate_version == "latest":
|
| 107 |
+
args.accelerate_version = "accelerate -U"
|
| 108 |
+
elif isinstance(parse(args.accelerate_version), Version):
|
| 109 |
+
args.accelerate_version = f"accelerate=={args.accelerate_version}"
|
| 110 |
+
|
| 111 |
+
if not args.command_file and not args.command:
|
| 112 |
+
raise ValueError("You must specify either a command file or a command to run on the pod.")
|
| 113 |
+
|
| 114 |
+
if args.command_file:
|
| 115 |
+
with open(args.command_file) as f:
|
| 116 |
+
args.command = [f.read().splitlines()]
|
| 117 |
+
|
| 118 |
+
# To turn list of lists into list of strings
|
| 119 |
+
if isinstance(args.command[0], list):
|
| 120 |
+
args.command = [line for cmd in args.command for line in cmd]
|
| 121 |
+
# Default to the shared folder and install accelerate
|
| 122 |
+
new_cmd = ["cd /usr/share"]
|
| 123 |
+
if args.install_accelerate:
|
| 124 |
+
new_cmd += [f"pip install {args.accelerate_version}"]
|
| 125 |
+
new_cmd += args.command
|
| 126 |
+
args.command = "; ".join(new_cmd)
|
| 127 |
+
|
| 128 |
+
# Then send it to gcloud
|
| 129 |
+
# Eventually try to use google-api-core to do this instead of subprocess
|
| 130 |
+
cmd = ["gcloud"]
|
| 131 |
+
if args.use_alpha:
|
| 132 |
+
cmd += ["alpha"]
|
| 133 |
+
cmd += [
|
| 134 |
+
"compute",
|
| 135 |
+
"tpus",
|
| 136 |
+
"tpu-vm",
|
| 137 |
+
"ssh",
|
| 138 |
+
args.tpu_name,
|
| 139 |
+
"--zone",
|
| 140 |
+
args.tpu_zone,
|
| 141 |
+
"--command",
|
| 142 |
+
args.command,
|
| 143 |
+
"--worker",
|
| 144 |
+
"all",
|
| 145 |
+
]
|
| 146 |
+
if args.debug:
|
| 147 |
+
print(f"Running {' '.join(cmd)}")
|
| 148 |
+
return
|
| 149 |
+
subprocess.run(cmd)
|
| 150 |
+
print("Successfully setup pod.")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def main():
|
| 154 |
+
parser = tpu_command_parser()
|
| 155 |
+
args = parser.parse_args()
|
| 156 |
+
|
| 157 |
+
tpu_command_launcher(args)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_ds_multiple_model.cpython-312.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/__init__.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from ..parallelism_config import ParallelismConfig
|
| 15 |
+
from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
|
| 16 |
+
from .constants import (
|
| 17 |
+
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
|
| 18 |
+
MODEL_NAME,
|
| 19 |
+
OPTIMIZER_NAME,
|
| 20 |
+
PROFILE_PATTERN_NAME,
|
| 21 |
+
RNG_STATE_NAME,
|
| 22 |
+
SAFE_MODEL_NAME,
|
| 23 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
| 24 |
+
SAFE_WEIGHTS_NAME,
|
| 25 |
+
SAFE_WEIGHTS_PATTERN_NAME,
|
| 26 |
+
SAMPLER_NAME,
|
| 27 |
+
SCALER_NAME,
|
| 28 |
+
SCHEDULER_NAME,
|
| 29 |
+
TORCH_DISTRIBUTED_OPERATION_TYPES,
|
| 30 |
+
TORCH_LAUNCH_PARAMS,
|
| 31 |
+
WEIGHTS_INDEX_NAME,
|
| 32 |
+
WEIGHTS_NAME,
|
| 33 |
+
WEIGHTS_PATTERN_NAME,
|
| 34 |
+
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
|
| 35 |
+
)
|
| 36 |
+
from .dataclasses import (
|
| 37 |
+
AORecipeKwargs,
|
| 38 |
+
AutocastKwargs,
|
| 39 |
+
BnbQuantizationConfig,
|
| 40 |
+
ComputeEnvironment,
|
| 41 |
+
CustomDtype,
|
| 42 |
+
DataLoaderConfiguration,
|
| 43 |
+
DDPCommunicationHookType,
|
| 44 |
+
DeepSpeedPlugin,
|
| 45 |
+
DeepSpeedSequenceParallelConfig,
|
| 46 |
+
DistributedDataParallelKwargs,
|
| 47 |
+
DistributedType,
|
| 48 |
+
DynamoBackend,
|
| 49 |
+
FP8RecipeKwargs,
|
| 50 |
+
FullyShardedDataParallelPlugin,
|
| 51 |
+
GradientAccumulationPlugin,
|
| 52 |
+
GradScalerKwargs,
|
| 53 |
+
InitProcessGroupKwargs,
|
| 54 |
+
KwargsHandler,
|
| 55 |
+
LoggerType,
|
| 56 |
+
MegatronLMPlugin,
|
| 57 |
+
MSAMPRecipeKwargs,
|
| 58 |
+
PrecisionType,
|
| 59 |
+
ProfileKwargs,
|
| 60 |
+
ProjectConfiguration,
|
| 61 |
+
RNGType,
|
| 62 |
+
SageMakerDistributedType,
|
| 63 |
+
TensorInformation,
|
| 64 |
+
TERecipeKwargs,
|
| 65 |
+
TorchContextParallelConfig,
|
| 66 |
+
TorchDynamoPlugin,
|
| 67 |
+
TorchTensorParallelConfig,
|
| 68 |
+
TorchTensorParallelPlugin,
|
| 69 |
+
add_model_config_to_megatron_parser,
|
| 70 |
+
)
|
| 71 |
+
from .environment import (
|
| 72 |
+
are_libraries_initialized,
|
| 73 |
+
check_cuda_fp8_capability,
|
| 74 |
+
check_cuda_p2p_ib_support,
|
| 75 |
+
clear_environment,
|
| 76 |
+
convert_dict_to_env_variables,
|
| 77 |
+
get_cpu_distributed_information,
|
| 78 |
+
get_current_device_type,
|
| 79 |
+
get_gpu_info,
|
| 80 |
+
get_int_from_env,
|
| 81 |
+
parse_choice_from_env,
|
| 82 |
+
parse_flag_from_env,
|
| 83 |
+
patch_environment,
|
| 84 |
+
purge_accelerate_environment,
|
| 85 |
+
set_numa_affinity,
|
| 86 |
+
str_to_bool,
|
| 87 |
+
)
|
| 88 |
+
from .imports import (
|
| 89 |
+
deepspeed_required,
|
| 90 |
+
get_ccl_version,
|
| 91 |
+
is_4bit_bnb_available,
|
| 92 |
+
is_8bit_bnb_available,
|
| 93 |
+
is_aim_available,
|
| 94 |
+
is_bf16_available,
|
| 95 |
+
is_bitsandbytes_multi_backend_available,
|
| 96 |
+
is_bnb_available,
|
| 97 |
+
is_boto3_available,
|
| 98 |
+
is_ccl_available,
|
| 99 |
+
is_clearml_available,
|
| 100 |
+
is_comet_ml_available,
|
| 101 |
+
is_cuda_available,
|
| 102 |
+
is_datasets_available,
|
| 103 |
+
is_deepspeed_available,
|
| 104 |
+
is_dvclive_available,
|
| 105 |
+
is_fp8_available,
|
| 106 |
+
is_fp16_available,
|
| 107 |
+
is_habana_gaudi1,
|
| 108 |
+
is_hpu_available,
|
| 109 |
+
is_import_timer_available,
|
| 110 |
+
is_ipex_available,
|
| 111 |
+
is_lomo_available,
|
| 112 |
+
is_matplotlib_available,
|
| 113 |
+
is_megatron_lm_available,
|
| 114 |
+
is_mlflow_available,
|
| 115 |
+
is_mlu_available,
|
| 116 |
+
is_mps_available,
|
| 117 |
+
is_msamp_available,
|
| 118 |
+
is_musa_available,
|
| 119 |
+
is_npu_available,
|
| 120 |
+
is_pandas_available,
|
| 121 |
+
is_peft_available,
|
| 122 |
+
is_pippy_available,
|
| 123 |
+
is_pynvml_available,
|
| 124 |
+
is_pytest_available,
|
| 125 |
+
is_rich_available,
|
| 126 |
+
is_sagemaker_available,
|
| 127 |
+
is_schedulefree_available,
|
| 128 |
+
is_sdaa_available,
|
| 129 |
+
is_swanlab_available,
|
| 130 |
+
is_tensorboard_available,
|
| 131 |
+
is_timm_available,
|
| 132 |
+
is_torch_xla_available,
|
| 133 |
+
is_torchao_available,
|
| 134 |
+
is_torchdata_available,
|
| 135 |
+
is_torchdata_stateful_dataloader_available,
|
| 136 |
+
is_torchvision_available,
|
| 137 |
+
is_trackio_available,
|
| 138 |
+
is_transformer_engine_available,
|
| 139 |
+
is_transformer_engine_mxfp8_available,
|
| 140 |
+
is_transformers_available,
|
| 141 |
+
is_triton_available,
|
| 142 |
+
is_wandb_available,
|
| 143 |
+
is_weights_only_available,
|
| 144 |
+
is_xccl_available,
|
| 145 |
+
is_xpu_available,
|
| 146 |
+
torchao_required,
|
| 147 |
+
)
|
| 148 |
+
from .modeling import (
|
| 149 |
+
align_module_device,
|
| 150 |
+
calculate_maximum_sizes,
|
| 151 |
+
check_device_map,
|
| 152 |
+
check_tied_parameters_in_config,
|
| 153 |
+
check_tied_parameters_on_same_device,
|
| 154 |
+
compute_module_sizes,
|
| 155 |
+
convert_file_size_to_int,
|
| 156 |
+
dtype_byte_size,
|
| 157 |
+
find_tied_parameters,
|
| 158 |
+
get_balanced_memory,
|
| 159 |
+
get_grad_scaler,
|
| 160 |
+
get_max_layer_size,
|
| 161 |
+
get_max_memory,
|
| 162 |
+
get_mixed_precision_context_manager,
|
| 163 |
+
has_offloaded_params,
|
| 164 |
+
id_tensor_storage,
|
| 165 |
+
infer_auto_device_map,
|
| 166 |
+
is_peft_model,
|
| 167 |
+
load_checkpoint_in_model,
|
| 168 |
+
load_offloaded_weights,
|
| 169 |
+
load_state_dict,
|
| 170 |
+
named_module_tensors,
|
| 171 |
+
retie_parameters,
|
| 172 |
+
set_module_tensor_to_device,
|
| 173 |
+
)
|
| 174 |
+
from .offload import (
|
| 175 |
+
OffloadedWeightsLoader,
|
| 176 |
+
PrefixedDataset,
|
| 177 |
+
extract_submodules_state_dict,
|
| 178 |
+
load_offloaded_weight,
|
| 179 |
+
offload_state_dict,
|
| 180 |
+
offload_weight,
|
| 181 |
+
save_offload_index,
|
| 182 |
+
)
|
| 183 |
+
from .operations import (
|
| 184 |
+
CannotPadNestedTensorWarning,
|
| 185 |
+
GatheredParameters,
|
| 186 |
+
broadcast,
|
| 187 |
+
broadcast_object_list,
|
| 188 |
+
concatenate,
|
| 189 |
+
convert_outputs_to_fp32,
|
| 190 |
+
convert_to_fp32,
|
| 191 |
+
copy_tensor_to_devices,
|
| 192 |
+
find_batch_size,
|
| 193 |
+
find_device,
|
| 194 |
+
gather,
|
| 195 |
+
gather_object,
|
| 196 |
+
get_data_structure,
|
| 197 |
+
honor_type,
|
| 198 |
+
ignorant_find_batch_size,
|
| 199 |
+
initialize_tensors,
|
| 200 |
+
is_namedtuple,
|
| 201 |
+
is_tensor_information,
|
| 202 |
+
is_torch_tensor,
|
| 203 |
+
listify,
|
| 204 |
+
pad_across_processes,
|
| 205 |
+
pad_input_tensors,
|
| 206 |
+
recursively_apply,
|
| 207 |
+
reduce,
|
| 208 |
+
send_to_device,
|
| 209 |
+
slice_tensors,
|
| 210 |
+
)
|
| 211 |
+
from .versions import compare_versions, is_torch_version
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if is_deepspeed_available():
|
| 215 |
+
from .deepspeed import (
|
| 216 |
+
DeepSpeedEngineWrapper,
|
| 217 |
+
DeepSpeedOptimizerWrapper,
|
| 218 |
+
DeepSpeedSchedulerWrapper,
|
| 219 |
+
DummyOptim,
|
| 220 |
+
DummyScheduler,
|
| 221 |
+
HfDeepSpeedConfig,
|
| 222 |
+
get_active_deepspeed_plugin,
|
| 223 |
+
map_pytorch_optim_to_deepspeed,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
from .bnb import has_4bit_bnb_layers, load_and_quantize_model
|
| 227 |
+
from .fsdp_utils import (
|
| 228 |
+
disable_fsdp_ram_efficient_loading,
|
| 229 |
+
enable_fsdp_ram_efficient_loading,
|
| 230 |
+
ensure_weights_retied,
|
| 231 |
+
fsdp2_apply_ac,
|
| 232 |
+
fsdp2_canonicalize_names,
|
| 233 |
+
fsdp2_load_full_state_dict,
|
| 234 |
+
fsdp2_prepare_model,
|
| 235 |
+
fsdp2_switch_optimizer_parameters,
|
| 236 |
+
get_fsdp2_grad_scaler,
|
| 237 |
+
load_fsdp_model,
|
| 238 |
+
load_fsdp_optimizer,
|
| 239 |
+
merge_fsdp_weights,
|
| 240 |
+
save_fsdp_model,
|
| 241 |
+
save_fsdp_optimizer,
|
| 242 |
+
)
|
| 243 |
+
from .launch import (
|
| 244 |
+
PrepareForLaunch,
|
| 245 |
+
_filter_args,
|
| 246 |
+
prepare_deepspeed_cmd_env,
|
| 247 |
+
prepare_multi_gpu_env,
|
| 248 |
+
prepare_sagemager_args_inputs,
|
| 249 |
+
prepare_simple_launcher_cmd_env,
|
| 250 |
+
prepare_tpu,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# For docs
|
| 254 |
+
from .megatron_lm import (
|
| 255 |
+
AbstractTrainStep,
|
| 256 |
+
BertTrainStep,
|
| 257 |
+
GPTTrainStep,
|
| 258 |
+
MegatronLMDummyDataLoader,
|
| 259 |
+
MegatronLMDummyScheduler,
|
| 260 |
+
T5TrainStep,
|
| 261 |
+
avg_losses_across_data_parallel_group,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if is_megatron_lm_available():
|
| 266 |
+
from .megatron_lm import (
|
| 267 |
+
MegatronEngine,
|
| 268 |
+
MegatronLMOptimizerWrapper,
|
| 269 |
+
MegatronLMSchedulerWrapper,
|
| 270 |
+
gather_across_data_parallel_groups,
|
| 271 |
+
)
|
| 272 |
+
from .megatron_lm import initialize as megatron_lm_initialize
|
| 273 |
+
from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
|
| 274 |
+
from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
|
| 275 |
+
from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
|
| 276 |
+
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
|
| 277 |
+
from .memory import find_executable_batch_size, release_memory
|
| 278 |
+
from .other import (
|
| 279 |
+
check_os_kernel,
|
| 280 |
+
clean_state_dict_for_safetensors,
|
| 281 |
+
compile_regions,
|
| 282 |
+
compile_regions_deepspeed,
|
| 283 |
+
convert_bytes,
|
| 284 |
+
extract_model_from_parallel,
|
| 285 |
+
get_module_children_bottom_up,
|
| 286 |
+
get_pretty_name,
|
| 287 |
+
has_compiled_regions,
|
| 288 |
+
is_compiled_module,
|
| 289 |
+
is_port_in_use,
|
| 290 |
+
load,
|
| 291 |
+
merge_dicts,
|
| 292 |
+
model_has_dtensor,
|
| 293 |
+
recursive_getattr,
|
| 294 |
+
save,
|
| 295 |
+
wait_for_everyone,
|
| 296 |
+
write_basic_config,
|
| 297 |
+
)
|
| 298 |
+
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
|
| 299 |
+
from .torch_xla import install_xla
|
| 300 |
+
from .tqdm import tqdm
|
| 301 |
+
from .transformer_engine import (
|
| 302 |
+
apply_fp8_autowrap,
|
| 303 |
+
contextual_fp8_autocast,
|
| 304 |
+
convert_model,
|
| 305 |
+
has_transformer_engine_layers,
|
| 306 |
+
)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/ao.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Needed utilities for torchao FP8 training.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from functools import partial
|
| 20 |
+
from typing import TYPE_CHECKING, Callable, Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from .imports import is_torchao_available, torchao_required
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
if is_torchao_available():
|
| 29 |
+
from torchao.float8.float8_linear import Float8LinearConfig
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def find_first_last_linear_layers(model: torch.nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Finds the first and last linear layer names in a model.
|
| 35 |
+
|
| 36 |
+
This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
|
| 37 |
+
|
| 38 |
+
Ref: https://x.com/xariusrke/status/1826669142604141052
|
| 39 |
+
"""
|
| 40 |
+
first_linear, last_linear = None, None
|
| 41 |
+
for name, module in model.named_modules():
|
| 42 |
+
if isinstance(module, torch.nn.Linear):
|
| 43 |
+
if first_linear is None:
|
| 44 |
+
first_linear = name
|
| 45 |
+
last_linear = name
|
| 46 |
+
return first_linear, last_linear
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:
|
| 50 |
+
"""
|
| 51 |
+
A function which will check if `module` is:
|
| 52 |
+
- a `torch.nn.Linear` layer
|
| 53 |
+
- has in_features and out_features divisible by 16
|
| 54 |
+
- is not part of `layers_to_filter`
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
module (`torch.nn.Module`):
|
| 58 |
+
The module to check.
|
| 59 |
+
fqn (`str`):
|
| 60 |
+
The fully qualified name of the layer.
|
| 61 |
+
layers_to_filter (`List[str]`):
|
| 62 |
+
The list of layers to filter.
|
| 63 |
+
"""
|
| 64 |
+
if isinstance(module, torch.nn.Linear):
|
| 65 |
+
if module.in_features % 16 != 0 or module.out_features % 16 != 0:
|
| 66 |
+
return False
|
| 67 |
+
if fqn in layers_to_filter:
|
| 68 |
+
return False
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
|
| 73 |
+
"""
|
| 74 |
+
A filter function which will filter out all linear layers except the first and last.
|
| 75 |
+
|
| 76 |
+
<Tip>
|
| 77 |
+
|
| 78 |
+
For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
|
| 79 |
+
converging properly
|
| 80 |
+
|
| 81 |
+
</Tip>
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
module (`torch.nn.Module`):
|
| 85 |
+
The module to check.
|
| 86 |
+
fqn (`str`):
|
| 87 |
+
The fully qualified name of the layer.
|
| 88 |
+
"""
|
| 89 |
+
first_linear, last_linear = find_first_last_linear_layers(module)
|
| 90 |
+
return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@torchao_required
|
| 94 |
+
def has_ao_layers(model: torch.nn.Module):
|
| 95 |
+
from torchao.float8.float8_linear import Float8Linear
|
| 96 |
+
|
| 97 |
+
for name, module in model.named_modules():
|
| 98 |
+
if isinstance(module, Float8Linear):
|
| 99 |
+
return True
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@torchao_required
|
| 104 |
+
def convert_model_to_fp8_ao(
|
| 105 |
+
model: torch.nn.Module,
|
| 106 |
+
config: Optional["Float8LinearConfig"] = None,
|
| 107 |
+
module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
model (`torch.nn.Module`):
|
| 114 |
+
The model to convert.
|
| 115 |
+
config (`torchao.float8.Float8LinearConfig`, *optional*):
|
| 116 |
+
The configuration for the FP8 training. Recommended to utilize
|
| 117 |
+
`torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
|
| 118 |
+
sufficient (what is passed when set to `None`).
|
| 119 |
+
module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
|
| 120 |
+
Optional function that must take in a module and layer name, and returns a boolean indicating whether the
|
| 121 |
+
module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
|
| 122 |
+
|
| 123 |
+
Example:
|
| 124 |
+
|
| 125 |
+
```python
|
| 126 |
+
from accelerate.utils.ao import convert_model_to_fp8_ao
|
| 127 |
+
|
| 128 |
+
model = MyModel()
|
| 129 |
+
model.to("cuda")
|
| 130 |
+
convert_to_float8_training(model)
|
| 131 |
+
|
| 132 |
+
model.train()
|
| 133 |
+
```
|
| 134 |
+
"""
|
| 135 |
+
from torchao.float8 import convert_to_float8_training
|
| 136 |
+
|
| 137 |
+
first_linear, last_linear = find_first_last_linear_layers(model)
|
| 138 |
+
if module_filter_func is None:
|
| 139 |
+
module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])
|
| 140 |
+
convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/bnb.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
|
| 24 |
+
from accelerate.utils.imports import (
|
| 25 |
+
is_4bit_bnb_available,
|
| 26 |
+
is_8bit_bnb_available,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from ..big_modeling import dispatch_model, init_empty_weights
|
| 30 |
+
from .dataclasses import BnbQuantizationConfig
|
| 31 |
+
from .modeling import (
|
| 32 |
+
find_tied_parameters,
|
| 33 |
+
get_balanced_memory,
|
| 34 |
+
infer_auto_device_map,
|
| 35 |
+
load_checkpoint_in_model,
|
| 36 |
+
offload_weight,
|
| 37 |
+
set_module_tensor_to_device,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_and_quantize_model(
|
| 45 |
+
model: torch.nn.Module,
|
| 46 |
+
bnb_quantization_config: BnbQuantizationConfig,
|
| 47 |
+
weights_location: Optional[Union[str, os.PathLike]] = None,
|
| 48 |
+
device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
|
| 49 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 50 |
+
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
|
| 51 |
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
| 52 |
+
offload_state_dict: bool = False,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the
|
| 56 |
+
model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the
|
| 57 |
+
model is already loaded, we will quantize the model and put the model on the GPU,
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model (`torch.nn.Module`):
|
| 61 |
+
Input model. The model can be already loaded or on the meta device
|
| 62 |
+
bnb_quantization_config (`BnbQuantizationConfig`):
|
| 63 |
+
The bitsandbytes quantization parameters
|
| 64 |
+
weights_location (`str` or `os.PathLike`):
|
| 65 |
+
The folder weights_location to load. It can be:
|
| 66 |
+
- a path to a file containing a whole model state dict
|
| 67 |
+
- a path to a `.json` file containing the index to a sharded checkpoint
|
| 68 |
+
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
| 69 |
+
- a path to a folder containing a unique pytorch_model.bin file.
|
| 70 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 71 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
| 72 |
+
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
| 73 |
+
no_split_module_classes (`List[str]`, *optional*):
|
| 74 |
+
A list of layer class names that should never be split across device (for instance any layer that has a
|
| 75 |
+
residual connection).
|
| 76 |
+
max_memory (`Dict`, *optional*):
|
| 77 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
| 78 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
| 79 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
| 80 |
+
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
| 81 |
+
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
|
| 82 |
+
the weight of the CPU state dict + the biggest shard does not fit.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
`torch.nn.Module`: The quantized model
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
load_in_4bit = bnb_quantization_config.load_in_4bit
|
| 89 |
+
load_in_8bit = bnb_quantization_config.load_in_8bit
|
| 90 |
+
|
| 91 |
+
if load_in_8bit and not is_8bit_bnb_available():
|
| 92 |
+
raise ImportError(
|
| 93 |
+
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
|
| 94 |
+
" make sure you have the latest version of `bitsandbytes` installed."
|
| 95 |
+
)
|
| 96 |
+
if load_in_4bit and not is_4bit_bnb_available():
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
|
| 99 |
+
"make sure you have the latest version of `bitsandbytes` installed."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
modules_on_cpu = []
|
| 103 |
+
# custom device map
|
| 104 |
+
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
| 105 |
+
modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
|
| 106 |
+
|
| 107 |
+
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
| 108 |
+
if bnb_quantization_config.skip_modules is None:
|
| 109 |
+
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
|
| 110 |
+
|
| 111 |
+
# add cpu modules to skip modules only for 4-bit modules
|
| 112 |
+
if load_in_4bit:
|
| 113 |
+
bnb_quantization_config.skip_modules.extend(modules_on_cpu)
|
| 114 |
+
modules_to_not_convert = bnb_quantization_config.skip_modules
|
| 115 |
+
|
| 116 |
+
# We add the modules we want to keep in full precision
|
| 117 |
+
if bnb_quantization_config.keep_in_fp32_modules is None:
|
| 118 |
+
bnb_quantization_config.keep_in_fp32_modules = []
|
| 119 |
+
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
|
| 120 |
+
modules_to_not_convert.extend(keep_in_fp32_modules)
|
| 121 |
+
|
| 122 |
+
# compatibility with peft
|
| 123 |
+
model.is_loaded_in_4bit = load_in_4bit
|
| 124 |
+
model.is_loaded_in_8bit = load_in_8bit
|
| 125 |
+
|
| 126 |
+
model_device = get_parameter_device(model)
|
| 127 |
+
if model_device.type != "meta":
|
| 128 |
+
# quantization of an already loaded model
|
| 129 |
+
logger.warning(
|
| 130 |
+
"It is not recommended to quantize a loaded model. "
|
| 131 |
+
"The model should be instantiated under the `init_empty_weights` context manager."
|
| 132 |
+
)
|
| 133 |
+
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
|
| 134 |
+
# convert param to the right dtype
|
| 135 |
+
dtype = bnb_quantization_config.torch_dtype
|
| 136 |
+
for name, param in model.state_dict().items():
|
| 137 |
+
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
| 138 |
+
param.to(torch.float32)
|
| 139 |
+
if param.dtype != torch.float32:
|
| 140 |
+
name = name.replace(".weight", "").replace(".bias", "")
|
| 141 |
+
param = getattr(model, name, None)
|
| 142 |
+
if param is not None:
|
| 143 |
+
param.to(torch.float32)
|
| 144 |
+
elif torch.is_floating_point(param):
|
| 145 |
+
param.to(dtype)
|
| 146 |
+
if model_device.type == "cuda":
|
| 147 |
+
model.cuda(torch.cuda.current_device())
|
| 148 |
+
torch.cuda.empty_cache()
|
| 149 |
+
elif torch.cuda.is_available():
|
| 150 |
+
model.to(torch.cuda.current_device())
|
| 151 |
+
elif torch.xpu.is_available():
|
| 152 |
+
model.to(torch.xpu.current_device())
|
| 153 |
+
else:
|
| 154 |
+
raise RuntimeError("No GPU or Intel XPU found. A GPU or Intel XPU is needed for quantization.")
|
| 155 |
+
logger.info(
|
| 156 |
+
f"The model device type is {model_device.type}. However, gpu or intel xpu is needed for quantization."
|
| 157 |
+
"We move the model to it."
|
| 158 |
+
)
|
| 159 |
+
return model
|
| 160 |
+
|
| 161 |
+
elif weights_location is None:
|
| 162 |
+
raise RuntimeError(
|
| 163 |
+
f"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} "
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
with init_empty_weights():
|
| 168 |
+
model = replace_with_bnb_layers(
|
| 169 |
+
model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert
|
| 170 |
+
)
|
| 171 |
+
device_map = get_quantized_model_device_map(
|
| 172 |
+
model,
|
| 173 |
+
bnb_quantization_config,
|
| 174 |
+
device_map,
|
| 175 |
+
max_memory=max_memory,
|
| 176 |
+
no_split_module_classes=no_split_module_classes,
|
| 177 |
+
)
|
| 178 |
+
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
|
| 179 |
+
offload_state_dict = True
|
| 180 |
+
|
| 181 |
+
offload = any(x in list(device_map.values()) for x in ["cpu", "disk"])
|
| 182 |
+
|
| 183 |
+
load_checkpoint_in_model(
|
| 184 |
+
model,
|
| 185 |
+
weights_location,
|
| 186 |
+
device_map,
|
| 187 |
+
dtype=bnb_quantization_config.torch_dtype,
|
| 188 |
+
offload_folder=offload_folder,
|
| 189 |
+
offload_state_dict=offload_state_dict,
|
| 190 |
+
keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules,
|
| 191 |
+
offload_8bit_bnb=load_in_8bit and offload,
|
| 192 |
+
)
|
| 193 |
+
return dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_quantized_model_device_map(
|
| 197 |
+
model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None
|
| 198 |
+
):
|
| 199 |
+
if device_map is None:
|
| 200 |
+
if torch.cuda.is_available():
|
| 201 |
+
device_map = {"": torch.cuda.current_device()}
|
| 202 |
+
elif torch.xpu.is_available():
|
| 203 |
+
device_map = {"": torch.xpu.current_device()}
|
| 204 |
+
else:
|
| 205 |
+
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
| 206 |
+
logger.info("The device_map was not initialized.Setting device_map to `{'':torch.cuda.current_device()}`.")
|
| 207 |
+
|
| 208 |
+
if isinstance(device_map, str):
|
| 209 |
+
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
| 212 |
+
"'sequential'."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
special_dtypes = {}
|
| 216 |
+
special_dtypes.update(
|
| 217 |
+
{
|
| 218 |
+
name: bnb_quantization_config.torch_dtype
|
| 219 |
+
for name, _ in model.named_parameters()
|
| 220 |
+
if any(m in name for m in bnb_quantization_config.skip_modules)
|
| 221 |
+
}
|
| 222 |
+
)
|
| 223 |
+
special_dtypes.update(
|
| 224 |
+
{
|
| 225 |
+
name: torch.float32
|
| 226 |
+
for name, _ in model.named_parameters()
|
| 227 |
+
if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules)
|
| 228 |
+
}
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
kwargs = {}
|
| 232 |
+
kwargs["special_dtypes"] = special_dtypes
|
| 233 |
+
kwargs["no_split_module_classes"] = no_split_module_classes
|
| 234 |
+
kwargs["dtype"] = bnb_quantization_config.target_dtype
|
| 235 |
+
|
| 236 |
+
# get max_memory for each device.
|
| 237 |
+
if device_map != "sequential":
|
| 238 |
+
max_memory = get_balanced_memory(
|
| 239 |
+
model,
|
| 240 |
+
low_zero=(device_map == "balanced_low_0"),
|
| 241 |
+
max_memory=max_memory,
|
| 242 |
+
**kwargs,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
kwargs["max_memory"] = max_memory
|
| 246 |
+
device_map = infer_auto_device_map(model, **kwargs)
|
| 247 |
+
|
| 248 |
+
if isinstance(device_map, dict):
|
| 249 |
+
# check if don't have any quantized module on the cpu
|
| 250 |
+
modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules
|
| 251 |
+
|
| 252 |
+
device_map_without_some_modules = {
|
| 253 |
+
key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert
|
| 254 |
+
}
|
| 255 |
+
for device in ["cpu", "disk"]:
|
| 256 |
+
if device in device_map_without_some_modules.values():
|
| 257 |
+
if bnb_quantization_config.load_in_4bit:
|
| 258 |
+
raise ValueError(
|
| 259 |
+
"""
|
| 260 |
+
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
|
| 261 |
+
the quantized model. If you want to dispatch the model on the CPU or the disk while keeping
|
| 262 |
+
these modules in `torch_dtype`, you need to pass a custom `device_map` to
|
| 263 |
+
`load_and_quantize_model`. Check
|
| 264 |
+
https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk
|
| 265 |
+
for more details.
|
| 266 |
+
"""
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
logger.info(
|
| 270 |
+
"Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit"
|
| 271 |
+
)
|
| 272 |
+
del device_map_without_some_modules
|
| 273 |
+
return device_map
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
|
| 277 |
+
"""
|
| 278 |
+
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
|
| 279 |
+
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
|
| 280 |
+
|
| 281 |
+
Parameters:
|
| 282 |
+
model (`torch.nn.Module`):
|
| 283 |
+
Input model or `torch.nn.Module` as the function is run recursively.
|
| 284 |
+
modules_to_not_convert (`List[str]`):
|
| 285 |
+
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
|
| 286 |
+
numerical stability reasons.
|
| 287 |
+
current_key_name (`List[str]`, *optional*):
|
| 288 |
+
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
| 289 |
+
it) is not in the list of modules to not convert.
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
if modules_to_not_convert is None:
|
| 293 |
+
modules_to_not_convert = []
|
| 294 |
+
|
| 295 |
+
model, has_been_replaced = _replace_with_bnb_layers(
|
| 296 |
+
model, bnb_quantization_config, modules_to_not_convert, current_key_name
|
| 297 |
+
)
|
| 298 |
+
if not has_been_replaced:
|
| 299 |
+
logger.warning(
|
| 300 |
+
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
| 301 |
+
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
| 302 |
+
" Please double check your model architecture, or submit an issue on github if you think this is"
|
| 303 |
+
" a bug."
|
| 304 |
+
)
|
| 305 |
+
return model
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _replace_with_bnb_layers(
|
| 309 |
+
model,
|
| 310 |
+
bnb_quantization_config,
|
| 311 |
+
modules_to_not_convert=None,
|
| 312 |
+
current_key_name=None,
|
| 313 |
+
):
|
| 314 |
+
"""
|
| 315 |
+
Private method that wraps the recursion for module replacement.
|
| 316 |
+
|
| 317 |
+
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
|
| 318 |
+
"""
|
| 319 |
+
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
|
| 320 |
+
import bitsandbytes as bnb
|
| 321 |
+
|
| 322 |
+
has_been_replaced = False
|
| 323 |
+
for name, module in model.named_children():
|
| 324 |
+
if current_key_name is None:
|
| 325 |
+
current_key_name = []
|
| 326 |
+
current_key_name.append(name)
|
| 327 |
+
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
| 328 |
+
# Check if the current key is not in the `modules_to_not_convert`
|
| 329 |
+
current_key_name_str = ".".join(current_key_name)
|
| 330 |
+
proceed = True
|
| 331 |
+
for key in modules_to_not_convert:
|
| 332 |
+
if (
|
| 333 |
+
(key in current_key_name_str) and (key + "." in current_key_name_str)
|
| 334 |
+
) or key == current_key_name_str:
|
| 335 |
+
proceed = False
|
| 336 |
+
break
|
| 337 |
+
if proceed:
|
| 338 |
+
# Load bnb module with empty weight and replace ``nn.Linear` module
|
| 339 |
+
if bnb_quantization_config.load_in_8bit:
|
| 340 |
+
bnb_module = bnb.nn.Linear8bitLt(
|
| 341 |
+
module.in_features,
|
| 342 |
+
module.out_features,
|
| 343 |
+
module.bias is not None,
|
| 344 |
+
has_fp16_weights=False,
|
| 345 |
+
threshold=bnb_quantization_config.llm_int8_threshold,
|
| 346 |
+
)
|
| 347 |
+
elif bnb_quantization_config.load_in_4bit:
|
| 348 |
+
bnb_module = bnb.nn.Linear4bit(
|
| 349 |
+
module.in_features,
|
| 350 |
+
module.out_features,
|
| 351 |
+
module.bias is not None,
|
| 352 |
+
bnb_quantization_config.bnb_4bit_compute_dtype,
|
| 353 |
+
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
|
| 354 |
+
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
|
| 358 |
+
bnb_module.weight.data = module.weight.data
|
| 359 |
+
if module.bias is not None:
|
| 360 |
+
bnb_module.bias.data = module.bias.data
|
| 361 |
+
bnb_module.requires_grad_(False)
|
| 362 |
+
setattr(model, name, bnb_module)
|
| 363 |
+
has_been_replaced = True
|
| 364 |
+
if len(list(module.children())) > 0:
|
| 365 |
+
_, _has_been_replaced = _replace_with_bnb_layers(
|
| 366 |
+
module, bnb_quantization_config, modules_to_not_convert, current_key_name
|
| 367 |
+
)
|
| 368 |
+
has_been_replaced = has_been_replaced | _has_been_replaced
|
| 369 |
+
# Remove the last key for recursion
|
| 370 |
+
current_key_name.pop(-1)
|
| 371 |
+
return model, has_been_replaced
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def get_keys_to_not_convert(model):
|
| 375 |
+
r"""
|
| 376 |
+
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
| 377 |
+
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
| 378 |
+
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
| 379 |
+
int8.
|
| 380 |
+
|
| 381 |
+
Parameters:
|
| 382 |
+
model (`torch.nn.Module`):
|
| 383 |
+
Input model
|
| 384 |
+
"""
|
| 385 |
+
# Create a copy of the model
|
| 386 |
+
with init_empty_weights():
|
| 387 |
+
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
| 388 |
+
|
| 389 |
+
tied_params = find_tied_parameters(tied_model)
|
| 390 |
+
# For compatibility with Accelerate < 0.18
|
| 391 |
+
if isinstance(tied_params, dict):
|
| 392 |
+
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
|
| 393 |
+
else:
|
| 394 |
+
tied_keys = sum(tied_params, [])
|
| 395 |
+
has_tied_params = len(tied_keys) > 0
|
| 396 |
+
|
| 397 |
+
# Check if it is a base model
|
| 398 |
+
is_base_model = False
|
| 399 |
+
if hasattr(model, "base_model_prefix"):
|
| 400 |
+
is_base_model = not hasattr(model, model.base_model_prefix)
|
| 401 |
+
|
| 402 |
+
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
| 403 |
+
if (not has_tied_params) and is_base_model:
|
| 404 |
+
return []
|
| 405 |
+
|
| 406 |
+
# otherwise they have an attached head
|
| 407 |
+
list_modules = list(model.named_children())
|
| 408 |
+
list_last_module = [list_modules[-1][0]]
|
| 409 |
+
|
| 410 |
+
# add last module together with tied weights
|
| 411 |
+
intersection = set(list_last_module) - set(tied_keys)
|
| 412 |
+
list_untouched = list(set(tied_keys)) + list(intersection)
|
| 413 |
+
|
| 414 |
+
# remove ".weight" from the keys
|
| 415 |
+
names_to_remove = [".weight", ".bias"]
|
| 416 |
+
filtered_module_names = []
|
| 417 |
+
for name in list_untouched:
|
| 418 |
+
for name_to_remove in names_to_remove:
|
| 419 |
+
if name_to_remove in name:
|
| 420 |
+
name = name.replace(name_to_remove, "")
|
| 421 |
+
filtered_module_names.append(name)
|
| 422 |
+
|
| 423 |
+
return filtered_module_names
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def has_4bit_bnb_layers(model):
|
| 427 |
+
"""Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model"""
|
| 428 |
+
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
|
| 429 |
+
import bitsandbytes as bnb
|
| 430 |
+
|
| 431 |
+
for m in model.modules():
|
| 432 |
+
if isinstance(m, bnb.nn.Linear4bit):
|
| 433 |
+
return True
|
| 434 |
+
return False
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def get_parameter_device(parameter: nn.Module):
|
| 438 |
+
return next(parameter.parameters()).device
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):
|
| 442 |
+
# if it is not quantized, we quantize and offload the quantized weights and the SCB stats
|
| 443 |
+
if fp16_statistics is None:
|
| 444 |
+
set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
|
| 445 |
+
tensor_name = param_name
|
| 446 |
+
module = model
|
| 447 |
+
if "." in tensor_name:
|
| 448 |
+
splits = tensor_name.split(".")
|
| 449 |
+
for split in splits[:-1]:
|
| 450 |
+
new_module = getattr(module, split)
|
| 451 |
+
if new_module is None:
|
| 452 |
+
raise ValueError(f"{module} has no attribute {split}.")
|
| 453 |
+
module = new_module
|
| 454 |
+
tensor_name = splits[-1]
|
| 455 |
+
# offload weights
|
| 456 |
+
module._parameters[tensor_name].requires_grad = False
|
| 457 |
+
offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
|
| 458 |
+
if hasattr(module._parameters[tensor_name], "SCB"):
|
| 459 |
+
offload_weight(
|
| 460 |
+
module._parameters[tensor_name].SCB,
|
| 461 |
+
param_name.replace("weight", "SCB"),
|
| 462 |
+
offload_folder,
|
| 463 |
+
index=offload_index,
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
offload_weight(param, param_name, offload_folder, index=offload_index)
|
| 467 |
+
offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index)
|
| 468 |
+
|
| 469 |
+
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size()))
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/constants.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import operator as op
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
SCALER_NAME = "scaler.pt"
|
| 21 |
+
MODEL_NAME = "pytorch_model"
|
| 22 |
+
SAFE_MODEL_NAME = "model"
|
| 23 |
+
RNG_STATE_NAME = "random_states"
|
| 24 |
+
OPTIMIZER_NAME = "optimizer"
|
| 25 |
+
SCHEDULER_NAME = "scheduler"
|
| 26 |
+
SAMPLER_NAME = "sampler"
|
| 27 |
+
PROFILE_PATTERN_NAME = "profile_{suffix}.json"
|
| 28 |
+
WEIGHTS_NAME = f"{MODEL_NAME}.bin"
|
| 29 |
+
WEIGHTS_PATTERN_NAME = "pytorch_model{suffix}.bin"
|
| 30 |
+
WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json"
|
| 31 |
+
SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors"
|
| 32 |
+
SAFE_WEIGHTS_PATTERN_NAME = "model{suffix}.safetensors"
|
| 33 |
+
SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json"
|
| 34 |
+
SAGEMAKER_PYTORCH_VERSION = "1.10.2"
|
| 35 |
+
SAGEMAKER_PYTHON_VERSION = "py38"
|
| 36 |
+
SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0"
|
| 37 |
+
SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"]
|
| 38 |
+
FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "HYBRID_SHARD_ZERO2"]
|
| 39 |
+
FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
|
| 40 |
+
FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
|
| 41 |
+
FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
|
| 42 |
+
FSDP2_STATE_DICT_TYPE = ["SHARDED_STATE_DICT", "FULL_STATE_DICT"]
|
| 43 |
+
FSDP_PYTORCH_VERSION = (
|
| 44 |
+
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
|
| 45 |
+
)
|
| 46 |
+
FSDP2_PYTORCH_VERSION = "2.6.0"
|
| 47 |
+
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
| 48 |
+
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
|
| 49 |
+
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
|
| 50 |
+
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
|
| 51 |
+
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
|
| 52 |
+
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
|
| 53 |
+
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
|
| 54 |
+
|
| 55 |
+
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
|
| 56 |
+
BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
|
| 57 |
+
BETA_SP_AVAILABLE_DEEPSPEED_VERSION = "0.18.2"
|
| 58 |
+
|
| 59 |
+
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
| 60 |
+
|
| 61 |
+
# These are the args for `torch.distributed.launch` for pytorch < 1.9
|
| 62 |
+
TORCH_LAUNCH_PARAMS = [
|
| 63 |
+
"nnodes",
|
| 64 |
+
"nproc_per_node",
|
| 65 |
+
"rdzv_backend",
|
| 66 |
+
"rdzv_endpoint",
|
| 67 |
+
"rdzv_id",
|
| 68 |
+
"rdzv_conf",
|
| 69 |
+
"standalone",
|
| 70 |
+
"max_restarts",
|
| 71 |
+
"monitor_interval",
|
| 72 |
+
"start_method",
|
| 73 |
+
"role",
|
| 74 |
+
"module",
|
| 75 |
+
"m",
|
| 76 |
+
"no_python",
|
| 77 |
+
"run_path",
|
| 78 |
+
"log_dir",
|
| 79 |
+
"r",
|
| 80 |
+
"redirects",
|
| 81 |
+
"t",
|
| 82 |
+
"tee",
|
| 83 |
+
"node_rank",
|
| 84 |
+
"master_addr",
|
| 85 |
+
"master_port",
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
|
| 89 |
+
TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
|
| 90 |
+
"MULTI_NPU",
|
| 91 |
+
"MULTI_MLU",
|
| 92 |
+
"MULTI_SDAA",
|
| 93 |
+
"MULTI_MUSA",
|
| 94 |
+
"MULTI_XPU",
|
| 95 |
+
"MULTI_CPU",
|
| 96 |
+
"MULTI_HPU",
|
| 97 |
+
]
|
| 98 |
+
SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING = (
|
| 99 |
+
torch.nn.Conv1d,
|
| 100 |
+
torch.nn.Conv2d,
|
| 101 |
+
torch.nn.Conv3d,
|
| 102 |
+
torch.nn.ConvTranspose1d,
|
| 103 |
+
torch.nn.ConvTranspose2d,
|
| 104 |
+
torch.nn.ConvTranspose3d,
|
| 105 |
+
torch.nn.Linear,
|
| 106 |
+
)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/dataclasses.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/deepspeed.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import base64
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
|
| 20 |
+
from torch import optim
|
| 21 |
+
|
| 22 |
+
from ..optimizer import AcceleratedOptimizer
|
| 23 |
+
from ..scheduler import AcceleratedScheduler
|
| 24 |
+
from .dataclasses import DistributedType
|
| 25 |
+
from .imports import is_bnb_available
|
| 26 |
+
from .versions import compare_versions
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def map_pytorch_optim_to_deepspeed(optimizer):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
optimizer: torch.optim.Optimizer
|
| 33 |
+
|
| 34 |
+
Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]}
|
| 38 |
+
|
| 39 |
+
# Select the DeepSpeedCPUOptimizer based on the original optimizer class.
|
| 40 |
+
# DeepSpeedCPUAdam is the default
|
| 41 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
| 42 |
+
|
| 43 |
+
optimizer_class = DeepSpeedCPUAdam
|
| 44 |
+
|
| 45 |
+
# For DeepSpeedCPUAdam (adamw_mode)
|
| 46 |
+
if compare_versions("deepspeed", ">=", "0.3.1"):
|
| 47 |
+
defaults["adamw_mode"] = False
|
| 48 |
+
is_adaw = isinstance(optimizer, optim.AdamW)
|
| 49 |
+
|
| 50 |
+
if is_bnb_available() and not is_adaw:
|
| 51 |
+
import bitsandbytes.optim as bnb_opt
|
| 52 |
+
|
| 53 |
+
if isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)):
|
| 54 |
+
try:
|
| 55 |
+
is_adaw = optimizer.optim_bits == 32
|
| 56 |
+
except AttributeError:
|
| 57 |
+
is_adaw = optimizer.args.optim_bits == 32
|
| 58 |
+
else:
|
| 59 |
+
is_adaw = False
|
| 60 |
+
|
| 61 |
+
if is_adaw:
|
| 62 |
+
defaults["adamw_mode"] = True
|
| 63 |
+
|
| 64 |
+
# For DeepSpeedCPUAdagrad
|
| 65 |
+
if compare_versions("deepspeed", ">=", "0.5.5"):
|
| 66 |
+
# Check if the optimizer is PyTorch's Adagrad.
|
| 67 |
+
is_ada = isinstance(optimizer, optim.Adagrad)
|
| 68 |
+
# If not, and bitsandbytes is available,
|
| 69 |
+
# # check if the optimizer is the 32-bit bitsandbytes Adagrad.
|
| 70 |
+
if is_bnb_available() and not is_ada:
|
| 71 |
+
import bitsandbytes.optim as bnb_opt
|
| 72 |
+
|
| 73 |
+
if isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)):
|
| 74 |
+
try:
|
| 75 |
+
is_ada = optimizer.optim_bits == 32
|
| 76 |
+
except AttributeError:
|
| 77 |
+
is_ada = optimizer.args.optim_bits == 32
|
| 78 |
+
if is_ada:
|
| 79 |
+
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
|
| 80 |
+
|
| 81 |
+
optimizer_class = DeepSpeedCPUAdagrad
|
| 82 |
+
|
| 83 |
+
# For DeepSpeedCPULion
|
| 84 |
+
if is_bnb_available(min_version="0.38.0") and compare_versions("deepspeed", ">=", "0.11.0"):
|
| 85 |
+
from bitsandbytes.optim import Lion, Lion32bit
|
| 86 |
+
|
| 87 |
+
if isinstance(optimizer, (Lion, Lion32bit)):
|
| 88 |
+
try:
|
| 89 |
+
is_bnb_32bits = optimizer.optim_bits == 32
|
| 90 |
+
except AttributeError:
|
| 91 |
+
is_bnb_32bits = optimizer.args.optim_bits == 32
|
| 92 |
+
if is_bnb_32bits:
|
| 93 |
+
from deepspeed.ops.lion import DeepSpeedCPULion
|
| 94 |
+
|
| 95 |
+
optimizer_class = DeepSpeedCPULion
|
| 96 |
+
|
| 97 |
+
return optimizer_class(optimizer.param_groups, **defaults)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_active_deepspeed_plugin(state):
|
| 101 |
+
"""
|
| 102 |
+
Returns the currently active DeepSpeedPlugin.
|
| 103 |
+
|
| 104 |
+
Raises:
|
| 105 |
+
ValueError: If DeepSpeed was not enabled and this function is called.
|
| 106 |
+
"""
|
| 107 |
+
if state.distributed_type != DistributedType.DEEPSPEED:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"Couldn't retrieve the active `DeepSpeedPlugin` as none were enabled. "
|
| 110 |
+
"Please make sure that either `Accelerator` is configured for `deepspeed` "
|
| 111 |
+
"or make sure that the desired `DeepSpeedPlugin` has been enabled (`AcceleratorState().select_deepspeed_plugin(name)`) "
|
| 112 |
+
"before calling this function."
|
| 113 |
+
)
|
| 114 |
+
if not isinstance(state.deepspeed_plugins, dict):
|
| 115 |
+
return state.deepspeed_plugins
|
| 116 |
+
return next(plugin for plugin in state.deepspeed_plugins.values() if plugin.selected)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class HfDeepSpeedConfig:
|
| 120 |
+
"""
|
| 121 |
+
This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
|
| 122 |
+
|
| 123 |
+
A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
|
| 124 |
+
things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
|
| 125 |
+
it's important that this object remains alive while the program is still running.
|
| 126 |
+
|
| 127 |
+
[`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
|
| 128 |
+
with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
|
| 129 |
+
the DeepSpeed configuration is not modified in any way.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.
|
| 133 |
+
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(self, config_file_or_dict):
|
| 137 |
+
if isinstance(config_file_or_dict, dict):
|
| 138 |
+
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
| 139 |
+
# modified it, it will not be accepted here again, since `auto` values would have been overridden
|
| 140 |
+
config = deepcopy(config_file_or_dict)
|
| 141 |
+
elif os.path.exists(config_file_or_dict):
|
| 142 |
+
with open(config_file_or_dict, encoding="utf-8") as f:
|
| 143 |
+
config = json.load(f)
|
| 144 |
+
else:
|
| 145 |
+
try:
|
| 146 |
+
try:
|
| 147 |
+
# First try parsing as JSON directly
|
| 148 |
+
config = json.loads(config_file_or_dict)
|
| 149 |
+
except json.JSONDecodeError:
|
| 150 |
+
# If that fails, try base64 decoding
|
| 151 |
+
config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8")
|
| 152 |
+
config = json.loads(config_decoded)
|
| 153 |
+
except (UnicodeDecodeError, AttributeError, ValueError):
|
| 154 |
+
raise ValueError(
|
| 155 |
+
f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.config = config
|
| 159 |
+
|
| 160 |
+
self.set_stage_and_offload()
|
| 161 |
+
|
| 162 |
+
def set_stage_and_offload(self):
|
| 163 |
+
# zero stage - this is done as early as possible, before model is created, to allow
|
| 164 |
+
# ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
|
| 165 |
+
# during ``zero.Init()`` which needs to know the dtype, and some other hparams.
|
| 166 |
+
self._stage = self.get_value("zero_optimization.stage", -1)
|
| 167 |
+
|
| 168 |
+
# offload
|
| 169 |
+
self._offload = False
|
| 170 |
+
if self.is_zero2() or self.is_zero3():
|
| 171 |
+
offload_devices_valid = set(["cpu", "nvme"])
|
| 172 |
+
offload_devices = set(
|
| 173 |
+
[
|
| 174 |
+
self.get_value("zero_optimization.offload_optimizer.device"),
|
| 175 |
+
self.get_value("zero_optimization.offload_param.device"),
|
| 176 |
+
]
|
| 177 |
+
)
|
| 178 |
+
if len(offload_devices & offload_devices_valid) > 0:
|
| 179 |
+
self._offload = True
|
| 180 |
+
|
| 181 |
+
def find_config_node(self, ds_key_long):
|
| 182 |
+
config = self.config
|
| 183 |
+
|
| 184 |
+
# find the config node of interest if it exists
|
| 185 |
+
nodes = ds_key_long.split(".")
|
| 186 |
+
ds_key = nodes.pop()
|
| 187 |
+
for node in nodes:
|
| 188 |
+
config = config.get(node)
|
| 189 |
+
if config is None:
|
| 190 |
+
return None, ds_key
|
| 191 |
+
|
| 192 |
+
return config, ds_key
|
| 193 |
+
|
| 194 |
+
def get_value(self, ds_key_long, default=None):
|
| 195 |
+
"""
|
| 196 |
+
Returns the set value or `default` if no value is set
|
| 197 |
+
"""
|
| 198 |
+
config, ds_key = self.find_config_node(ds_key_long)
|
| 199 |
+
if config is None:
|
| 200 |
+
return default
|
| 201 |
+
return config.get(ds_key, default)
|
| 202 |
+
|
| 203 |
+
def del_config_sub_tree(self, ds_key_long, must_exist=False):
|
| 204 |
+
"""
|
| 205 |
+
Deletes a sub-section of the config file if it's found.
|
| 206 |
+
|
| 207 |
+
Unless `must_exist` is `True` the section doesn't have to exist.
|
| 208 |
+
"""
|
| 209 |
+
config = self.config
|
| 210 |
+
|
| 211 |
+
# find the config node of interest if it exists
|
| 212 |
+
nodes = ds_key_long.split(".")
|
| 213 |
+
for node in nodes:
|
| 214 |
+
parent_config = config
|
| 215 |
+
config = config.get(node)
|
| 216 |
+
if config is None:
|
| 217 |
+
if must_exist:
|
| 218 |
+
raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}")
|
| 219 |
+
else:
|
| 220 |
+
return
|
| 221 |
+
|
| 222 |
+
# if found remove it
|
| 223 |
+
if parent_config is not None:
|
| 224 |
+
parent_config.pop(node)
|
| 225 |
+
|
| 226 |
+
def is_true(self, ds_key_long):
|
| 227 |
+
"""
|
| 228 |
+
Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
|
| 229 |
+
specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set).
|
| 230 |
+
|
| 231 |
+
"""
|
| 232 |
+
value = self.get_value(ds_key_long)
|
| 233 |
+
return False if value is None else bool(value)
|
| 234 |
+
|
| 235 |
+
def is_false(self, ds_key_long):
|
| 236 |
+
"""
|
| 237 |
+
Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
|
| 238 |
+
specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set).
|
| 239 |
+
"""
|
| 240 |
+
value = self.get_value(ds_key_long)
|
| 241 |
+
return False if value is None else not bool(value)
|
| 242 |
+
|
| 243 |
+
def is_zero2(self):
|
| 244 |
+
return self._stage == 2
|
| 245 |
+
|
| 246 |
+
def is_zero3(self):
|
| 247 |
+
return self._stage == 3
|
| 248 |
+
|
| 249 |
+
def is_offload(self):
|
| 250 |
+
return self._offload
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class DeepSpeedEngineWrapper:
|
| 254 |
+
"""
|
| 255 |
+
Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(self, engine):
|
| 262 |
+
self.engine = engine
|
| 263 |
+
|
| 264 |
+
def backward(self, loss, sync_gradients=True, **kwargs):
|
| 265 |
+
# Set gradient accumulation boundary based on Accelerate's sync_gradients state
|
| 266 |
+
# This tells DeepSpeed whether this is the final micro-batch before gradient sync
|
| 267 |
+
self.engine.set_gradient_accumulation_boundary(is_boundary=sync_gradients)
|
| 268 |
+
|
| 269 |
+
# runs backpropagation and handles mixed precision
|
| 270 |
+
self.engine.backward(loss, **kwargs)
|
| 271 |
+
|
| 272 |
+
# Only perform step and related operations at gradient accumulation boundaries
|
| 273 |
+
if sync_gradients:
|
| 274 |
+
# Deepspeed's `engine.step` performs the following operations:
|
| 275 |
+
# - gradient accumulation check
|
| 276 |
+
# - gradient clipping
|
| 277 |
+
# - optimizer step
|
| 278 |
+
# - zero grad
|
| 279 |
+
# - checking overflow
|
| 280 |
+
# - lr_scheduler step (only if engine.lr_scheduler is not None)
|
| 281 |
+
self.engine.step()
|
| 282 |
+
# and this plugin overrides the above calls with no-ops when Accelerate runs under
|
| 283 |
+
# Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple
|
| 284 |
+
# training loop that works transparently under many training regimes.
|
| 285 |
+
|
| 286 |
+
def get_global_grad_norm(self):
|
| 287 |
+
"""Get the global gradient norm from DeepSpeed engine."""
|
| 288 |
+
grad_norm = self.engine.get_global_grad_norm()
|
| 289 |
+
# Convert to scalar if it's a tensor
|
| 290 |
+
if hasattr(grad_norm, "item"):
|
| 291 |
+
return grad_norm.item()
|
| 292 |
+
return grad_norm
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
|
| 296 |
+
"""
|
| 297 |
+
Internal wrapper around a deepspeed optimizer.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
optimizer (`torch.optim.optimizer.Optimizer`):
|
| 301 |
+
The optimizer to wrap.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def __init__(self, optimizer):
|
| 305 |
+
super().__init__(optimizer, device_placement=False, scaler=None)
|
| 306 |
+
self.__has_overflow__ = hasattr(self.optimizer, "overflow")
|
| 307 |
+
|
| 308 |
+
def zero_grad(self, set_to_none=None):
|
| 309 |
+
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
|
| 310 |
+
|
| 311 |
+
def step(self):
|
| 312 |
+
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
def step_was_skipped(self):
|
| 316 |
+
"""Whether or not the optimizer step was done, or skipped because of gradient overflow."""
|
| 317 |
+
if self.__has_overflow__:
|
| 318 |
+
return self.optimizer.overflow
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class DeepSpeedSchedulerWrapper(AcceleratedScheduler):
|
| 323 |
+
"""
|
| 324 |
+
Internal wrapper around a deepspeed scheduler.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
scheduler (`torch.optim.lr_scheduler.LambdaLR`):
|
| 328 |
+
The scheduler to wrap.
|
| 329 |
+
optimizers (one or a list of `torch.optim.Optimizer`):
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
def __init__(self, scheduler, optimizers):
|
| 333 |
+
super().__init__(scheduler, optimizers)
|
| 334 |
+
|
| 335 |
+
def step(self):
|
| 336 |
+
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class DummyOptim:
|
| 340 |
+
"""
|
| 341 |
+
Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training
|
| 342 |
+
loop when optimizer config is specified in the deepspeed config file.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
lr (float):
|
| 346 |
+
Learning rate.
|
| 347 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 348 |
+
parameter groups
|
| 349 |
+
weight_decay (float):
|
| 350 |
+
Weight decay.
|
| 351 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 352 |
+
Other arguments.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
def __init__(self, params, lr=0.001, weight_decay=0, **kwargs):
|
| 356 |
+
self.params = params
|
| 357 |
+
self.lr = lr
|
| 358 |
+
self.weight_decay = weight_decay
|
| 359 |
+
self.kwargs = kwargs
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class DummyScheduler:
|
| 363 |
+
"""
|
| 364 |
+
Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
|
| 365 |
+
loop when scheduler config is specified in the deepspeed config file.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
optimizer (`torch.optim.optimizer.Optimizer`):
|
| 369 |
+
The optimizer to wrap.
|
| 370 |
+
total_num_steps (int, *optional*):
|
| 371 |
+
Total number of steps.
|
| 372 |
+
warmup_num_steps (int, *optional*):
|
| 373 |
+
Number of steps for warmup.
|
| 374 |
+
lr_scheduler_callable (callable, *optional*):
|
| 375 |
+
A callable function that creates an LR Scheduler. It accepts only one argument `optimizer`.
|
| 376 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 377 |
+
Other arguments.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, lr_scheduler_callable=None, **kwargs):
|
| 381 |
+
self.optimizer = optimizer
|
| 382 |
+
self.total_num_steps = total_num_steps
|
| 383 |
+
self.warmup_num_steps = warmup_num_steps
|
| 384 |
+
self.lr_scheduler_callable = lr_scheduler_callable
|
| 385 |
+
self.kwargs = kwargs
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/environment.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
import platform
|
| 19 |
+
import subprocess
|
| 20 |
+
import sys
|
| 21 |
+
from contextlib import contextmanager
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from functools import lru_cache, wraps
|
| 24 |
+
from shutil import which
|
| 25 |
+
from typing import Optional, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from packaging.version import parse
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def convert_dict_to_env_variables(current_env: dict):
|
| 35 |
+
"""
|
| 36 |
+
Verifies that all keys and values in `current_env` do not contain illegal keys or values, and returns a list of
|
| 37 |
+
strings as the result.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
```python
|
| 41 |
+
>>> from accelerate.utils.environment import verify_env
|
| 42 |
+
|
| 43 |
+
>>> env = {"ACCELERATE_DEBUG_MODE": "1", "BAD_ENV_NAME": "<mything", "OTHER_ENV": "2"}
|
| 44 |
+
>>> valid_env_items = verify_env(env)
|
| 45 |
+
>>> print(valid_env_items)
|
| 46 |
+
["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"]
|
| 47 |
+
```
|
| 48 |
+
"""
|
| 49 |
+
forbidden_chars = [";", "\n", "<", ">", " "]
|
| 50 |
+
valid_env_items = []
|
| 51 |
+
for key, value in current_env.items():
|
| 52 |
+
if all(char not in (key + value) for char in forbidden_chars) and len(key) >= 1 and len(value) >= 1:
|
| 53 |
+
valid_env_items.append(f"{key}={value}\n")
|
| 54 |
+
else:
|
| 55 |
+
logger.warning(f"WARNING: Skipping {key}={value} as it contains forbidden characters or missing values.")
|
| 56 |
+
return valid_env_items
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def str_to_bool(value, to_bool: bool = False) -> Union[int, bool]:
|
| 60 |
+
"""
|
| 61 |
+
Converts a string representation of truth to `True` (1) or `False` (0).
|
| 62 |
+
|
| 63 |
+
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
|
| 64 |
+
"""
|
| 65 |
+
value = value.lower()
|
| 66 |
+
if value in ("y", "yes", "t", "true", "on", "1"):
|
| 67 |
+
return 1 if not to_bool else True
|
| 68 |
+
elif value in ("n", "no", "f", "false", "off", "0"):
|
| 69 |
+
return 0 if not to_bool else False
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"invalid truth value {value}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_int_from_env(env_keys, default):
|
| 75 |
+
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
| 76 |
+
for e in env_keys:
|
| 77 |
+
val = int(os.environ.get(e, -1))
|
| 78 |
+
if val >= 0:
|
| 79 |
+
return val
|
| 80 |
+
return default
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def parse_flag_from_env(key, default=False):
|
| 84 |
+
"""Returns truthy value for `key` from the env if available else the default."""
|
| 85 |
+
value = os.environ.get(key, str(default))
|
| 86 |
+
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def parse_choice_from_env(key, default="no"):
|
| 90 |
+
value = os.environ.get(key, str(default))
|
| 91 |
+
return value
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def are_libraries_initialized(*library_names: str) -> list[str]:
|
| 95 |
+
"""
|
| 96 |
+
Checks if any of `library_names` are imported in the environment. Will return any names that are.
|
| 97 |
+
"""
|
| 98 |
+
return [lib_name for lib_name in library_names if lib_name in sys.modules.keys()]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_current_device_type() -> tuple[str, str]:
|
| 102 |
+
"""
|
| 103 |
+
Determines the current device type and distributed type without initializing any device.
|
| 104 |
+
|
| 105 |
+
This is particularly important when using fork-based multiprocessing, as device initialization
|
| 106 |
+
before forking can cause errors.
|
| 107 |
+
|
| 108 |
+
The device detection order follows the same priority as state.py:_prepare_backend():
|
| 109 |
+
MLU -> SDAA -> MUSA -> NPU -> HPU -> CUDA -> XPU
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
tuple[str, str]: A tuple of (device_type, distributed_type)
|
| 113 |
+
- device_type: The device string (e.g., "cuda", "npu", "xpu")
|
| 114 |
+
- distributed_type: The distributed type string (e.g., "MULTI_GPU", "MULTI_NPU")
|
| 115 |
+
|
| 116 |
+
Example:
|
| 117 |
+
```python
|
| 118 |
+
>>> device_type, distributed_type = get_current_device_type()
|
| 119 |
+
>>> print(device_type) # "cuda"
|
| 120 |
+
>>> print(distributed_type) # "MULTI_GPU"
|
| 121 |
+
```
|
| 122 |
+
"""
|
| 123 |
+
from .imports import (
|
| 124 |
+
is_hpu_available,
|
| 125 |
+
is_mlu_available,
|
| 126 |
+
is_musa_available,
|
| 127 |
+
is_npu_available,
|
| 128 |
+
is_sdaa_available,
|
| 129 |
+
is_xpu_available,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if is_mlu_available():
|
| 133 |
+
return "mlu", "MULTI_MLU"
|
| 134 |
+
elif is_sdaa_available():
|
| 135 |
+
return "sdaa", "MULTI_SDAA"
|
| 136 |
+
elif is_musa_available():
|
| 137 |
+
return "musa", "MULTI_MUSA"
|
| 138 |
+
elif is_npu_available():
|
| 139 |
+
return "npu", "MULTI_NPU"
|
| 140 |
+
elif is_hpu_available():
|
| 141 |
+
return "hpu", "MULTI_HPU"
|
| 142 |
+
elif torch.cuda.is_available():
|
| 143 |
+
return "cuda", "MULTI_GPU"
|
| 144 |
+
elif is_xpu_available():
|
| 145 |
+
return "xpu", "MULTI_XPU"
|
| 146 |
+
else:
|
| 147 |
+
# Default to CUDA even if not available (for CPU-only scenarios where CUDA code paths are still used)
|
| 148 |
+
return "cuda", "MULTI_GPU"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _nvidia_smi():
|
| 152 |
+
"""
|
| 153 |
+
Returns the right nvidia-smi command based on the system.
|
| 154 |
+
"""
|
| 155 |
+
if platform.system() == "Windows":
|
| 156 |
+
# If platform is Windows and nvidia-smi can't be found in path
|
| 157 |
+
# try from systemd drive with default installation path
|
| 158 |
+
command = which("nvidia-smi")
|
| 159 |
+
if command is None:
|
| 160 |
+
command = f"{os.environ['systemdrive']}\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe"
|
| 161 |
+
else:
|
| 162 |
+
command = "nvidia-smi"
|
| 163 |
+
return command
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_gpu_info():
|
| 167 |
+
"""
|
| 168 |
+
Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA.
|
| 169 |
+
|
| 170 |
+
Largely based on the `gputil` library.
|
| 171 |
+
"""
|
| 172 |
+
# Returns as list of `n` GPUs and their names
|
| 173 |
+
output = subprocess.check_output(
|
| 174 |
+
[_nvidia_smi(), "--query-gpu=count,name", "--format=csv,noheader"], universal_newlines=True
|
| 175 |
+
)
|
| 176 |
+
output = output.strip()
|
| 177 |
+
gpus = output.split(os.linesep)
|
| 178 |
+
# Get names from output
|
| 179 |
+
gpu_count = len(gpus)
|
| 180 |
+
gpu_names = [gpu.split(",")[1].strip() for gpu in gpus]
|
| 181 |
+
return gpu_names, gpu_count
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_driver_version():
|
| 185 |
+
"""
|
| 186 |
+
Returns the driver version
|
| 187 |
+
|
| 188 |
+
In the case of multiple GPUs, will return the first.
|
| 189 |
+
"""
|
| 190 |
+
output = subprocess.check_output(
|
| 191 |
+
[_nvidia_smi(), "--query-gpu=driver_version", "--format=csv,noheader"], universal_newlines=True
|
| 192 |
+
)
|
| 193 |
+
output = output.strip()
|
| 194 |
+
return output.split(os.linesep)[0]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def check_cuda_p2p_ib_support():
|
| 198 |
+
"""
|
| 199 |
+
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
|
| 200 |
+
the 3090.
|
| 201 |
+
|
| 202 |
+
Notably uses `nvidia-smi` instead of torch to not initialize CUDA.
|
| 203 |
+
"""
|
| 204 |
+
try:
|
| 205 |
+
device_names, device_count = get_gpu_info()
|
| 206 |
+
# As new consumer GPUs get released, add them to `unsupported_devices``
|
| 207 |
+
unsupported_devices = {"RTX 40"}
|
| 208 |
+
if device_count > 1:
|
| 209 |
+
if any(
|
| 210 |
+
unsupported_device in device_name
|
| 211 |
+
for device_name in device_names
|
| 212 |
+
for unsupported_device in unsupported_devices
|
| 213 |
+
):
|
| 214 |
+
# Check if they have the right driver version
|
| 215 |
+
acceptable_driver_version = "550.40.07"
|
| 216 |
+
current_driver_version = get_driver_version()
|
| 217 |
+
if parse(current_driver_version) < parse(acceptable_driver_version):
|
| 218 |
+
return False
|
| 219 |
+
return True
|
| 220 |
+
except Exception:
|
| 221 |
+
pass
|
| 222 |
+
return True
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@lru_cache
|
| 226 |
+
def check_cuda_fp8_capability():
|
| 227 |
+
"""
|
| 228 |
+
Checks if the current GPU available supports FP8.
|
| 229 |
+
|
| 230 |
+
Notably might initialize `torch.cuda` to check.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
# try to get the compute capability from nvidia-smi
|
| 235 |
+
output = subprocess.check_output(
|
| 236 |
+
[_nvidia_smi(), "--query-gpu=compute_capability", "--format=csv,noheader"], universal_newlines=True
|
| 237 |
+
)
|
| 238 |
+
output = output.strip()
|
| 239 |
+
# we take the first GPU's compute capability
|
| 240 |
+
compute_capability = tuple(map(int, output.split(os.linesep)[0].split(".")))
|
| 241 |
+
except Exception:
|
| 242 |
+
compute_capability = torch.cuda.get_device_capability()
|
| 243 |
+
|
| 244 |
+
return compute_capability >= (8, 9)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@dataclass
|
| 248 |
+
class CPUInformation:
|
| 249 |
+
"""
|
| 250 |
+
Stores information about the CPU in a distributed environment. It contains the following attributes:
|
| 251 |
+
- rank: The rank of the current process.
|
| 252 |
+
- world_size: The total number of processes in the world.
|
| 253 |
+
- local_rank: The rank of the current process on the local node.
|
| 254 |
+
- local_world_size: The total number of processes on the local node.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
rank: int = field(default=0, metadata={"help": "The rank of the current process."})
|
| 258 |
+
world_size: int = field(default=1, metadata={"help": "The total number of processes in the world."})
|
| 259 |
+
local_rank: int = field(default=0, metadata={"help": "The rank of the current process on the local node."})
|
| 260 |
+
local_world_size: int = field(default=1, metadata={"help": "The total number of processes on the local node."})
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def get_cpu_distributed_information() -> CPUInformation:
|
| 264 |
+
"""
|
| 265 |
+
Returns various information about the environment in relation to CPU distributed training as a `CPUInformation`
|
| 266 |
+
dataclass.
|
| 267 |
+
"""
|
| 268 |
+
information = {}
|
| 269 |
+
information["rank"] = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
|
| 270 |
+
information["world_size"] = get_int_from_env(
|
| 271 |
+
["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1
|
| 272 |
+
)
|
| 273 |
+
information["local_rank"] = get_int_from_env(
|
| 274 |
+
["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0
|
| 275 |
+
)
|
| 276 |
+
information["local_world_size"] = get_int_from_env(
|
| 277 |
+
["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"],
|
| 278 |
+
1,
|
| 279 |
+
)
|
| 280 |
+
return CPUInformation(**information)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def override_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None:
|
| 284 |
+
"""
|
| 285 |
+
Overrides whatever NUMA affinity is set for the current process. This is very taxing and requires recalculating the
|
| 286 |
+
affinity to set, ideally you should use `utils.environment.set_numa_affinity` instead.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
local_process_index (int):
|
| 290 |
+
The index of the current process on the current server.
|
| 291 |
+
verbose (bool, *optional*):
|
| 292 |
+
Whether to log out the assignment of each CPU. If `ACCELERATE_DEBUG_MODE` is enabled, will default to True.
|
| 293 |
+
"""
|
| 294 |
+
if verbose is None:
|
| 295 |
+
verbose = parse_flag_from_env("ACCELERATE_DEBUG_MODE", False)
|
| 296 |
+
if torch.cuda.is_available():
|
| 297 |
+
from accelerate.utils import is_pynvml_available
|
| 298 |
+
|
| 299 |
+
if not is_pynvml_available():
|
| 300 |
+
raise ImportError(
|
| 301 |
+
"To set CPU affinity on CUDA GPUs the `nvidia-ml-py` package must be available. (`pip install nvidia-ml-py`)"
|
| 302 |
+
)
|
| 303 |
+
import pynvml as nvml
|
| 304 |
+
|
| 305 |
+
# The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py
|
| 306 |
+
nvml.nvmlInit()
|
| 307 |
+
num_elements = math.ceil(os.cpu_count() / 64)
|
| 308 |
+
handle = nvml.nvmlDeviceGetHandleByIndex(local_process_index)
|
| 309 |
+
affinity_string = ""
|
| 310 |
+
for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements):
|
| 311 |
+
# assume nvml returns list of 64 bit ints
|
| 312 |
+
affinity_string = f"{j:064b}{affinity_string}"
|
| 313 |
+
affinity_list = [int(x) for x in affinity_string]
|
| 314 |
+
affinity_list.reverse() # so core 0 is the 0th element
|
| 315 |
+
affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0]
|
| 316 |
+
os.sched_setaffinity(0, affinity_to_set)
|
| 317 |
+
if verbose:
|
| 318 |
+
cpu_cores = os.sched_getaffinity(0)
|
| 319 |
+
logger.info(f"Assigning {len(cpu_cores)} cpu cores to process {local_process_index}: {cpu_cores}")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@lru_cache
|
| 323 |
+
def set_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None:
|
| 324 |
+
"""
|
| 325 |
+
Assigns the current process to a specific NUMA node. Ideally most efficient when having at least 2 cpus per node.
|
| 326 |
+
|
| 327 |
+
This result is cached between calls. If you want to override it, please use
|
| 328 |
+
`accelerate.utils.environment.override_numa_afifnity`.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
local_process_index (int):
|
| 332 |
+
The index of the current process on the current server.
|
| 333 |
+
verbose (bool, *optional*):
|
| 334 |
+
Whether to print the new cpu cores assignment for each process. If `ACCELERATE_DEBUG_MODE` is enabled, will
|
| 335 |
+
default to True.
|
| 336 |
+
"""
|
| 337 |
+
override_numa_affinity(local_process_index=local_process_index, verbose=verbose)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@contextmanager
|
| 341 |
+
def clear_environment():
|
| 342 |
+
"""
|
| 343 |
+
A context manager that will temporarily clear environment variables.
|
| 344 |
+
|
| 345 |
+
When this context exits, the previous environment variables will be back.
|
| 346 |
+
|
| 347 |
+
Example:
|
| 348 |
+
|
| 349 |
+
```python
|
| 350 |
+
>>> import os
|
| 351 |
+
>>> from accelerate.utils import clear_environment
|
| 352 |
+
|
| 353 |
+
>>> os.environ["FOO"] = "bar"
|
| 354 |
+
>>> with clear_environment():
|
| 355 |
+
... print(os.environ)
|
| 356 |
+
... os.environ["FOO"] = "new_bar"
|
| 357 |
+
... print(os.environ["FOO"])
|
| 358 |
+
{}
|
| 359 |
+
new_bar
|
| 360 |
+
|
| 361 |
+
>>> print(os.environ["FOO"])
|
| 362 |
+
bar
|
| 363 |
+
```
|
| 364 |
+
"""
|
| 365 |
+
_old_os_environ = os.environ.copy()
|
| 366 |
+
os.environ.clear()
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
yield
|
| 370 |
+
finally:
|
| 371 |
+
os.environ.clear() # clear any added keys,
|
| 372 |
+
os.environ.update(_old_os_environ) # then restore previous environment
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@contextmanager
|
| 376 |
+
def patch_environment(**kwargs):
|
| 377 |
+
"""
|
| 378 |
+
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
|
| 379 |
+
|
| 380 |
+
Will convert the values in `kwargs` to strings and upper-case all the keys.
|
| 381 |
+
|
| 382 |
+
Example:
|
| 383 |
+
|
| 384 |
+
```python
|
| 385 |
+
>>> import os
|
| 386 |
+
>>> from accelerate.utils import patch_environment
|
| 387 |
+
|
| 388 |
+
>>> with patch_environment(FOO="bar"):
|
| 389 |
+
... print(os.environ["FOO"]) # prints "bar"
|
| 390 |
+
>>> print(os.environ["FOO"]) # raises KeyError
|
| 391 |
+
```
|
| 392 |
+
"""
|
| 393 |
+
existing_vars = {}
|
| 394 |
+
for key, value in kwargs.items():
|
| 395 |
+
key = key.upper()
|
| 396 |
+
if key in os.environ:
|
| 397 |
+
existing_vars[key] = os.environ[key]
|
| 398 |
+
os.environ[key] = str(value)
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
yield
|
| 402 |
+
finally:
|
| 403 |
+
for key in kwargs:
|
| 404 |
+
key = key.upper()
|
| 405 |
+
if key in existing_vars:
|
| 406 |
+
# restore previous value
|
| 407 |
+
os.environ[key] = existing_vars[key]
|
| 408 |
+
else:
|
| 409 |
+
os.environ.pop(key, None)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def purge_accelerate_environment(func_or_cls):
|
| 413 |
+
"""Decorator to clean up accelerate environment variables set by the decorated class or function.
|
| 414 |
+
|
| 415 |
+
In some circumstances, calling certain classes or functions can result in accelerate env vars being set and not
|
| 416 |
+
being cleaned up afterwards. As an example, when calling:
|
| 417 |
+
|
| 418 |
+
TrainingArguments(fp16=True, ...)
|
| 419 |
+
|
| 420 |
+
The following env var will be set:
|
| 421 |
+
|
| 422 |
+
ACCELERATE_MIXED_PRECISION=fp16
|
| 423 |
+
|
| 424 |
+
This can affect subsequent code, since the env var takes precedence over TrainingArguments(fp16=False). This is
|
| 425 |
+
especially relevant for unit testing, where we want to avoid the individual tests to have side effects on one
|
| 426 |
+
another. Decorate the unit test function or whole class with this decorator to ensure that after each test, the env
|
| 427 |
+
vars are cleaned up. This works for both unittest.TestCase and normal classes (pytest); it also works when
|
| 428 |
+
decorating the parent class.
|
| 429 |
+
|
| 430 |
+
"""
|
| 431 |
+
prefix = "ACCELERATE_"
|
| 432 |
+
|
| 433 |
+
@contextmanager
|
| 434 |
+
def env_var_context():
|
| 435 |
+
# Store existing accelerate env vars
|
| 436 |
+
existing_vars = {k: v for k, v in os.environ.items() if k.startswith(prefix)}
|
| 437 |
+
try:
|
| 438 |
+
yield
|
| 439 |
+
finally:
|
| 440 |
+
# Restore original env vars or remove new ones
|
| 441 |
+
for key in [k for k in os.environ if k.startswith(prefix)]:
|
| 442 |
+
if key in existing_vars:
|
| 443 |
+
os.environ[key] = existing_vars[key]
|
| 444 |
+
else:
|
| 445 |
+
os.environ.pop(key, None)
|
| 446 |
+
|
| 447 |
+
def wrap_function(func):
|
| 448 |
+
@wraps(func)
|
| 449 |
+
def wrapper(*args, **kwargs):
|
| 450 |
+
with env_var_context():
|
| 451 |
+
return func(*args, **kwargs)
|
| 452 |
+
|
| 453 |
+
wrapper._accelerate_is_purged_environment_wrapped = True
|
| 454 |
+
return wrapper
|
| 455 |
+
|
| 456 |
+
if not isinstance(func_or_cls, type):
|
| 457 |
+
return wrap_function(func_or_cls)
|
| 458 |
+
|
| 459 |
+
# Handle classes by wrapping test methods
|
| 460 |
+
def wrap_test_methods(test_class_instance):
|
| 461 |
+
for name in dir(test_class_instance):
|
| 462 |
+
if name.startswith("test"):
|
| 463 |
+
method = getattr(test_class_instance, name)
|
| 464 |
+
if callable(method) and not hasattr(method, "_accelerate_is_purged_environment_wrapped"):
|
| 465 |
+
setattr(test_class_instance, name, wrap_function(method))
|
| 466 |
+
return test_class_instance
|
| 467 |
+
|
| 468 |
+
# Handle inheritance
|
| 469 |
+
wrap_test_methods(func_or_cls)
|
| 470 |
+
func_or_cls.__init_subclass__ = classmethod(lambda cls, **kw: wrap_test_methods(cls))
|
| 471 |
+
return func_or_cls
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/fsdp_utils.py
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import copy
|
| 15 |
+
import functools
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
import shutil
|
| 19 |
+
import warnings
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from collections.abc import Iterable
|
| 22 |
+
from contextlib import nullcontext
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Callable, Union
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from ..logging import get_logger
|
| 29 |
+
from .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
| 30 |
+
from .dataclasses import get_module_class_from_name
|
| 31 |
+
from .modeling import get_non_persistent_buffers, is_peft_model
|
| 32 |
+
from .other import get_module_children_bottom_up, is_compiled_module, save
|
| 33 |
+
from .versions import is_torch_version
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logger = get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def enable_fsdp_ram_efficient_loading():
|
| 40 |
+
"""
|
| 41 |
+
Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
|
| 42 |
+
"""
|
| 43 |
+
# Sets values for `transformers.modeling_utils.is_fsdp_enabled`
|
| 44 |
+
if "ACCELERATE_USE_FSDP" not in os.environ:
|
| 45 |
+
os.environ["ACCELERATE_USE_FSDP"] = "True"
|
| 46 |
+
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "True"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def disable_fsdp_ram_efficient_loading():
|
| 50 |
+
"""
|
| 51 |
+
Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
|
| 52 |
+
"""
|
| 53 |
+
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "False"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _get_model_state_dict(model, adapter_only=False, sd_options=None):
|
| 57 |
+
if adapter_only and is_peft_model(model):
|
| 58 |
+
from peft import get_peft_model_state_dict
|
| 59 |
+
|
| 60 |
+
return get_peft_model_state_dict(model, adapter_name=model.active_adapter)
|
| 61 |
+
|
| 62 |
+
# Invariant: `sd_options` is not None only for FSDP2
|
| 63 |
+
if sd_options is not None:
|
| 64 |
+
from torch.distributed.checkpoint.state_dict import get_model_state_dict
|
| 65 |
+
|
| 66 |
+
return get_model_state_dict(model, options=sd_options)
|
| 67 |
+
else:
|
| 68 |
+
return model.state_dict()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _set_model_state_dict(model, state_dict, adapter_only=False, sd_options=None):
|
| 72 |
+
if adapter_only and is_peft_model(model):
|
| 73 |
+
from peft import set_peft_model_state_dict
|
| 74 |
+
|
| 75 |
+
return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter)
|
| 76 |
+
|
| 77 |
+
# Invariant: `sd_options` is not None only for FSDP2
|
| 78 |
+
if sd_options is not None:
|
| 79 |
+
from torch.distributed.checkpoint.state_dict import set_model_state_dict
|
| 80 |
+
|
| 81 |
+
return set_model_state_dict(model, state_dict, options=sd_options)
|
| 82 |
+
else:
|
| 83 |
+
return model.load_state_dict(state_dict)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _prepare_sd_options(fsdp_plugin):
|
| 87 |
+
sd_options = None
|
| 88 |
+
|
| 89 |
+
# we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0
|
| 90 |
+
if fsdp_plugin.fsdp_version == 2:
|
| 91 |
+
from torch.distributed.checkpoint.state_dict import StateDictOptions
|
| 92 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
| 93 |
+
|
| 94 |
+
sd_options = StateDictOptions(
|
| 95 |
+
full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT,
|
| 96 |
+
cpu_offload=getattr(fsdp_plugin.state_dict_config, "offload_to_cpu", False),
|
| 97 |
+
broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, "rank0_only", False),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return sd_options
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):
|
| 104 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
| 105 |
+
import torch.distributed.checkpoint as dist_cp
|
| 106 |
+
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
| 107 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
| 108 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
| 109 |
+
|
| 110 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 111 |
+
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
| 112 |
+
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
|
| 113 |
+
# so, only enable it when num_processes>1
|
| 114 |
+
is_multi_process = accelerator.num_processes > 1
|
| 115 |
+
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
|
| 116 |
+
fsdp_plugin.state_dict_config.rank0_only = is_multi_process
|
| 117 |
+
|
| 118 |
+
ctx = (
|
| 119 |
+
FSDP.state_dict_type(
|
| 120 |
+
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
|
| 121 |
+
)
|
| 122 |
+
if fsdp_plugin.fsdp_version == 1
|
| 123 |
+
else nullcontext()
|
| 124 |
+
)
|
| 125 |
+
sd_options = _prepare_sd_options(fsdp_plugin)
|
| 126 |
+
|
| 127 |
+
with ctx:
|
| 128 |
+
state_dict = _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)
|
| 129 |
+
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
| 130 |
+
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
|
| 131 |
+
output_model_file = os.path.join(output_dir, weights_name)
|
| 132 |
+
if accelerator.process_index == 0:
|
| 133 |
+
logger.info(f"Saving model to {output_model_file}")
|
| 134 |
+
torch.save(state_dict, output_model_file)
|
| 135 |
+
logger.info(f"Model saved to {output_model_file}")
|
| 136 |
+
# Invariant: `LOCAL_STATE_DICT` is never possible with `FSDP2`
|
| 137 |
+
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
|
| 138 |
+
weights_name = (
|
| 139 |
+
f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
|
| 140 |
+
if model_index == 0
|
| 141 |
+
else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
|
| 142 |
+
)
|
| 143 |
+
output_model_file = os.path.join(output_dir, weights_name)
|
| 144 |
+
logger.info(f"Saving model to {output_model_file}")
|
| 145 |
+
torch.save(state_dict, output_model_file)
|
| 146 |
+
logger.info(f"Model saved to {output_model_file}")
|
| 147 |
+
elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:
|
| 148 |
+
ckpt_dir = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{model_index}")
|
| 149 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 150 |
+
logger.info(f"Saving model to {ckpt_dir}")
|
| 151 |
+
state_dict = {"model": state_dict}
|
| 152 |
+
|
| 153 |
+
dist_cp.save(
|
| 154 |
+
state_dict=state_dict,
|
| 155 |
+
storage_writer=dist_cp.FileSystemWriter(ckpt_dir),
|
| 156 |
+
planner=DefaultSavePlanner(),
|
| 157 |
+
)
|
| 158 |
+
logger.info(f"Model saved to {ckpt_dir}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):
|
| 162 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
| 163 |
+
import torch.distributed.checkpoint as dist_cp
|
| 164 |
+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
| 165 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
| 166 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
| 167 |
+
|
| 168 |
+
accelerator.wait_for_everyone()
|
| 169 |
+
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
| 170 |
+
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
|
| 171 |
+
# so, only enable it when num_processes>1
|
| 172 |
+
is_multi_process = accelerator.num_processes > 1
|
| 173 |
+
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
|
| 174 |
+
fsdp_plugin.state_dict_config.rank0_only = is_multi_process
|
| 175 |
+
|
| 176 |
+
ctx = (
|
| 177 |
+
FSDP.state_dict_type(
|
| 178 |
+
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
|
| 179 |
+
)
|
| 180 |
+
if fsdp_plugin.fsdp_version == 1
|
| 181 |
+
else nullcontext()
|
| 182 |
+
)
|
| 183 |
+
sd_options = _prepare_sd_options(fsdp_plugin)
|
| 184 |
+
with ctx:
|
| 185 |
+
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
| 186 |
+
if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2:
|
| 187 |
+
if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
|
| 190 |
+
"initializing FSDP object"
|
| 191 |
+
)
|
| 192 |
+
return
|
| 193 |
+
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
|
| 194 |
+
input_model_file = os.path.join(input_dir, weights_name)
|
| 195 |
+
logger.info(f"Loading model from {input_model_file}")
|
| 196 |
+
# we want an empty state dict for FSDP2 as we use `broadcast_from_rank0`
|
| 197 |
+
load_model = not accelerator.is_fsdp2 or accelerator.is_main_process
|
| 198 |
+
if load_model:
|
| 199 |
+
state_dict = torch.load(input_model_file, weights_only=True)
|
| 200 |
+
else:
|
| 201 |
+
state_dict = {}
|
| 202 |
+
logger.info(f"Model loaded from {input_model_file}")
|
| 203 |
+
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
|
| 204 |
+
weights_name = (
|
| 205 |
+
f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
|
| 206 |
+
if model_index == 0
|
| 207 |
+
else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
|
| 208 |
+
)
|
| 209 |
+
input_model_file = os.path.join(input_dir, weights_name)
|
| 210 |
+
logger.info(f"Loading model from {input_model_file}")
|
| 211 |
+
state_dict = torch.load(input_model_file, weights_only=True)
|
| 212 |
+
logger.info(f"Model loaded from {input_model_file}")
|
| 213 |
+
elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:
|
| 214 |
+
ckpt_dir = (
|
| 215 |
+
os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{model_index}")
|
| 216 |
+
if f"{FSDP_MODEL_NAME}" not in input_dir
|
| 217 |
+
else input_dir
|
| 218 |
+
)
|
| 219 |
+
logger.info(f"Loading model from {ckpt_dir}")
|
| 220 |
+
state_dict = {"model": _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)}
|
| 221 |
+
dist_cp.load(
|
| 222 |
+
state_dict=state_dict,
|
| 223 |
+
storage_reader=dist_cp.FileSystemReader(ckpt_dir),
|
| 224 |
+
planner=DefaultLoadPlanner(),
|
| 225 |
+
)
|
| 226 |
+
state_dict = state_dict["model"]
|
| 227 |
+
logger.info(f"Model loaded from {ckpt_dir}")
|
| 228 |
+
|
| 229 |
+
load_result = _set_model_state_dict(model, state_dict, adapter_only=adapter_only, sd_options=sd_options)
|
| 230 |
+
return load_result
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):
|
| 234 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
| 235 |
+
import torch.distributed.checkpoint as dist_cp
|
| 236 |
+
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
| 237 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
| 238 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
| 239 |
+
|
| 240 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 241 |
+
|
| 242 |
+
ctx = (
|
| 243 |
+
FSDP.state_dict_type(
|
| 244 |
+
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
|
| 245 |
+
)
|
| 246 |
+
if fsdp_plugin.fsdp_version == 1
|
| 247 |
+
else nullcontext()
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
sd_options = _prepare_sd_options(fsdp_plugin)
|
| 251 |
+
|
| 252 |
+
with ctx:
|
| 253 |
+
if fsdp_plugin.fsdp_version == 2:
|
| 254 |
+
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
|
| 255 |
+
|
| 256 |
+
optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options)
|
| 257 |
+
else:
|
| 258 |
+
optim_state = FSDP.optim_state_dict(model, optimizer)
|
| 259 |
+
|
| 260 |
+
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
| 261 |
+
if accelerator.process_index == 0:
|
| 262 |
+
optim_state_name = (
|
| 263 |
+
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
|
| 264 |
+
)
|
| 265 |
+
output_optimizer_file = os.path.join(output_dir, optim_state_name)
|
| 266 |
+
logger.info(f"Saving Optimizer state to {output_optimizer_file}")
|
| 267 |
+
torch.save(optim_state, output_optimizer_file)
|
| 268 |
+
logger.info(f"Optimizer state saved in {output_optimizer_file}")
|
| 269 |
+
else:
|
| 270 |
+
ckpt_dir = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
|
| 271 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 272 |
+
logger.info(f"Saving Optimizer state to {ckpt_dir}")
|
| 273 |
+
dist_cp.save(
|
| 274 |
+
state_dict={"optimizer": optim_state},
|
| 275 |
+
storage_writer=dist_cp.FileSystemWriter(ckpt_dir),
|
| 276 |
+
planner=DefaultSavePlanner(),
|
| 277 |
+
)
|
| 278 |
+
logger.info(f"Optimizer state saved in {ckpt_dir}")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
|
| 282 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
| 283 |
+
import torch.distributed.checkpoint as dist_cp
|
| 284 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
| 285 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
| 286 |
+
|
| 287 |
+
accelerator.wait_for_everyone()
|
| 288 |
+
ctx = (
|
| 289 |
+
FSDP.state_dict_type(
|
| 290 |
+
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
|
| 291 |
+
)
|
| 292 |
+
if fsdp_plugin.fsdp_version == 1
|
| 293 |
+
else nullcontext()
|
| 294 |
+
)
|
| 295 |
+
sd_options = _prepare_sd_options(fsdp_plugin)
|
| 296 |
+
with ctx:
|
| 297 |
+
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
| 298 |
+
optim_state = None
|
| 299 |
+
if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
|
| 300 |
+
optimizer_name = (
|
| 301 |
+
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
|
| 302 |
+
)
|
| 303 |
+
input_optimizer_file = os.path.join(input_dir, optimizer_name)
|
| 304 |
+
logger.info(f"Loading Optimizer state from {input_optimizer_file}")
|
| 305 |
+
optim_state = torch.load(input_optimizer_file, weights_only=True)
|
| 306 |
+
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
|
| 307 |
+
else:
|
| 308 |
+
ckpt_dir = (
|
| 309 |
+
os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
|
| 310 |
+
if f"{OPTIMIZER_NAME}" not in input_dir
|
| 311 |
+
else input_dir
|
| 312 |
+
)
|
| 313 |
+
logger.info(f"Loading Optimizer from {ckpt_dir}")
|
| 314 |
+
optim_state = {"optimizer": optimizer.state_dict()}
|
| 315 |
+
dist_cp.load(
|
| 316 |
+
optim_state,
|
| 317 |
+
checkpoint_id=ckpt_dir,
|
| 318 |
+
storage_reader=dist_cp.FileSystemReader(ckpt_dir),
|
| 319 |
+
)
|
| 320 |
+
optim_state = optim_state["optimizer"]
|
| 321 |
+
logger.info(f"Optimizer loaded from {ckpt_dir}")
|
| 322 |
+
|
| 323 |
+
if fsdp_plugin.fsdp_version == 1:
|
| 324 |
+
flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)
|
| 325 |
+
optimizer.load_state_dict(flattened_osd)
|
| 326 |
+
else:
|
| 327 |
+
from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict
|
| 328 |
+
|
| 329 |
+
set_optimizer_state_dict(model, optimizer, optim_state, options=sd_options)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True):
|
| 333 |
+
"""
|
| 334 |
+
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
|
| 335 |
+
|
| 336 |
+
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
|
| 337 |
+
"""
|
| 338 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
| 339 |
+
import torch.distributed.checkpoint as dist_cp
|
| 340 |
+
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
| 341 |
+
|
| 342 |
+
state_dict = {}
|
| 343 |
+
save_path = Path(save_path)
|
| 344 |
+
save_path.mkdir(exist_ok=True)
|
| 345 |
+
dist_cp_format_utils._load_state_dict(
|
| 346 |
+
state_dict,
|
| 347 |
+
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
|
| 348 |
+
planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),
|
| 349 |
+
no_dist=True,
|
| 350 |
+
)
|
| 351 |
+
save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME
|
| 352 |
+
|
| 353 |
+
# To handle if state is a dict like {model: {...}}
|
| 354 |
+
if len(state_dict.keys()) == 1:
|
| 355 |
+
state_dict = state_dict[list(state_dict)[0]]
|
| 356 |
+
save(state_dict, save_path, safe_serialization=safe_serialization)
|
| 357 |
+
return save_path
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def merge_fsdp_weights(
|
| 361 |
+
checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False
|
| 362 |
+
):
|
| 363 |
+
"""
|
| 364 |
+
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
|
| 365 |
+
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
|
| 366 |
+
`safe_serialization` else `pytorch_model.bin`.
|
| 367 |
+
|
| 368 |
+
Note: this is a CPU-bound process.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
checkpoint_dir (`str`):
|
| 372 |
+
The directory containing the FSDP checkpoints (can be either the model or optimizer).
|
| 373 |
+
output_path (`str`):
|
| 374 |
+
The path to save the merged checkpoint.
|
| 375 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
| 376 |
+
Whether to save the merged weights with safetensors (recommended).
|
| 377 |
+
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
|
| 378 |
+
Whether to remove the checkpoint directory after merging.
|
| 379 |
+
"""
|
| 380 |
+
checkpoint_dir = Path(checkpoint_dir)
|
| 381 |
+
from accelerate.state import PartialState
|
| 382 |
+
|
| 383 |
+
if not is_torch_version(">=", "2.3.0"):
|
| 384 |
+
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
| 385 |
+
|
| 386 |
+
# Verify that the checkpoint directory exists
|
| 387 |
+
if not checkpoint_dir.exists():
|
| 388 |
+
model_path_exists = (checkpoint_dir / "pytorch_model_fsdp_0").exists()
|
| 389 |
+
optimizer_path_exists = (checkpoint_dir / "optimizer_0").exists()
|
| 390 |
+
err = f"Tried to load from {checkpoint_dir} but couldn't find a valid metadata file."
|
| 391 |
+
if model_path_exists and optimizer_path_exists:
|
| 392 |
+
err += " However, potential model and optimizer checkpoint directories exist."
|
| 393 |
+
err += f"Please pass in either {checkpoint_dir}/pytorch_model_fsdp_0 or {checkpoint_dir}/optimizer_0"
|
| 394 |
+
err += "instead."
|
| 395 |
+
elif model_path_exists:
|
| 396 |
+
err += " However, a potential model checkpoint directory exists."
|
| 397 |
+
err += f"Please try passing in {checkpoint_dir}/pytorch_model_fsdp_0 instead."
|
| 398 |
+
elif optimizer_path_exists:
|
| 399 |
+
err += " However, a potential optimizer checkpoint directory exists."
|
| 400 |
+
err += f"Please try passing in {checkpoint_dir}/optimizer_0 instead."
|
| 401 |
+
raise ValueError(err)
|
| 402 |
+
|
| 403 |
+
# To setup `save` to work
|
| 404 |
+
state = PartialState()
|
| 405 |
+
if state.is_main_process:
|
| 406 |
+
logger.info(f"Merging FSDP weights from {checkpoint_dir}")
|
| 407 |
+
save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)
|
| 408 |
+
logger.info(f"Successfully merged FSDP weights and saved to {save_path}")
|
| 409 |
+
if remove_checkpoint_dir:
|
| 410 |
+
logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
|
| 411 |
+
shutil.rmtree(checkpoint_dir)
|
| 412 |
+
state.wait_for_everyone()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.device):
|
| 416 |
+
_tied_names = getattr(model, "_tied_weights_keys", None)
|
| 417 |
+
if not _tied_names:
|
| 418 |
+
# if no tied names just passthrough
|
| 419 |
+
return param_init_fn
|
| 420 |
+
|
| 421 |
+
# get map of parameter instances to params.
|
| 422 |
+
# - needed for replacement later
|
| 423 |
+
_tied_params = {}
|
| 424 |
+
for name in _tied_names:
|
| 425 |
+
name = name.split(".")
|
| 426 |
+
name, param_name = ".".join(name[:-1]), name[-1]
|
| 427 |
+
mod = model.get_submodule(name)
|
| 428 |
+
param = getattr(mod, param_name)
|
| 429 |
+
|
| 430 |
+
_tied_params[id(param)] = None # placeholder for the param first
|
| 431 |
+
|
| 432 |
+
# build param_init_fn for the case with tied params
|
| 433 |
+
def param_init_fn_tied_param(module: torch.nn.Module):
|
| 434 |
+
# track which params to tie
|
| 435 |
+
# - usually only 1, but for completeness consider > 1
|
| 436 |
+
params_to_tie = defaultdict(list)
|
| 437 |
+
for n, param in module.named_parameters(recurse=False):
|
| 438 |
+
if id(param) in _tied_params:
|
| 439 |
+
params_to_tie[id(param)].append(n)
|
| 440 |
+
|
| 441 |
+
# call the param init fn, which potentially re-allocates the
|
| 442 |
+
# parameters
|
| 443 |
+
module = param_init_fn(module)
|
| 444 |
+
|
| 445 |
+
# search the parameters again and tie them up again
|
| 446 |
+
for id_key, _param_names in params_to_tie.items():
|
| 447 |
+
for param_name in _param_names:
|
| 448 |
+
param = _tied_params[id_key]
|
| 449 |
+
if param is None:
|
| 450 |
+
# everything will be tied to the first time the
|
| 451 |
+
# param is observed
|
| 452 |
+
_tied_params[id_key] = getattr(module, param_name)
|
| 453 |
+
else:
|
| 454 |
+
setattr(module, param_name, param) # tie
|
| 455 |
+
|
| 456 |
+
return module
|
| 457 |
+
|
| 458 |
+
return param_init_fn_tied_param
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
| 462 |
+
"""
|
| 463 |
+
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
| 464 |
+
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
accelerator (`Accelerator`): The accelerator instance
|
| 468 |
+
model (`torch.nn.Module`):
|
| 469 |
+
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
|
| 470 |
+
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
| 471 |
+
"""
|
| 472 |
+
import torch.distributed as dist
|
| 473 |
+
from torch.distributed.tensor import DTensor, distribute_tensor
|
| 474 |
+
|
| 475 |
+
# Model was previously copied to meta device
|
| 476 |
+
meta_sharded_sd = model.state_dict()
|
| 477 |
+
sharded_sd = {}
|
| 478 |
+
|
| 479 |
+
# Rank 0 distributes the full state dict to other ranks
|
| 480 |
+
def _infer_parameter_dtype(model, param_name, empty_param):
|
| 481 |
+
try:
|
| 482 |
+
old_param = model.get_parameter_or_buffer(param_name)
|
| 483 |
+
except AttributeError:
|
| 484 |
+
# Need this for LORA, as there some params are not *parameters* of sorts
|
| 485 |
+
base_param_name, local_param_name = param_name.rsplit(".", 1)
|
| 486 |
+
submodule = model.get_submodule(base_param_name)
|
| 487 |
+
old_param = getattr(submodule, local_param_name)
|
| 488 |
+
|
| 489 |
+
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
| 490 |
+
casting_dtype = None
|
| 491 |
+
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
| 492 |
+
|
| 493 |
+
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
| 494 |
+
casting_dtype = old_param.dtype
|
| 495 |
+
|
| 496 |
+
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
| 497 |
+
|
| 498 |
+
def _cast_and_contiguous(tensor, to_contiguous, dtype):
|
| 499 |
+
if dtype is not None:
|
| 500 |
+
tensor = tensor.to(dtype=dtype)
|
| 501 |
+
if to_contiguous:
|
| 502 |
+
tensor = tensor.contiguous()
|
| 503 |
+
return tensor
|
| 504 |
+
|
| 505 |
+
if accelerator.is_main_process:
|
| 506 |
+
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
|
| 507 |
+
device_mesh = sharded_param.device_mesh
|
| 508 |
+
full_param = full_param.detach().to(device_mesh.device_type)
|
| 509 |
+
if isinstance(full_param, DTensor):
|
| 510 |
+
# dist.broadcast() only supports torch.Tensor.
|
| 511 |
+
# After prepare_tp(), model parameters may become DTensor.
|
| 512 |
+
# To broadcast such a parameter, convert it to a local tensor first.
|
| 513 |
+
full_param = full_param.to_local()
|
| 514 |
+
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
|
| 515 |
+
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
|
| 516 |
+
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
| 517 |
+
model,
|
| 518 |
+
param_name,
|
| 519 |
+
full_param,
|
| 520 |
+
)
|
| 521 |
+
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
|
| 522 |
+
sharded_sd[param_name] = sharded_tensor
|
| 523 |
+
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
|
| 524 |
+
else:
|
| 525 |
+
for param_name, sharded_param in meta_sharded_sd.items():
|
| 526 |
+
device_mesh = sharded_param.device_mesh
|
| 527 |
+
full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
|
| 528 |
+
dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
|
| 529 |
+
sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
|
| 530 |
+
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
| 531 |
+
model,
|
| 532 |
+
param_name,
|
| 533 |
+
full_tensor,
|
| 534 |
+
)
|
| 535 |
+
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
|
| 536 |
+
sharded_sd[param_name] = sharded_tensor
|
| 537 |
+
|
| 538 |
+
# we set `assign=True` because our params are on meta device
|
| 539 |
+
model.load_state_dict(sharded_sd, assign=True)
|
| 540 |
+
return model
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping: dict):
|
| 544 |
+
"""
|
| 545 |
+
Switches the parameters of the optimizer to new ones (sharded parameters in usual case). This function modifies the
|
| 546 |
+
optimizer in-place.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
optimizer (`torch.optim.Optimizer`): Optimizer instance which contains the original model parameters
|
| 550 |
+
mapping (`dict`): Mapping from the original parameter (specified by `data_ptr`) to the sharded parameter
|
| 551 |
+
|
| 552 |
+
Raises:
|
| 553 |
+
KeyError:
|
| 554 |
+
If a parameter in the optimizer couldn't be switched to its sharded version. This should never happen and
|
| 555 |
+
indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically
|
| 556 |
+
correct and weights wouldn't get updated.
|
| 557 |
+
"""
|
| 558 |
+
from torch.distributed.tensor import DTensor
|
| 559 |
+
|
| 560 |
+
accessor_mapping = {}
|
| 561 |
+
|
| 562 |
+
accessor_mapping[DTensor] = "_local_tensor"
|
| 563 |
+
try:
|
| 564 |
+
for param_group in optimizer.param_groups:
|
| 565 |
+
param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
|
| 566 |
+
except KeyError:
|
| 567 |
+
# This shouldn't ever happen, but we want to fail here else training wouldn't be numerically correct
|
| 568 |
+
# This basically means that we're missing a mapping from the original parameter to the sharded parameter
|
| 569 |
+
raise KeyError(
|
| 570 |
+
"A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def fsdp2_apply_ac(accelerator, model: torch.nn.Module):
|
| 575 |
+
"""
|
| 576 |
+
Applies the activation checkpointing to the model.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
accelerator (`Accelerator`): The accelerator instance
|
| 580 |
+
model (`torch.nn.Module`): The model to apply the activation checkpointing to
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
`torch.nn.Module`: The model with the activation checkpointing applied
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| 587 |
+
checkpoint_wrapper,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(accelerator.state.fsdp_plugin, model)
|
| 591 |
+
|
| 592 |
+
for layer_name, layer in get_module_children_bottom_up(model, return_fqns=True)[:-1]:
|
| 593 |
+
if len(layer_name.split(".")) > 1:
|
| 594 |
+
parent_name, child_name = layer_name.rsplit(".", 1)
|
| 595 |
+
else:
|
| 596 |
+
parent_name = None
|
| 597 |
+
child_name = layer_name
|
| 598 |
+
|
| 599 |
+
parent_module = model.get_submodule(parent_name) if parent_name else model
|
| 600 |
+
if auto_wrap_policy_func(parent_module):
|
| 601 |
+
layer = checkpoint_wrapper(layer, preserve_rng_state=False)
|
| 602 |
+
parent_module.register_module(child_name, layer)
|
| 603 |
+
|
| 604 |
+
return model
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
| 608 |
+
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
accelerator (`Accelerator`): The accelerator instance
|
| 612 |
+
model (`torch.nn.Module`): The model to prepare
|
| 613 |
+
|
| 614 |
+
Returns:
|
| 615 |
+
`torch.nn.Module`: Prepared model
|
| 616 |
+
"""
|
| 617 |
+
from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard
|
| 618 |
+
|
| 619 |
+
is_type_fsdp = isinstance(model, FSDPModule) or (
|
| 620 |
+
is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)
|
| 621 |
+
)
|
| 622 |
+
if is_type_fsdp:
|
| 623 |
+
return model
|
| 624 |
+
|
| 625 |
+
fsdp2_plugin = accelerator.state.fsdp_plugin
|
| 626 |
+
|
| 627 |
+
fsdp2_plugin.set_auto_wrap_policy(model)
|
| 628 |
+
|
| 629 |
+
original_sd = model.state_dict()
|
| 630 |
+
mesh = getattr(accelerator, "torch_device_mesh", None)
|
| 631 |
+
|
| 632 |
+
fsdp2_kwargs = {
|
| 633 |
+
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
| 634 |
+
"offload_policy": fsdp2_plugin.cpu_offload,
|
| 635 |
+
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
| 636 |
+
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
| 637 |
+
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
|
| 638 |
+
"ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device),
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
model_has_params4bit = False
|
| 642 |
+
for name, param in model.named_parameters():
|
| 643 |
+
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
| 644 |
+
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
|
| 645 |
+
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
|
| 646 |
+
if param.__class__.__name__ == "Params4bit":
|
| 647 |
+
model_has_params4bit = True
|
| 648 |
+
break
|
| 649 |
+
|
| 650 |
+
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
| 651 |
+
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
| 652 |
+
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
| 653 |
+
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU
|
| 654 |
+
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
| 655 |
+
|
| 656 |
+
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
| 657 |
+
# Also, these buffers aren't getting sharded by default
|
| 658 |
+
# We get the FQNs of all non-persistent buffers, to re-register them after
|
| 659 |
+
non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True)
|
| 660 |
+
original_non_persistent_buffers = copy.deepcopy(
|
| 661 |
+
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
|
| 662 |
+
)
|
| 663 |
+
# We move the model to meta device, as then sharding happens on meta device
|
| 664 |
+
model = model.to(torch.device("meta"))
|
| 665 |
+
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
|
| 666 |
+
# We assume `transformers` models have a `tie_weights` method if they support it
|
| 667 |
+
if hasattr(model, "tie_weights"):
|
| 668 |
+
model.tie_weights()
|
| 669 |
+
|
| 670 |
+
auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
| 671 |
+
if auto_wrap_policy_func is not None:
|
| 672 |
+
# We skip the model itself, as that one is always wrapped
|
| 673 |
+
for module in get_module_children_bottom_up(model)[:-1]:
|
| 674 |
+
if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule):
|
| 675 |
+
fully_shard(module, **fsdp2_kwargs)
|
| 676 |
+
|
| 677 |
+
if not isinstance(model, FSDPModule):
|
| 678 |
+
fully_shard(model, **fsdp2_kwargs)
|
| 679 |
+
|
| 680 |
+
if fsdp2_plugin.cpu_ram_efficient_loading:
|
| 681 |
+
# If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights
|
| 682 |
+
# Other ranks have an empty model on `meta` device, so we need to distribute the weights properly
|
| 683 |
+
fsdp2_load_full_state_dict(accelerator, model, original_sd)
|
| 684 |
+
|
| 685 |
+
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
| 686 |
+
# We re-register the buffers, as they may not be in the state_dict
|
| 687 |
+
for fqn, buffer_tensor in original_non_persistent_buffers.items():
|
| 688 |
+
buffer_tensor = buffer_tensor.to(accelerator.device)
|
| 689 |
+
|
| 690 |
+
if "." in fqn:
|
| 691 |
+
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
|
| 692 |
+
parent_module = model.get_submodule(parent_fqn)
|
| 693 |
+
else:
|
| 694 |
+
local_buffer_name = fqn
|
| 695 |
+
parent_module = model
|
| 696 |
+
|
| 697 |
+
parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False)
|
| 698 |
+
|
| 699 |
+
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
|
| 700 |
+
# Needs to be called both here and above
|
| 701 |
+
# removing this call makes the have slightly different loss
|
| 702 |
+
# removing the call above leads to extra memory usage as explained in the comment above
|
| 703 |
+
if hasattr(model, "tie_weights"):
|
| 704 |
+
model.tie_weights()
|
| 705 |
+
|
| 706 |
+
# There is no `dtype` attribution for nn.Module
|
| 707 |
+
# Set it to None if it doesn't exist and do the upcast always
|
| 708 |
+
model_dtype = getattr(model, "dtype", None)
|
| 709 |
+
if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
|
| 710 |
+
# We upcast the model according to `deepspeed`'s implementation
|
| 711 |
+
# More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
|
| 712 |
+
model = model.to(torch.float32)
|
| 713 |
+
if accelerator.is_main_process:
|
| 714 |
+
# TODO(siro1): Add a warning for each parameter that was upcasted
|
| 715 |
+
warnings.warn(
|
| 716 |
+
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints."
|
| 717 |
+
)
|
| 718 |
+
return model
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module) -> Callable[[torch.nn.Module], bool]:
|
| 722 |
+
"""Prepares the auto wrap policy based on its type, done to mimic the behaviour of FSDP1 auto wrap policy.
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
fsdp2_plugin (`FullyShardedDataParallelPlugin`):
|
| 726 |
+
Instance of `FullyShardedDataParallelPlugin` containing the configuration options
|
| 727 |
+
auto_wrap_policy_type (`str`):
|
| 728 |
+
Either `transformer` or `size`
|
| 729 |
+
model (`torch.nn.Module`):
|
| 730 |
+
The model to wrap
|
| 731 |
+
|
| 732 |
+
Returns:
|
| 733 |
+
`Callable[[torch.nn.Module], bool]`:
|
| 734 |
+
The auto wrap policy function to be applied to the model
|
| 735 |
+
"""
|
| 736 |
+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
| 737 |
+
|
| 738 |
+
fn = fsdp2_plugin.auto_wrap_policy
|
| 739 |
+
|
| 740 |
+
if isinstance(fn, functools.partial):
|
| 741 |
+
fn = fn.func
|
| 742 |
+
|
| 743 |
+
if fn is transformer_auto_wrap_policy:
|
| 744 |
+
no_split_modules = getattr(model, "_no_split_modules", None)
|
| 745 |
+
if no_split_modules is None:
|
| 746 |
+
no_split_modules = []
|
| 747 |
+
transformer_cls_names_to_wrap = list(no_split_modules)
|
| 748 |
+
if fsdp2_plugin.transformer_cls_names_to_wrap is not None:
|
| 749 |
+
transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap
|
| 750 |
+
transformer_cls_to_wrap = set()
|
| 751 |
+
|
| 752 |
+
for layer_class in transformer_cls_names_to_wrap:
|
| 753 |
+
transformer_cls = get_module_class_from_name(model, layer_class)
|
| 754 |
+
if transformer_cls is None:
|
| 755 |
+
raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.")
|
| 756 |
+
transformer_cls_to_wrap.add(transformer_cls)
|
| 757 |
+
|
| 758 |
+
def policy(module: torch.nn.Module) -> bool:
|
| 759 |
+
if fsdp2_plugin.transformer_cls_names_to_wrap is None:
|
| 760 |
+
return False
|
| 761 |
+
return isinstance(module, tuple(transformer_cls_to_wrap))
|
| 762 |
+
|
| 763 |
+
elif fn is size_based_auto_wrap_policy:
|
| 764 |
+
|
| 765 |
+
def policy(module: torch.nn.Module) -> bool:
|
| 766 |
+
module_num_params = sum(p.numel() for p in module.parameters())
|
| 767 |
+
return module_num_params > fsdp2_plugin.min_num_params
|
| 768 |
+
else:
|
| 769 |
+
return None
|
| 770 |
+
|
| 771 |
+
return policy
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def get_fsdp2_grad_scaler(**kwargs):
|
| 775 |
+
"""
|
| 776 |
+
Returns a `GradScaler` for FSDP2, as the current implementation of `get_grad_scaler` doesn't accept other args. We
|
| 777 |
+
need this as current `get_grad_scaler` accepts only `distributed_type` as arg, which doesn't differentiate between
|
| 778 |
+
FSDP1 and FSDP2
|
| 779 |
+
"""
|
| 780 |
+
from torch.amp.grad_scaler import GradScaler
|
| 781 |
+
|
| 782 |
+
return GradScaler(**kwargs)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def fsdp2_canonicalize_names(named_params: dict) -> dict:
|
| 786 |
+
"""Removes parameter name modifiers in order to map them back to their original names.
|
| 787 |
+
|
| 788 |
+
See huggingface/accelerate#3554 for more context.
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
named_params (`dict`): The named parameters dictionary to canonicalize.
|
| 792 |
+
|
| 793 |
+
Returns:
|
| 794 |
+
`dict`: The canonicalized named parameters dictionary
|
| 795 |
+
"""
|
| 796 |
+
named_params = {k.replace("._checkpoint_wrapped_module", ""): v for k, v in named_params.items()}
|
| 797 |
+
named_params = {
|
| 798 |
+
k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k: v for k, v in named_params.items()
|
| 799 |
+
}
|
| 800 |
+
named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
|
| 801 |
+
return named_params
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def get_parameters_from_modules(
|
| 805 |
+
modules: Union[Iterable[torch.nn.Module], str], model, device
|
| 806 |
+
) -> set[torch.nn.Parameter]:
|
| 807 |
+
"""Converts modules to parameters where modules can be a string or list of torch.nn.Module
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
|
| 811 |
+
|
| 812 |
+
Returns:
|
| 813 |
+
`set[torch.nn.Parameter]`: List of parameters
|
| 814 |
+
"""
|
| 815 |
+
if modules is None:
|
| 816 |
+
return set()
|
| 817 |
+
parameters = []
|
| 818 |
+
# code taken from accelerate while preparing kwargs for FSDP
|
| 819 |
+
if isinstance(modules, str):
|
| 820 |
+
reg = re.compile(modules)
|
| 821 |
+
mapped_modules = []
|
| 822 |
+
for name, module in model.named_modules():
|
| 823 |
+
if reg.fullmatch(name):
|
| 824 |
+
module.to(device)
|
| 825 |
+
mapped_modules.append(module)
|
| 826 |
+
modules = mapped_modules
|
| 827 |
+
for module in modules:
|
| 828 |
+
parameters.extend(list(module.parameters()))
|
| 829 |
+
return set(parameters)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/imports.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import importlib
|
| 16 |
+
import importlib.metadata
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import warnings
|
| 20 |
+
from functools import lru_cache, wraps
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from packaging import version
|
| 24 |
+
from packaging.version import parse
|
| 25 |
+
|
| 26 |
+
from .environment import parse_flag_from_env, patch_environment, str_to_bool
|
| 27 |
+
from .versions import compare_versions, is_torch_version
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
|
| 31 |
+
USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True)
|
| 32 |
+
|
| 33 |
+
_torch_xla_available = False
|
| 34 |
+
if USE_TORCH_XLA:
|
| 35 |
+
try:
|
| 36 |
+
import torch_xla.core.xla_model as xm # noqa: F401
|
| 37 |
+
import torch_xla.runtime
|
| 38 |
+
|
| 39 |
+
_torch_xla_available = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
# Keep it for is_tpu_available. It will be removed along with is_tpu_available.
|
| 44 |
+
_tpu_available = _torch_xla_available
|
| 45 |
+
|
| 46 |
+
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
| 47 |
+
_torch_distributed_available = torch.distributed.is_available()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _is_package_available(pkg_name, metadata_name=None):
|
| 51 |
+
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
|
| 52 |
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
| 53 |
+
if package_exists:
|
| 54 |
+
try:
|
| 55 |
+
# Some libraries have different names in the metadata
|
| 56 |
+
_ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name)
|
| 57 |
+
return True
|
| 58 |
+
except importlib.metadata.PackageNotFoundError:
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def is_torch_distributed_available() -> bool:
|
| 63 |
+
return _torch_distributed_available
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def is_xccl_available():
|
| 67 |
+
if is_torch_version(">=", "2.7.0"):
|
| 68 |
+
return torch.distributed.distributed_c10d.is_xccl_available()
|
| 69 |
+
if is_ipex_available():
|
| 70 |
+
return False
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def is_ccl_available():
|
| 75 |
+
try:
|
| 76 |
+
pass
|
| 77 |
+
except ImportError:
|
| 78 |
+
print(
|
| 79 |
+
"Intel(R) oneCCL Bindings for PyTorch* is required to run DDP on Intel(R) XPUs, but it is not"
|
| 80 |
+
" detected. If you see \"ValueError: Invalid backend: 'ccl'\" error, please install Intel(R) oneCCL"
|
| 81 |
+
" Bindings for PyTorch*."
|
| 82 |
+
)
|
| 83 |
+
return importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_ccl_version():
|
| 87 |
+
return importlib.metadata.version("oneccl_bind_pt")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def is_import_timer_available():
|
| 91 |
+
return _is_package_available("import_timer")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def is_pynvml_available():
|
| 95 |
+
return _is_package_available("pynvml") or _is_package_available("pynvml", "nvidia-ml-py")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def is_pytest_available():
|
| 99 |
+
return _is_package_available("pytest")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def is_msamp_available():
|
| 103 |
+
return _is_package_available("msamp", "ms-amp")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def is_schedulefree_available():
|
| 107 |
+
return _is_package_available("schedulefree")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def is_transformer_engine_available():
|
| 111 |
+
if is_hpu_available():
|
| 112 |
+
return _is_package_available("intel_transformer_engine", "intel-transformer-engine")
|
| 113 |
+
else:
|
| 114 |
+
return _is_package_available("transformer_engine", "transformer-engine")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def is_transformer_engine_mxfp8_available():
|
| 118 |
+
if _is_package_available("transformer_engine", "transformer-engine"):
|
| 119 |
+
import transformer_engine.pytorch as te
|
| 120 |
+
|
| 121 |
+
return te.fp8.check_mxfp8_support()[0]
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def is_lomo_available():
|
| 126 |
+
return _is_package_available("lomo_optim")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def is_cuda_available():
|
| 130 |
+
"""
|
| 131 |
+
Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda
|
| 132 |
+
uninitialized.
|
| 133 |
+
"""
|
| 134 |
+
with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"):
|
| 135 |
+
available = torch.cuda.is_available()
|
| 136 |
+
|
| 137 |
+
return available
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@lru_cache
|
| 141 |
+
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
|
| 142 |
+
"""
|
| 143 |
+
Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
|
| 144 |
+
the USE_TORCH_XLA to false.
|
| 145 |
+
"""
|
| 146 |
+
assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
|
| 147 |
+
|
| 148 |
+
if not _torch_xla_available:
|
| 149 |
+
return False
|
| 150 |
+
elif check_is_gpu:
|
| 151 |
+
return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
|
| 152 |
+
elif check_is_tpu:
|
| 153 |
+
return torch_xla.runtime.device_type() == "TPU"
|
| 154 |
+
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def is_torchao_available():
|
| 159 |
+
package_exists = _is_package_available("torchao")
|
| 160 |
+
if package_exists:
|
| 161 |
+
torchao_version = version.parse(importlib.metadata.version("torchao"))
|
| 162 |
+
return compare_versions(torchao_version, ">=", "0.6.1")
|
| 163 |
+
return False
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def is_deepspeed_available():
|
| 167 |
+
return _is_package_available("deepspeed")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def is_pippy_available():
|
| 171 |
+
return is_torch_version(">=", "2.4.0")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def is_bf16_available(ignore_tpu=False):
|
| 175 |
+
"Checks if bf16 is supported, optionally ignoring the TPU"
|
| 176 |
+
if is_torch_xla_available(check_is_tpu=True):
|
| 177 |
+
return not ignore_tpu
|
| 178 |
+
if is_cuda_available():
|
| 179 |
+
return torch.cuda.is_bf16_supported()
|
| 180 |
+
if is_mlu_available():
|
| 181 |
+
return torch.mlu.is_bf16_supported()
|
| 182 |
+
if is_xpu_available():
|
| 183 |
+
return torch.xpu.is_bf16_supported()
|
| 184 |
+
if is_mps_available():
|
| 185 |
+
return torch.backends.mps.is_macos_or_newer(14, 0)
|
| 186 |
+
return True
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def is_fp16_available():
|
| 190 |
+
"Checks if fp16 is supported"
|
| 191 |
+
if is_habana_gaudi1():
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
return True
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def is_fp8_available():
|
| 198 |
+
"Checks if fp8 is supported"
|
| 199 |
+
return is_msamp_available() or is_transformer_engine_available() or is_torchao_available()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def is_4bit_bnb_available():
|
| 203 |
+
package_exists = _is_package_available("bitsandbytes")
|
| 204 |
+
if package_exists:
|
| 205 |
+
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
|
| 206 |
+
return compare_versions(bnb_version, ">=", "0.39.0")
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def is_8bit_bnb_available():
|
| 211 |
+
package_exists = _is_package_available("bitsandbytes")
|
| 212 |
+
if package_exists:
|
| 213 |
+
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
|
| 214 |
+
return compare_versions(bnb_version, ">=", "0.37.2")
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def is_bnb_available(min_version=None):
|
| 219 |
+
package_exists = _is_package_available("bitsandbytes")
|
| 220 |
+
if package_exists and min_version is not None:
|
| 221 |
+
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
|
| 222 |
+
return compare_versions(bnb_version, ">=", min_version)
|
| 223 |
+
else:
|
| 224 |
+
return package_exists
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def is_bitsandbytes_multi_backend_available():
|
| 228 |
+
if not is_bnb_available():
|
| 229 |
+
return False
|
| 230 |
+
import bitsandbytes as bnb
|
| 231 |
+
|
| 232 |
+
return "multi_backend" in getattr(bnb, "features", set())
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def is_torchvision_available():
|
| 236 |
+
return _is_package_available("torchvision")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def is_megatron_lm_available():
|
| 240 |
+
if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
|
| 241 |
+
if importlib.util.find_spec("megatron") is not None:
|
| 242 |
+
try:
|
| 243 |
+
megatron_version = parse(importlib.metadata.version("megatron-core"))
|
| 244 |
+
if compare_versions(megatron_version, ">=", "0.8.0"):
|
| 245 |
+
return importlib.util.find_spec(".training", "megatron")
|
| 246 |
+
except Exception as e:
|
| 247 |
+
warnings.warn(f"Parse Megatron version failed. Exception:{e}")
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def is_transformers_available():
|
| 252 |
+
return _is_package_available("transformers")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def is_datasets_available():
|
| 256 |
+
return _is_package_available("datasets")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def is_peft_available():
|
| 260 |
+
return _is_package_available("peft")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def is_timm_available():
|
| 264 |
+
return _is_package_available("timm")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def is_triton_available():
|
| 268 |
+
if is_xpu_available():
|
| 269 |
+
return _is_package_available("triton", "pytorch-triton-xpu")
|
| 270 |
+
return _is_package_available("triton")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def is_aim_available():
|
| 274 |
+
package_exists = _is_package_available("aim")
|
| 275 |
+
if package_exists:
|
| 276 |
+
aim_version = version.parse(importlib.metadata.version("aim"))
|
| 277 |
+
return compare_versions(aim_version, "<", "4.0.0")
|
| 278 |
+
return False
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def is_tensorboard_available():
|
| 282 |
+
return _is_package_available("tensorboard") or _is_package_available("tensorboardX")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def is_wandb_available():
|
| 286 |
+
return _is_package_available("wandb")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def is_comet_ml_available():
|
| 290 |
+
return _is_package_available("comet_ml")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def is_swanlab_available():
|
| 294 |
+
return _is_package_available("swanlab")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def is_trackio_available():
|
| 298 |
+
return sys.version_info >= (3, 10) and _is_package_available("trackio")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def is_boto3_available():
|
| 302 |
+
return _is_package_available("boto3")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def is_rich_available():
|
| 306 |
+
if _is_package_available("rich"):
|
| 307 |
+
return parse_flag_from_env("ACCELERATE_ENABLE_RICH", False)
|
| 308 |
+
return False
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def is_sagemaker_available():
|
| 312 |
+
return _is_package_available("sagemaker")
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def is_tqdm_available():
|
| 316 |
+
return _is_package_available("tqdm")
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def is_clearml_available():
|
| 320 |
+
return _is_package_available("clearml")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def is_pandas_available():
|
| 324 |
+
return _is_package_available("pandas")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def is_matplotlib_available():
|
| 328 |
+
return _is_package_available("matplotlib")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def is_mlflow_available():
|
| 332 |
+
if _is_package_available("mlflow"):
|
| 333 |
+
return True
|
| 334 |
+
|
| 335 |
+
if importlib.util.find_spec("mlflow") is not None:
|
| 336 |
+
try:
|
| 337 |
+
_ = importlib.metadata.metadata("mlflow-skinny")
|
| 338 |
+
return True
|
| 339 |
+
except importlib.metadata.PackageNotFoundError:
|
| 340 |
+
return False
|
| 341 |
+
return False
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def is_mps_available(min_version="1.12"):
|
| 345 |
+
"Checks if MPS device is available. The minimum version required is 1.12."
|
| 346 |
+
# With torch 1.12, you can use torch.backends.mps
|
| 347 |
+
# With torch 2.0.0, you can use torch.mps
|
| 348 |
+
return is_torch_version(">=", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def is_ipex_available():
|
| 352 |
+
"Checks if ipex is installed."
|
| 353 |
+
|
| 354 |
+
def get_major_and_minor_from_version(full_version):
|
| 355 |
+
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
| 356 |
+
|
| 357 |
+
_torch_version = importlib.metadata.version("torch")
|
| 358 |
+
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
| 359 |
+
return False
|
| 360 |
+
_ipex_version = "N/A"
|
| 361 |
+
try:
|
| 362 |
+
_ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
|
| 363 |
+
except importlib.metadata.PackageNotFoundError:
|
| 364 |
+
return False
|
| 365 |
+
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
|
| 366 |
+
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
|
| 367 |
+
if torch_major_and_minor != ipex_major_and_minor:
|
| 368 |
+
warnings.warn(
|
| 369 |
+
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
|
| 370 |
+
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
|
| 371 |
+
)
|
| 372 |
+
return False
|
| 373 |
+
return True
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@lru_cache
|
| 377 |
+
def is_mlu_available(check_device=False):
|
| 378 |
+
"""
|
| 379 |
+
Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
|
| 380 |
+
uninitialized.
|
| 381 |
+
"""
|
| 382 |
+
if importlib.util.find_spec("torch_mlu") is None:
|
| 383 |
+
return False
|
| 384 |
+
|
| 385 |
+
import torch_mlu # noqa: F401
|
| 386 |
+
|
| 387 |
+
with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"):
|
| 388 |
+
available = torch.mlu.is_available()
|
| 389 |
+
|
| 390 |
+
return available
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@lru_cache
|
| 394 |
+
def is_musa_available(check_device=False):
|
| 395 |
+
"Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
|
| 396 |
+
if importlib.util.find_spec("torch_musa") is None:
|
| 397 |
+
return False
|
| 398 |
+
|
| 399 |
+
import torch_musa # noqa: F401
|
| 400 |
+
|
| 401 |
+
if check_device:
|
| 402 |
+
try:
|
| 403 |
+
# Will raise a RuntimeError if no MUSA is found
|
| 404 |
+
_ = torch.musa.device_count()
|
| 405 |
+
return torch.musa.is_available()
|
| 406 |
+
except RuntimeError:
|
| 407 |
+
return False
|
| 408 |
+
return hasattr(torch, "musa") and torch.musa.is_available()
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@lru_cache
|
| 412 |
+
def is_npu_available(check_device=False):
|
| 413 |
+
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
|
| 414 |
+
if importlib.util.find_spec("torch_npu") is None:
|
| 415 |
+
return False
|
| 416 |
+
|
| 417 |
+
# NOTE: importing torch_npu may raise error in some envs
|
| 418 |
+
# e.g. inside cpu-only container with torch_npu installed
|
| 419 |
+
try:
|
| 420 |
+
import torch_npu # noqa: F401
|
| 421 |
+
except Exception:
|
| 422 |
+
return False
|
| 423 |
+
|
| 424 |
+
if check_device:
|
| 425 |
+
try:
|
| 426 |
+
# Will raise a RuntimeError if no NPU is found
|
| 427 |
+
_ = torch.npu.device_count()
|
| 428 |
+
return torch.npu.is_available()
|
| 429 |
+
except RuntimeError:
|
| 430 |
+
return False
|
| 431 |
+
return hasattr(torch, "npu") and torch.npu.is_available()
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@lru_cache
|
| 435 |
+
def is_sdaa_available(check_device=False):
|
| 436 |
+
"Checks if `torch_sdaa` is installed and potentially if a SDAA is in the environment"
|
| 437 |
+
if importlib.util.find_spec("torch_sdaa") is None:
|
| 438 |
+
return False
|
| 439 |
+
|
| 440 |
+
import torch_sdaa # noqa: F401
|
| 441 |
+
|
| 442 |
+
if check_device:
|
| 443 |
+
try:
|
| 444 |
+
# Will raise a RuntimeError if no NPU is found
|
| 445 |
+
_ = torch.sdaa.device_count()
|
| 446 |
+
return torch.sdaa.is_available()
|
| 447 |
+
except RuntimeError:
|
| 448 |
+
return False
|
| 449 |
+
return hasattr(torch, "sdaa") and torch.sdaa.is_available()
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
@lru_cache
|
| 453 |
+
def is_hpu_available(init_hccl=False):
|
| 454 |
+
"Checks if `torch.hpu` is installed and potentially if a HPU is in the environment"
|
| 455 |
+
if (
|
| 456 |
+
importlib.util.find_spec("habana_frameworks") is None
|
| 457 |
+
or importlib.util.find_spec("habana_frameworks.torch") is None
|
| 458 |
+
):
|
| 459 |
+
return False
|
| 460 |
+
|
| 461 |
+
import habana_frameworks.torch # noqa: F401
|
| 462 |
+
|
| 463 |
+
if init_hccl:
|
| 464 |
+
import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401
|
| 465 |
+
|
| 466 |
+
return hasattr(torch, "hpu") and torch.hpu.is_available()
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def is_habana_gaudi1():
|
| 470 |
+
if is_hpu_available():
|
| 471 |
+
import habana_frameworks.torch.utils.experimental as htexp # noqa: F401
|
| 472 |
+
|
| 473 |
+
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi:
|
| 474 |
+
return True
|
| 475 |
+
|
| 476 |
+
return False
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@lru_cache
|
| 480 |
+
def is_xpu_available(check_device=False):
|
| 481 |
+
"""
|
| 482 |
+
Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or via stock PyTorch (>=2.4) and
|
| 483 |
+
potentially if a XPU is in the environment
|
| 484 |
+
"""
|
| 485 |
+
|
| 486 |
+
if is_ipex_available():
|
| 487 |
+
import intel_extension_for_pytorch # noqa: F401
|
| 488 |
+
else:
|
| 489 |
+
if is_torch_version("<=", "2.3"):
|
| 490 |
+
return False
|
| 491 |
+
|
| 492 |
+
if check_device:
|
| 493 |
+
try:
|
| 494 |
+
# Will raise a RuntimeError if no XPU is found
|
| 495 |
+
_ = torch.xpu.device_count()
|
| 496 |
+
return torch.xpu.is_available()
|
| 497 |
+
except RuntimeError:
|
| 498 |
+
return False
|
| 499 |
+
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def is_dvclive_available():
|
| 503 |
+
return _is_package_available("dvclive")
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def is_torchdata_available():
|
| 507 |
+
return _is_package_available("torchdata")
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
|
| 511 |
+
def is_torchdata_stateful_dataloader_available():
|
| 512 |
+
package_exists = _is_package_available("torchdata")
|
| 513 |
+
if package_exists:
|
| 514 |
+
torchdata_version = version.parse(importlib.metadata.version("torchdata"))
|
| 515 |
+
return compare_versions(torchdata_version, ">=", "0.8.0")
|
| 516 |
+
return False
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def torchao_required(func):
|
| 520 |
+
"""
|
| 521 |
+
A decorator that ensures the decorated function is only called when torchao is available.
|
| 522 |
+
"""
|
| 523 |
+
|
| 524 |
+
@wraps(func)
|
| 525 |
+
def wrapper(*args, **kwargs):
|
| 526 |
+
if not is_torchao_available():
|
| 527 |
+
raise ImportError(
|
| 528 |
+
"`torchao` is not available, please install it before calling this function via `pip install torchao`."
|
| 529 |
+
)
|
| 530 |
+
return func(*args, **kwargs)
|
| 531 |
+
|
| 532 |
+
return wrapper
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
# TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed`
|
| 536 |
+
def deepspeed_required(func):
|
| 537 |
+
"""
|
| 538 |
+
A decorator that ensures the decorated function is only called when deepspeed is enabled.
|
| 539 |
+
"""
|
| 540 |
+
|
| 541 |
+
@wraps(func)
|
| 542 |
+
def wrapper(*args, **kwargs):
|
| 543 |
+
from accelerate.state import AcceleratorState
|
| 544 |
+
from accelerate.utils.dataclasses import DistributedType
|
| 545 |
+
|
| 546 |
+
if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type != DistributedType.DEEPSPEED:
|
| 547 |
+
raise ValueError(
|
| 548 |
+
"DeepSpeed is not enabled, please make sure that an `Accelerator` is configured for `deepspeed` "
|
| 549 |
+
"before calling this function."
|
| 550 |
+
)
|
| 551 |
+
return func(*args, **kwargs)
|
| 552 |
+
|
| 553 |
+
return wrapper
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def is_weights_only_available():
|
| 557 |
+
# Weights only with allowlist was added in 2.4.0
|
| 558 |
+
# ref: https://github.com/pytorch/pytorch/pull/124331
|
| 559 |
+
return is_torch_version(">=", "2.4.0")
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def is_numpy_available(min_version="1.25.0"):
|
| 563 |
+
numpy_version = parse(importlib.metadata.version("numpy"))
|
| 564 |
+
return compare_versions(numpy_version, ">=", min_version)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/launch.py
ADDED
|
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
import warnings
|
| 20 |
+
from ast import literal_eval
|
| 21 |
+
from shutil import which
|
| 22 |
+
from typing import Any
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from ..commands.config.config_args import SageMakerConfig
|
| 27 |
+
from ..utils import (
|
| 28 |
+
DynamoBackend,
|
| 29 |
+
PrecisionType,
|
| 30 |
+
is_ccl_available,
|
| 31 |
+
is_fp8_available,
|
| 32 |
+
is_hpu_available,
|
| 33 |
+
is_ipex_available,
|
| 34 |
+
is_mlu_available,
|
| 35 |
+
is_musa_available,
|
| 36 |
+
is_npu_available,
|
| 37 |
+
is_sdaa_available,
|
| 38 |
+
is_torch_xla_available,
|
| 39 |
+
is_xpu_available,
|
| 40 |
+
)
|
| 41 |
+
from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
|
| 42 |
+
from ..utils.other import get_free_port, is_port_in_use, merge_dicts
|
| 43 |
+
from ..utils.versions import compare_versions
|
| 44 |
+
from .dataclasses import DistributedType, SageMakerDistributedType
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _filter_args(args, parser, default_args=[]):
|
| 48 |
+
"""
|
| 49 |
+
Filters out all `accelerate` specific args
|
| 50 |
+
"""
|
| 51 |
+
new_args, _ = parser.parse_known_args(default_args)
|
| 52 |
+
for key, value in vars(args).items():
|
| 53 |
+
if key in vars(new_args).keys():
|
| 54 |
+
setattr(new_args, key, value)
|
| 55 |
+
return new_args
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_mpirun_args():
|
| 59 |
+
"""
|
| 60 |
+
Determines the executable and argument names for mpirun, based on the type of install. The supported MPI programs
|
| 61 |
+
are: OpenMPI, Intel MPI, or MVAPICH.
|
| 62 |
+
|
| 63 |
+
Returns: Program name and arg names for hostfile, num processes, and processes per node
|
| 64 |
+
"""
|
| 65 |
+
# Find the MPI program name
|
| 66 |
+
mpi_apps = [x for x in ["mpirun", "mpiexec"] if which(x)]
|
| 67 |
+
|
| 68 |
+
if len(mpi_apps) == 0:
|
| 69 |
+
raise OSError("mpirun or mpiexec were not found. Ensure that Intel MPI, Open MPI, or MVAPICH are installed.")
|
| 70 |
+
|
| 71 |
+
# Call the app with the --version flag to determine which MPI app is installed
|
| 72 |
+
mpi_app = mpi_apps[0]
|
| 73 |
+
mpirun_version = subprocess.check_output([mpi_app, "--version"])
|
| 74 |
+
|
| 75 |
+
if b"Open MPI" in mpirun_version:
|
| 76 |
+
return mpi_app, "--hostfile", "-n", "--npernode", "--bind-to"
|
| 77 |
+
else:
|
| 78 |
+
# Intel MPI and MVAPICH both use the same arg names
|
| 79 |
+
return mpi_app, "-f", "-n", "-ppn", ""
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
|
| 83 |
+
"""
|
| 84 |
+
Setup the FP8 environment variables.
|
| 85 |
+
"""
|
| 86 |
+
prefix = "ACCELERATE_"
|
| 87 |
+
for arg in vars(args):
|
| 88 |
+
if arg.startswith("fp8_"):
|
| 89 |
+
value = getattr(args, arg)
|
| 90 |
+
if value is not None:
|
| 91 |
+
if arg == "fp8_override_linear_precision":
|
| 92 |
+
current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
|
| 93 |
+
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
|
| 94 |
+
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
|
| 95 |
+
else:
|
| 96 |
+
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
|
| 97 |
+
return current_env
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:
|
| 101 |
+
"""
|
| 102 |
+
Prepares and returns the command list and an environment with the correct simple launcher environment variables.
|
| 103 |
+
"""
|
| 104 |
+
cmd = []
|
| 105 |
+
if args.no_python and args.module:
|
| 106 |
+
raise ValueError("--module and --no_python cannot be used together")
|
| 107 |
+
|
| 108 |
+
num_processes = getattr(args, "num_processes", None)
|
| 109 |
+
num_machines = args.num_machines
|
| 110 |
+
if args.mpirun_hostfile is not None:
|
| 111 |
+
mpi_app_name, hostfile_arg, num_proc_arg, proc_per_node_arg, bind_to_arg = _get_mpirun_args()
|
| 112 |
+
bind_to = getattr(args, "bind-to", "socket")
|
| 113 |
+
nproc_per_node = str(num_processes // num_machines) if num_processes and num_machines else "1"
|
| 114 |
+
cmd += [
|
| 115 |
+
mpi_app_name,
|
| 116 |
+
hostfile_arg,
|
| 117 |
+
args.mpirun_hostfile,
|
| 118 |
+
proc_per_node_arg,
|
| 119 |
+
nproc_per_node,
|
| 120 |
+
]
|
| 121 |
+
if num_processes:
|
| 122 |
+
cmd += [num_proc_arg, str(num_processes)]
|
| 123 |
+
if bind_to_arg:
|
| 124 |
+
cmd += [bind_to_arg, bind_to]
|
| 125 |
+
if not args.no_python:
|
| 126 |
+
cmd.append(sys.executable)
|
| 127 |
+
if args.module:
|
| 128 |
+
cmd.append("-m")
|
| 129 |
+
cmd.append(args.training_script)
|
| 130 |
+
cmd.extend(args.training_script_args)
|
| 131 |
+
|
| 132 |
+
current_env = os.environ.copy()
|
| 133 |
+
current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu)
|
| 134 |
+
if args.debug:
|
| 135 |
+
current_env["ACCELERATE_DEBUG_MODE"] = "true"
|
| 136 |
+
if args.gpu_ids != "all" and args.gpu_ids is not None:
|
| 137 |
+
if is_xpu_available():
|
| 138 |
+
current_env["ZE_AFFINITY_MASK"] = args.gpu_ids
|
| 139 |
+
elif is_mlu_available():
|
| 140 |
+
current_env["MLU_VISIBLE_DEVICES"] = args.gpu_ids
|
| 141 |
+
elif is_sdaa_available():
|
| 142 |
+
current_env["SDAA_VISIBLE_DEVICES"] = args.gpu_ids
|
| 143 |
+
elif is_musa_available():
|
| 144 |
+
current_env["MUSA_VISIBLE_DEVICES"] = args.gpu_ids
|
| 145 |
+
elif is_npu_available():
|
| 146 |
+
current_env["ASCEND_RT_VISIBLE_DEVICES"] = args.gpu_ids
|
| 147 |
+
elif is_hpu_available():
|
| 148 |
+
current_env["HABANA_VISIBLE_MODULES"] = args.gpu_ids
|
| 149 |
+
else:
|
| 150 |
+
current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
|
| 151 |
+
if num_machines > 1:
|
| 152 |
+
assert args.main_process_ip is not None, (
|
| 153 |
+
"When using multiple machines, you need to specify the main process IP."
|
| 154 |
+
)
|
| 155 |
+
assert args.main_process_port is not None, (
|
| 156 |
+
"When using multiple machines, you need to specify the main process port."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
ccl_worker_count = getattr(args, "mpirun_ccl", 0) if is_ccl_available() else 0
|
| 160 |
+
if (num_processes is not None and num_processes > 1) or num_machines > 1:
|
| 161 |
+
current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1"
|
| 162 |
+
current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500"
|
| 163 |
+
current_env["CCL_WORKER_COUNT"] = str(ccl_worker_count)
|
| 164 |
+
if current_env["ACCELERATE_USE_CPU"]:
|
| 165 |
+
current_env["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
|
| 166 |
+
current_env["KMP_BLOCKTIME"] = str(1)
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
mixed_precision = PrecisionType(args.mixed_precision.lower())
|
| 170 |
+
except ValueError:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
|
| 176 |
+
if args.mixed_precision.lower() == "fp8":
|
| 177 |
+
if not is_fp8_available():
|
| 178 |
+
raise RuntimeError(
|
| 179 |
+
"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
|
| 180 |
+
)
|
| 181 |
+
current_env = setup_fp8_env(args, current_env)
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
|
| 185 |
+
except ValueError:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
|
| 188 |
+
)
|
| 189 |
+
current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value
|
| 190 |
+
current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode
|
| 191 |
+
current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph)
|
| 192 |
+
current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic)
|
| 193 |
+
current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation)
|
| 194 |
+
|
| 195 |
+
current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
|
| 196 |
+
if is_ipex_available():
|
| 197 |
+
current_env["ACCELERATE_USE_IPEX"] = str(args.ipex).lower()
|
| 198 |
+
if args.enable_cpu_affinity:
|
| 199 |
+
current_env["ACCELERATE_CPU_AFFINITY"] = "1"
|
| 200 |
+
return cmd, current_env
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
|
| 204 |
+
"""
|
| 205 |
+
Prepares and returns an environment with the correct multi-GPU environment variables.
|
| 206 |
+
"""
|
| 207 |
+
# get free port and update configurations
|
| 208 |
+
if args.main_process_port == 0:
|
| 209 |
+
args.main_process_port = get_free_port()
|
| 210 |
+
|
| 211 |
+
elif args.main_process_port is None:
|
| 212 |
+
args.main_process_port = 29500
|
| 213 |
+
|
| 214 |
+
num_processes = args.num_processes
|
| 215 |
+
num_machines = args.num_machines
|
| 216 |
+
main_process_ip = args.main_process_ip
|
| 217 |
+
main_process_port = args.main_process_port
|
| 218 |
+
if num_machines > 1:
|
| 219 |
+
args.nproc_per_node = str(num_processes // num_machines)
|
| 220 |
+
args.nnodes = str(num_machines)
|
| 221 |
+
args.node_rank = int(args.machine_rank)
|
| 222 |
+
if getattr(args, "same_network", False):
|
| 223 |
+
args.master_addr = str(main_process_ip)
|
| 224 |
+
args.master_port = str(main_process_port)
|
| 225 |
+
else:
|
| 226 |
+
args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}"
|
| 227 |
+
else:
|
| 228 |
+
args.nproc_per_node = str(num_processes)
|
| 229 |
+
if main_process_port is not None:
|
| 230 |
+
args.master_port = str(main_process_port)
|
| 231 |
+
|
| 232 |
+
# only need to check port availability in main process, in case we have to start multiple launchers on the same machine
|
| 233 |
+
# for some reasons like splitting log files.
|
| 234 |
+
need_port_check = num_machines <= 1 or int(args.machine_rank) == 0
|
| 235 |
+
if need_port_check and is_port_in_use(main_process_port):
|
| 236 |
+
if num_machines <= 1:
|
| 237 |
+
args.standalone = True
|
| 238 |
+
warnings.warn(
|
| 239 |
+
f"Port `{main_process_port}` is already in use. "
|
| 240 |
+
"Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. "
|
| 241 |
+
"If this current attempt fails, or for more control in future runs, please specify a different port "
|
| 242 |
+
"(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection "
|
| 243 |
+
"in your launch command or Accelerate config file."
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
raise ConnectionError(
|
| 247 |
+
f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. "
|
| 248 |
+
"Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)"
|
| 249 |
+
" and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`."
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if args.module and args.no_python:
|
| 253 |
+
raise ValueError("--module and --no_python cannot be used together")
|
| 254 |
+
elif args.module:
|
| 255 |
+
args.module = True
|
| 256 |
+
elif args.no_python:
|
| 257 |
+
args.no_python = True
|
| 258 |
+
|
| 259 |
+
current_env = os.environ.copy()
|
| 260 |
+
if args.debug:
|
| 261 |
+
current_env["ACCELERATE_DEBUG_MODE"] = "true"
|
| 262 |
+
gpu_ids = getattr(args, "gpu_ids", "all")
|
| 263 |
+
if gpu_ids != "all" and args.gpu_ids is not None:
|
| 264 |
+
if is_xpu_available():
|
| 265 |
+
current_env["ZE_AFFINITY_MASK"] = gpu_ids
|
| 266 |
+
elif is_mlu_available():
|
| 267 |
+
current_env["MLU_VISIBLE_DEVICES"] = gpu_ids
|
| 268 |
+
elif is_sdaa_available():
|
| 269 |
+
current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids
|
| 270 |
+
elif is_musa_available():
|
| 271 |
+
current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids
|
| 272 |
+
elif is_npu_available():
|
| 273 |
+
current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids
|
| 274 |
+
elif is_hpu_available():
|
| 275 |
+
current_env["HABANA_VISIBLE_MODULES"] = gpu_ids
|
| 276 |
+
else:
|
| 277 |
+
current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
|
| 278 |
+
mixed_precision = args.mixed_precision.lower()
|
| 279 |
+
try:
|
| 280 |
+
mixed_precision = PrecisionType(mixed_precision)
|
| 281 |
+
except ValueError:
|
| 282 |
+
raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")
|
| 283 |
+
|
| 284 |
+
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
|
| 285 |
+
if args.mixed_precision.lower() == "fp8":
|
| 286 |
+
if not is_fp8_available():
|
| 287 |
+
raise RuntimeError(
|
| 288 |
+
"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
|
| 289 |
+
)
|
| 290 |
+
current_env = setup_fp8_env(args, current_env)
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
|
| 294 |
+
except ValueError:
|
| 295 |
+
raise ValueError(
|
| 296 |
+
f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
|
| 297 |
+
)
|
| 298 |
+
current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value
|
| 299 |
+
current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode
|
| 300 |
+
current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph)
|
| 301 |
+
current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic)
|
| 302 |
+
current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation)
|
| 303 |
+
|
| 304 |
+
if args.use_fsdp:
|
| 305 |
+
current_env["ACCELERATE_USE_FSDP"] = "true"
|
| 306 |
+
if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states:
|
| 307 |
+
raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`")
|
| 308 |
+
|
| 309 |
+
current_env["FSDP_VERSION"] = str(args.fsdp_version) if hasattr(args, "fsdp_version") else "1"
|
| 310 |
+
|
| 311 |
+
# For backwards compatibility, we support this in launched scripts,
|
| 312 |
+
# however, we do not ask users for this in `accelerate config` CLI
|
| 313 |
+
current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy)
|
| 314 |
+
|
| 315 |
+
current_env["FSDP_RESHARD_AFTER_FORWARD"] = str(args.fsdp_reshard_after_forward).lower()
|
| 316 |
+
current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower()
|
| 317 |
+
current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params)
|
| 318 |
+
if args.fsdp_auto_wrap_policy is not None:
|
| 319 |
+
current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy)
|
| 320 |
+
if args.fsdp_transformer_layer_cls_to_wrap is not None:
|
| 321 |
+
current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap)
|
| 322 |
+
if args.fsdp_backward_prefetch is not None:
|
| 323 |
+
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch)
|
| 324 |
+
if args.fsdp_state_dict_type is not None:
|
| 325 |
+
current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)
|
| 326 |
+
current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower()
|
| 327 |
+
current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower()
|
| 328 |
+
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
|
| 329 |
+
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
|
| 330 |
+
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
|
| 331 |
+
if getattr(args, "fsdp_ignored_modules", None) is not None:
|
| 332 |
+
current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
|
| 333 |
+
|
| 334 |
+
if args.use_megatron_lm:
|
| 335 |
+
prefix = "MEGATRON_LM_"
|
| 336 |
+
current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"
|
| 337 |
+
current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree)
|
| 338 |
+
current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree)
|
| 339 |
+
current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping)
|
| 340 |
+
if args.megatron_lm_num_micro_batches is not None:
|
| 341 |
+
current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches)
|
| 342 |
+
if args.megatron_lm_sequence_parallelism is not None:
|
| 343 |
+
current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism)
|
| 344 |
+
if args.megatron_lm_recompute_activations is not None:
|
| 345 |
+
current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations)
|
| 346 |
+
if args.megatron_lm_use_distributed_optimizer is not None:
|
| 347 |
+
current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer)
|
| 348 |
+
|
| 349 |
+
current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
|
| 350 |
+
if args.enable_cpu_affinity:
|
| 351 |
+
current_env["ACCELERATE_CPU_AFFINITY"] = "1"
|
| 352 |
+
|
| 353 |
+
if args.use_parallelism_config:
|
| 354 |
+
current_env = prepare_extend_env_parallelism_config(args, current_env)
|
| 355 |
+
|
| 356 |
+
return current_env
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def prepare_extend_env_parallelism_config(
|
| 360 |
+
args: argparse.Namespace, current_env: dict
|
| 361 |
+
) -> tuple[list[str], dict[str, str]]:
|
| 362 |
+
"""
|
| 363 |
+
Extends `current_env` with context parallelism env vars if any have been set
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
prefix = "PARALLELISM_CONFIG_"
|
| 367 |
+
|
| 368 |
+
current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
| 369 |
+
current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size)
|
| 370 |
+
current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size)
|
| 371 |
+
current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size)
|
| 372 |
+
current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size)
|
| 373 |
+
current_env[prefix + "CP_BACKEND"] = str(args.parallelism_config_cp_backend)
|
| 374 |
+
current_env[prefix + "SP_SIZE"] = str(args.parallelism_config_sp_size)
|
| 375 |
+
current_env[prefix + "SP_BACKEND"] = str(args.parallelism_config_sp_backend)
|
| 376 |
+
if args.parallelism_config_cp_size > 1:
|
| 377 |
+
current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy)
|
| 378 |
+
if args.parallelism_config_sp_size > 1:
|
| 379 |
+
current_env[prefix + "SP_SEQ_LENGTH"] = str(args.parallelism_config_sp_seq_length)
|
| 380 |
+
current_env[prefix + "SP_SEQ_LENGTH_IS_VARIABLE"] = str(args.parallelism_config_sp_seq_length_is_variable)
|
| 381 |
+
current_env[prefix + "SP_ATTN_IMPLEMENTATION"] = str(args.parallelism_config_sp_attn_implementation)
|
| 382 |
+
|
| 383 |
+
return current_env
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:
|
| 387 |
+
"""
|
| 388 |
+
Prepares and returns the command list and an environment with the correct DeepSpeed environment variables.
|
| 389 |
+
"""
|
| 390 |
+
# get free port and update configurations
|
| 391 |
+
if args.main_process_port == 0:
|
| 392 |
+
args.main_process_port = get_free_port()
|
| 393 |
+
|
| 394 |
+
elif args.main_process_port is None:
|
| 395 |
+
args.main_process_port = 29500
|
| 396 |
+
|
| 397 |
+
num_processes = args.num_processes
|
| 398 |
+
num_machines = args.num_machines
|
| 399 |
+
main_process_ip = args.main_process_ip
|
| 400 |
+
main_process_port = args.main_process_port
|
| 401 |
+
cmd = None
|
| 402 |
+
|
| 403 |
+
# make sure launcher is not None
|
| 404 |
+
if args.deepspeed_multinode_launcher is None:
|
| 405 |
+
# set to default pdsh
|
| 406 |
+
args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0]
|
| 407 |
+
|
| 408 |
+
if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
|
| 409 |
+
cmd = ["deepspeed"]
|
| 410 |
+
cmd.extend(["--hostfile", str(args.deepspeed_hostfile)])
|
| 411 |
+
if args.deepspeed_multinode_launcher == "nossh":
|
| 412 |
+
if compare_versions("deepspeed", "<", "0.14.5"):
|
| 413 |
+
raise ValueError("nossh launcher requires DeepSpeed >= 0.14.5")
|
| 414 |
+
cmd.extend(["--node_rank", str(args.machine_rank), "--no_ssh"])
|
| 415 |
+
else:
|
| 416 |
+
cmd.extend(["--no_local_rank", "--launcher", str(args.deepspeed_multinode_launcher)])
|
| 417 |
+
if args.deepspeed_exclusion_filter is not None:
|
| 418 |
+
cmd.extend(
|
| 419 |
+
[
|
| 420 |
+
"--exclude",
|
| 421 |
+
str(args.deepspeed_exclusion_filter),
|
| 422 |
+
]
|
| 423 |
+
)
|
| 424 |
+
elif args.deepspeed_inclusion_filter is not None:
|
| 425 |
+
cmd.extend(
|
| 426 |
+
[
|
| 427 |
+
"--include",
|
| 428 |
+
str(args.deepspeed_inclusion_filter),
|
| 429 |
+
]
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)])
|
| 433 |
+
if main_process_ip:
|
| 434 |
+
cmd.extend(["--master_addr", str(main_process_ip)])
|
| 435 |
+
cmd.extend(["--master_port", str(main_process_port)])
|
| 436 |
+
if args.module and args.no_python:
|
| 437 |
+
raise ValueError("--module and --no_python cannot be used together")
|
| 438 |
+
elif args.module:
|
| 439 |
+
cmd.append("--module")
|
| 440 |
+
elif args.no_python:
|
| 441 |
+
cmd.append("--no_python")
|
| 442 |
+
cmd.append(args.training_script)
|
| 443 |
+
cmd.extend(args.training_script_args)
|
| 444 |
+
elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]:
|
| 445 |
+
args.nproc_per_node = str(num_processes // num_machines)
|
| 446 |
+
args.nnodes = str(num_machines)
|
| 447 |
+
args.node_rank = int(args.machine_rank)
|
| 448 |
+
if getattr(args, "same_network", False):
|
| 449 |
+
args.master_addr = str(main_process_ip)
|
| 450 |
+
args.master_port = str(main_process_port)
|
| 451 |
+
else:
|
| 452 |
+
args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}"
|
| 453 |
+
else:
|
| 454 |
+
args.nproc_per_node = str(num_processes)
|
| 455 |
+
if main_process_port is not None:
|
| 456 |
+
args.master_port = str(main_process_port)
|
| 457 |
+
|
| 458 |
+
# only need to check port availability in main process, in case we have to start multiple launchers on the same machine
|
| 459 |
+
# for some reasons like splitting log files.
|
| 460 |
+
need_port_check = num_machines <= 1 or int(args.machine_rank) == 0
|
| 461 |
+
if need_port_check and is_port_in_use(main_process_port):
|
| 462 |
+
if num_machines <= 1:
|
| 463 |
+
args.standalone = True
|
| 464 |
+
warnings.warn(
|
| 465 |
+
f"Port `{main_process_port}` is already in use. "
|
| 466 |
+
"Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. "
|
| 467 |
+
"If this current attempt fails, or for more control in future runs, please specify a different port "
|
| 468 |
+
"(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection "
|
| 469 |
+
"in your launch command or Accelerate config file."
|
| 470 |
+
)
|
| 471 |
+
else:
|
| 472 |
+
raise ConnectionError(
|
| 473 |
+
f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. "
|
| 474 |
+
"Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)"
|
| 475 |
+
" and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
if args.module and args.no_python:
|
| 479 |
+
raise ValueError("--module and --no_python cannot be used together")
|
| 480 |
+
elif args.module:
|
| 481 |
+
args.module = True
|
| 482 |
+
elif args.no_python:
|
| 483 |
+
args.no_python = True
|
| 484 |
+
|
| 485 |
+
current_env = os.environ.copy()
|
| 486 |
+
if args.debug:
|
| 487 |
+
current_env["ACCELERATE_DEBUG_MODE"] = "true"
|
| 488 |
+
gpu_ids = getattr(args, "gpu_ids", "all")
|
| 489 |
+
if gpu_ids != "all" and args.gpu_ids is not None:
|
| 490 |
+
if is_xpu_available():
|
| 491 |
+
current_env["ZE_AFFINITY_MASK"] = gpu_ids
|
| 492 |
+
elif is_mlu_available():
|
| 493 |
+
current_env["MLU_VISIBLE_DEVICES"] = gpu_ids
|
| 494 |
+
elif is_sdaa_available():
|
| 495 |
+
current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids
|
| 496 |
+
elif is_musa_available():
|
| 497 |
+
current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids
|
| 498 |
+
elif is_npu_available():
|
| 499 |
+
current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids
|
| 500 |
+
elif is_hpu_available():
|
| 501 |
+
current_env["HABANA_VISIBLE_MODULES"] = gpu_ids
|
| 502 |
+
else:
|
| 503 |
+
current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
|
| 504 |
+
try:
|
| 505 |
+
mixed_precision = PrecisionType(args.mixed_precision.lower())
|
| 506 |
+
except ValueError:
|
| 507 |
+
raise ValueError(
|
| 508 |
+
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath("."))
|
| 512 |
+
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
|
| 513 |
+
if args.mixed_precision.lower() == "fp8":
|
| 514 |
+
if not is_fp8_available():
|
| 515 |
+
raise RuntimeError(
|
| 516 |
+
"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
|
| 517 |
+
)
|
| 518 |
+
current_env = setup_fp8_env(args, current_env)
|
| 519 |
+
current_env["ACCELERATE_CONFIG_DS_FIELDS"] = str(args.deepspeed_fields_from_accelerate_config).lower()
|
| 520 |
+
current_env["ACCELERATE_USE_DEEPSPEED"] = "true"
|
| 521 |
+
if args.zero_stage is not None:
|
| 522 |
+
current_env["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage)
|
| 523 |
+
if args.gradient_accumulation_steps is not None:
|
| 524 |
+
current_env["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps)
|
| 525 |
+
if args.gradient_clipping is not None:
|
| 526 |
+
current_env["ACCELERATE_GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower()
|
| 527 |
+
if args.offload_optimizer_device is not None:
|
| 528 |
+
current_env["ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower()
|
| 529 |
+
if args.offload_param_device is not None:
|
| 530 |
+
current_env["ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower()
|
| 531 |
+
if args.zero3_init_flag is not None:
|
| 532 |
+
current_env["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower()
|
| 533 |
+
if args.zero3_save_16bit_model is not None:
|
| 534 |
+
current_env["ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower()
|
| 535 |
+
if args.deepspeed_config_file is not None:
|
| 536 |
+
current_env["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file)
|
| 537 |
+
if args.enable_cpu_affinity:
|
| 538 |
+
current_env["ACCELERATE_CPU_AFFINITY"] = "1"
|
| 539 |
+
if args.deepspeed_moe_layer_cls_names is not None:
|
| 540 |
+
current_env["ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES"] = str(args.deepspeed_moe_layer_cls_names)
|
| 541 |
+
|
| 542 |
+
if args.use_parallelism_config:
|
| 543 |
+
current_env = prepare_extend_env_parallelism_config(args, current_env)
|
| 544 |
+
|
| 545 |
+
return cmd, current_env
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def prepare_tpu(
|
| 549 |
+
args: argparse.Namespace, current_env: dict[str, str], pod: bool = False
|
| 550 |
+
) -> tuple[argparse.Namespace, dict[str, str]]:
|
| 551 |
+
"""
|
| 552 |
+
Prepares and returns an environment with the correct TPU environment variables.
|
| 553 |
+
"""
|
| 554 |
+
if args.mixed_precision == "bf16" and is_torch_xla_available(check_is_tpu=True):
|
| 555 |
+
if args.downcast_bf16:
|
| 556 |
+
current_env["XLA_DOWNCAST_BF16"] = "1"
|
| 557 |
+
else:
|
| 558 |
+
current_env["XLA_USE_BF16"] = "1"
|
| 559 |
+
if args.debug:
|
| 560 |
+
current_env["ACCELERATE_DEBUG_MODE"] = "true"
|
| 561 |
+
if pod:
|
| 562 |
+
# Take explicit args and set them up for XLA
|
| 563 |
+
args.vm = args.tpu_vm
|
| 564 |
+
args.tpu = args.tpu_name
|
| 565 |
+
return args, current_env
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def _convert_nargs_to_dict(nargs: list[str]) -> dict[str, str]:
|
| 569 |
+
if len(nargs) < 0:
|
| 570 |
+
return {}
|
| 571 |
+
# helper function to infer type for argsparser
|
| 572 |
+
|
| 573 |
+
def _infer_type(s):
|
| 574 |
+
try:
|
| 575 |
+
s = float(s)
|
| 576 |
+
|
| 577 |
+
if s // 1 == s:
|
| 578 |
+
return int(s)
|
| 579 |
+
return s
|
| 580 |
+
except ValueError:
|
| 581 |
+
return s
|
| 582 |
+
|
| 583 |
+
parser = argparse.ArgumentParser()
|
| 584 |
+
_, unknown = parser.parse_known_args(nargs)
|
| 585 |
+
for index, argument in enumerate(unknown):
|
| 586 |
+
if argument.startswith(("-", "--")):
|
| 587 |
+
action = None
|
| 588 |
+
if index + 1 < len(unknown): # checks if next index would be in list
|
| 589 |
+
if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key
|
| 590 |
+
# raise an error if element is store_true or store_false
|
| 591 |
+
raise ValueError(
|
| 592 |
+
"SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
|
| 593 |
+
)
|
| 594 |
+
else: # raise an error if last element is store_true or store_false
|
| 595 |
+
raise ValueError(
|
| 596 |
+
"SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
|
| 597 |
+
)
|
| 598 |
+
# adds argument to parser based on action_store true
|
| 599 |
+
if action is None:
|
| 600 |
+
parser.add_argument(argument, type=_infer_type)
|
| 601 |
+
else:
|
| 602 |
+
parser.add_argument(argument, action=action)
|
| 603 |
+
|
| 604 |
+
return {
|
| 605 |
+
key: (literal_eval(value) if value in ("True", "False") else value)
|
| 606 |
+
for key, value in parser.parse_args(nargs).__dict__.items()
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def prepare_sagemager_args_inputs(
|
| 611 |
+
sagemaker_config: SageMakerConfig, args: argparse.Namespace
|
| 612 |
+
) -> tuple[argparse.Namespace, dict[str, Any]]:
|
| 613 |
+
# configure environment
|
| 614 |
+
print("Configuring Amazon SageMaker environment")
|
| 615 |
+
os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region
|
| 616 |
+
|
| 617 |
+
# configure credentials
|
| 618 |
+
if sagemaker_config.profile is not None:
|
| 619 |
+
os.environ["AWS_PROFILE"] = sagemaker_config.profile
|
| 620 |
+
elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None:
|
| 621 |
+
os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id
|
| 622 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key
|
| 623 |
+
else:
|
| 624 |
+
raise OSError("You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile")
|
| 625 |
+
|
| 626 |
+
# extract needed arguments
|
| 627 |
+
source_dir = os.path.dirname(args.training_script)
|
| 628 |
+
if not source_dir: # checks if string is empty
|
| 629 |
+
source_dir = "."
|
| 630 |
+
entry_point = os.path.basename(args.training_script)
|
| 631 |
+
if not entry_point.endswith(".py"):
|
| 632 |
+
raise ValueError(f'Your training script should be a python script and not "{entry_point}"')
|
| 633 |
+
|
| 634 |
+
print("Converting Arguments to Hyperparameters")
|
| 635 |
+
hyperparameters = _convert_nargs_to_dict(args.training_script_args)
|
| 636 |
+
|
| 637 |
+
try:
|
| 638 |
+
mixed_precision = PrecisionType(args.mixed_precision.lower())
|
| 639 |
+
except ValueError:
|
| 640 |
+
raise ValueError(
|
| 641 |
+
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
try:
|
| 645 |
+
dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
|
| 646 |
+
except ValueError:
|
| 647 |
+
raise ValueError(
|
| 648 |
+
f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Environment variables to be set for use during training job
|
| 652 |
+
environment = {
|
| 653 |
+
"ACCELERATE_USE_SAGEMAKER": "true",
|
| 654 |
+
"ACCELERATE_MIXED_PRECISION": str(mixed_precision),
|
| 655 |
+
"ACCELERATE_DYNAMO_BACKEND": dynamo_backend.value,
|
| 656 |
+
"ACCELERATE_DYNAMO_MODE": args.dynamo_mode,
|
| 657 |
+
"ACCELERATE_DYNAMO_USE_FULLGRAPH": str(args.dynamo_use_fullgraph),
|
| 658 |
+
"ACCELERATE_DYNAMO_USE_DYNAMIC": str(args.dynamo_use_dynamic),
|
| 659 |
+
"ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION": str(args.dynamo_use_regional_compilation),
|
| 660 |
+
"ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value,
|
| 661 |
+
}
|
| 662 |
+
if args.mixed_precision.lower() == "fp8":
|
| 663 |
+
if not is_fp8_available():
|
| 664 |
+
raise RuntimeError(
|
| 665 |
+
"FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
|
| 666 |
+
)
|
| 667 |
+
environment = setup_fp8_env(args, environment)
|
| 668 |
+
# configure distribution set up
|
| 669 |
+
distribution = None
|
| 670 |
+
if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL:
|
| 671 |
+
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
|
| 672 |
+
|
| 673 |
+
# configure sagemaker inputs
|
| 674 |
+
sagemaker_inputs = None
|
| 675 |
+
if sagemaker_config.sagemaker_inputs_file is not None:
|
| 676 |
+
print(f"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file")
|
| 677 |
+
sagemaker_inputs = {}
|
| 678 |
+
with open(sagemaker_config.sagemaker_inputs_file) as file:
|
| 679 |
+
for i, line in enumerate(file):
|
| 680 |
+
if i == 0:
|
| 681 |
+
continue
|
| 682 |
+
l = line.split("\t")
|
| 683 |
+
sagemaker_inputs[l[0]] = l[1].strip()
|
| 684 |
+
print(f"Loaded SageMaker Inputs: {sagemaker_inputs}")
|
| 685 |
+
|
| 686 |
+
# configure sagemaker metrics
|
| 687 |
+
sagemaker_metrics = None
|
| 688 |
+
if sagemaker_config.sagemaker_metrics_file is not None:
|
| 689 |
+
print(f"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file")
|
| 690 |
+
sagemaker_metrics = []
|
| 691 |
+
with open(sagemaker_config.sagemaker_metrics_file) as file:
|
| 692 |
+
for i, line in enumerate(file):
|
| 693 |
+
if i == 0:
|
| 694 |
+
continue
|
| 695 |
+
l = line.split("\t")
|
| 696 |
+
metric_dict = {
|
| 697 |
+
"Name": l[0],
|
| 698 |
+
"Regex": l[1].strip(),
|
| 699 |
+
}
|
| 700 |
+
sagemaker_metrics.append(metric_dict)
|
| 701 |
+
print(f"Loaded SageMaker Metrics: {sagemaker_metrics}")
|
| 702 |
+
|
| 703 |
+
# configure session
|
| 704 |
+
print("Creating Estimator")
|
| 705 |
+
args = {
|
| 706 |
+
"image_uri": sagemaker_config.image_uri,
|
| 707 |
+
"entry_point": entry_point,
|
| 708 |
+
"source_dir": source_dir,
|
| 709 |
+
"role": sagemaker_config.iam_role_name,
|
| 710 |
+
"transformers_version": sagemaker_config.transformers_version,
|
| 711 |
+
"pytorch_version": sagemaker_config.pytorch_version,
|
| 712 |
+
"py_version": sagemaker_config.py_version,
|
| 713 |
+
"base_job_name": sagemaker_config.base_job_name,
|
| 714 |
+
"instance_count": sagemaker_config.num_machines,
|
| 715 |
+
"instance_type": sagemaker_config.ec2_instance_type,
|
| 716 |
+
"debugger_hook_config": False,
|
| 717 |
+
"distribution": distribution,
|
| 718 |
+
"hyperparameters": hyperparameters,
|
| 719 |
+
"environment": environment,
|
| 720 |
+
"metric_definitions": sagemaker_metrics,
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
if sagemaker_config.additional_args is not None:
|
| 724 |
+
args = merge_dicts(sagemaker_config.additional_args, args)
|
| 725 |
+
return args, sagemaker_inputs
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def env_var_path_add(env_var_name, path_to_add):
|
| 729 |
+
"""
|
| 730 |
+
Extends a path-based environment variable's value with a new path and returns the updated value. It's up to the
|
| 731 |
+
caller to set it in os.environ.
|
| 732 |
+
"""
|
| 733 |
+
paths = [p for p in os.environ.get(env_var_name, "").split(":") if len(p) > 0]
|
| 734 |
+
paths.append(str(path_to_add))
|
| 735 |
+
return ":".join(paths)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
class PrepareForLaunch:
|
| 739 |
+
"""
|
| 740 |
+
Prepare a function that will launched in a distributed setup.
|
| 741 |
+
|
| 742 |
+
Args:
|
| 743 |
+
launcher (`Callable`):
|
| 744 |
+
The function to launch.
|
| 745 |
+
distributed_type ([`~state.DistributedType`]):
|
| 746 |
+
The distributed type to prepare for.
|
| 747 |
+
debug (`bool`, *optional*, defaults to `False`):
|
| 748 |
+
Whether or not this is a debug launch.
|
| 749 |
+
"""
|
| 750 |
+
|
| 751 |
+
def __init__(self, launcher, distributed_type="NO", debug=False):
|
| 752 |
+
self.launcher = launcher
|
| 753 |
+
self.distributed_type = DistributedType(distributed_type)
|
| 754 |
+
self.debug = debug
|
| 755 |
+
|
| 756 |
+
def __call__(self, index, *args):
|
| 757 |
+
if self.debug:
|
| 758 |
+
world_size = int(os.environ.get("WORLD_SIZE"))
|
| 759 |
+
rdv_file = os.environ.get("ACCELERATE_DEBUG_RDV_FILE")
|
| 760 |
+
torch.distributed.init_process_group(
|
| 761 |
+
"gloo",
|
| 762 |
+
rank=index,
|
| 763 |
+
store=torch.distributed.FileStore(rdv_file, world_size),
|
| 764 |
+
world_size=world_size,
|
| 765 |
+
)
|
| 766 |
+
elif self.distributed_type in (
|
| 767 |
+
DistributedType.MULTI_GPU,
|
| 768 |
+
DistributedType.MULTI_MLU,
|
| 769 |
+
DistributedType.MULTI_MUSA,
|
| 770 |
+
DistributedType.MULTI_NPU,
|
| 771 |
+
DistributedType.MULTI_XPU,
|
| 772 |
+
DistributedType.MULTI_CPU,
|
| 773 |
+
):
|
| 774 |
+
# Prepare the environment for torch.distributed
|
| 775 |
+
os.environ["LOCAL_RANK"] = str(index)
|
| 776 |
+
nproc = int(os.environ.get("NPROC", 1))
|
| 777 |
+
node_rank = int(os.environ.get("NODE_RANK", 0))
|
| 778 |
+
os.environ["RANK"] = str(nproc * node_rank + index)
|
| 779 |
+
|
| 780 |
+
os.environ["FORK_LAUNCHED"] = str(1)
|
| 781 |
+
self.launcher(*args)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/megatron_lm.py
ADDED
|
@@ -0,0 +1,1424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
from abc import ABC
|
| 19 |
+
from functools import partial
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
| 25 |
+
|
| 26 |
+
from ..optimizer import AcceleratedOptimizer
|
| 27 |
+
from ..scheduler import AcceleratedScheduler
|
| 28 |
+
from .imports import is_megatron_lm_available
|
| 29 |
+
from .operations import recursively_apply, send_to_device
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if is_megatron_lm_available():
|
| 33 |
+
from megatron.core import mpu, tensor_parallel
|
| 34 |
+
from megatron.core.distributed import DistributedDataParallel as LocalDDP
|
| 35 |
+
from megatron.core.distributed import finalize_model_grads
|
| 36 |
+
from megatron.core.enums import ModelType
|
| 37 |
+
from megatron.core.num_microbatches_calculator import get_num_microbatches
|
| 38 |
+
from megatron.core.optimizer import get_megatron_optimizer
|
| 39 |
+
from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank
|
| 40 |
+
from megatron.core.pipeline_parallel import get_forward_backward_func
|
| 41 |
+
from megatron.core.utils import get_model_config
|
| 42 |
+
from megatron.inference.text_generation.communication import broadcast_int_list, broadcast_tensor
|
| 43 |
+
from megatron.inference.text_generation.generation import (
|
| 44 |
+
beam_search_and_return_on_first_stage,
|
| 45 |
+
generate_tokens_probs_and_return_on_first_stage,
|
| 46 |
+
)
|
| 47 |
+
from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets
|
| 48 |
+
from megatron.legacy.model import BertModel, Float16Module, GPTModel, T5Model
|
| 49 |
+
from megatron.legacy.model.classification import Classification
|
| 50 |
+
from megatron.training import (
|
| 51 |
+
get_args,
|
| 52 |
+
get_tensorboard_writer,
|
| 53 |
+
get_tokenizer,
|
| 54 |
+
print_rank_last,
|
| 55 |
+
)
|
| 56 |
+
from megatron.training.arguments import (
|
| 57 |
+
_add_data_args,
|
| 58 |
+
_add_validation_args,
|
| 59 |
+
core_transformer_config_from_args,
|
| 60 |
+
parse_args,
|
| 61 |
+
validate_args,
|
| 62 |
+
)
|
| 63 |
+
from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint
|
| 64 |
+
from megatron.training.global_vars import set_global_variables
|
| 65 |
+
from megatron.training.initialize import (
|
| 66 |
+
_compile_dependencies,
|
| 67 |
+
_init_autoresume,
|
| 68 |
+
_initialize_distributed,
|
| 69 |
+
_set_random_seed,
|
| 70 |
+
set_jit_fusion_options,
|
| 71 |
+
write_args_to_tensorboard,
|
| 72 |
+
)
|
| 73 |
+
from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
|
| 74 |
+
from megatron.training.training import (
|
| 75 |
+
build_train_valid_test_data_iterators,
|
| 76 |
+
get_optimizer_param_scheduler,
|
| 77 |
+
num_floating_point_operations,
|
| 78 |
+
setup_model_and_optimizer,
|
| 79 |
+
train_step,
|
| 80 |
+
training_log,
|
| 81 |
+
)
|
| 82 |
+
from megatron.training.utils import (
|
| 83 |
+
average_losses_across_data_parallel_group,
|
| 84 |
+
calc_params_l2_norm,
|
| 85 |
+
get_ltor_masks_and_position_ids,
|
| 86 |
+
unwrap_model,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# model utilities
|
| 91 |
+
def model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True):
|
| 92 |
+
"""Build the model."""
|
| 93 |
+
args = get_args()
|
| 94 |
+
mode = "pre-training" if args.pretraining_flag else "fine-tuning"
|
| 95 |
+
if args.rank == 0:
|
| 96 |
+
print(f"Building {args.model_type_name} model in the {mode} mode.")
|
| 97 |
+
print(
|
| 98 |
+
"The Megatron LM model weights are initialized at random in `accelerator.prepare`. "
|
| 99 |
+
"Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup."
|
| 100 |
+
)
|
| 101 |
+
config = core_transformer_config_from_args(args)
|
| 102 |
+
if args.model_type_name == "bert":
|
| 103 |
+
if args.pretraining_flag:
|
| 104 |
+
num_tokentypes = 2 if args.bert_binary_head else 0
|
| 105 |
+
model = BertModel(
|
| 106 |
+
config=config,
|
| 107 |
+
num_tokentypes=num_tokentypes,
|
| 108 |
+
add_binary_head=args.bert_binary_head,
|
| 109 |
+
parallel_output=True,
|
| 110 |
+
pre_process=pre_process,
|
| 111 |
+
post_process=post_process,
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
model = Classification(
|
| 115 |
+
config=config,
|
| 116 |
+
num_classes=args.num_labels,
|
| 117 |
+
num_tokentypes=2,
|
| 118 |
+
pre_process=pre_process,
|
| 119 |
+
post_process=post_process,
|
| 120 |
+
)
|
| 121 |
+
elif args.model_type_name == "gpt":
|
| 122 |
+
model = GPTModel(
|
| 123 |
+
config=config,
|
| 124 |
+
num_tokentypes=0,
|
| 125 |
+
parallel_output=True,
|
| 126 |
+
pre_process=pre_process,
|
| 127 |
+
post_process=post_process,
|
| 128 |
+
)
|
| 129 |
+
elif args.model_type_name == "t5":
|
| 130 |
+
model = T5Model(
|
| 131 |
+
config=config,
|
| 132 |
+
num_tokentypes=0,
|
| 133 |
+
parallel_output=True,
|
| 134 |
+
pre_process=pre_process,
|
| 135 |
+
post_process=post_process,
|
| 136 |
+
add_encoder=add_encoder,
|
| 137 |
+
add_decoder=add_decoder,
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(f"Unsupported model type: {args.model_type_name}")
|
| 141 |
+
return model
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def prepare_model_optimizer_scheduler(accelerator):
|
| 145 |
+
accelerator.print("Preparing model optimizer scheduler")
|
| 146 |
+
args = get_args()
|
| 147 |
+
if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None:
|
| 148 |
+
if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`."
|
| 151 |
+
)
|
| 152 |
+
custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function
|
| 153 |
+
model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func)
|
| 154 |
+
optimizer = prepare_optimizer(accelerator, model)
|
| 155 |
+
scheduler = prepare_scheduler(accelerator, optimizer, scheduler=None)
|
| 156 |
+
else:
|
| 157 |
+
model_type = ModelType.encoder_or_decoder
|
| 158 |
+
if args.model_type_name == "t5":
|
| 159 |
+
model_type = ModelType.encoder_and_decoder
|
| 160 |
+
model_provider_func_ = model_provider_func
|
| 161 |
+
if accelerator.state.megatron_lm_plugin.custom_model_provider_function is not None:
|
| 162 |
+
model_provider_func_ = accelerator.state.megatron_lm_plugin.custom_model_provider_function
|
| 163 |
+
(model, optimizer, scheduler) = setup_model_and_optimizer(
|
| 164 |
+
model_provider_func_,
|
| 165 |
+
model_type,
|
| 166 |
+
no_wd_decay_cond=args.no_wd_decay_cond,
|
| 167 |
+
scale_lr_cond=args.scale_lr_cond,
|
| 168 |
+
lr_mult=args.lr_mult,
|
| 169 |
+
)
|
| 170 |
+
args.model_len = len(model)
|
| 171 |
+
return model, optimizer, scheduler
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# dataloader utilities
|
| 175 |
+
class MegatronLMDummyDataLoader:
|
| 176 |
+
"""
|
| 177 |
+
Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
**dataset_kwargs: Megatron data arguments.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, **dataset_kwargs):
|
| 184 |
+
parser = argparse.ArgumentParser()
|
| 185 |
+
parser = _add_data_args(parser)
|
| 186 |
+
parser = _add_validation_args(parser)
|
| 187 |
+
data_args = parser.parse_known_args()
|
| 188 |
+
self.dataset_args = vars(data_args[0])
|
| 189 |
+
self.dataset_args.update(dataset_kwargs)
|
| 190 |
+
self.dataset_args["megatron_dataset_flag"] = True
|
| 191 |
+
|
| 192 |
+
def set_megatron_data_args(self):
|
| 193 |
+
args = get_args()
|
| 194 |
+
for key, value in self.dataset_args.items():
|
| 195 |
+
old_value = getattr(args, key, "")
|
| 196 |
+
if old_value != value:
|
| 197 |
+
print(
|
| 198 |
+
f"WARNING: MegatronLMDummyDataLoader overriding arguments for {key}:{old_value} with {key}:{value}"
|
| 199 |
+
)
|
| 200 |
+
setattr(args, key, value)
|
| 201 |
+
|
| 202 |
+
def get_train_valid_test_datasets_provider(self, accelerator):
|
| 203 |
+
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
| 204 |
+
"""Build train, valid, and test datasets."""
|
| 205 |
+
args = get_args()
|
| 206 |
+
dataset_args = {
|
| 207 |
+
"data_prefix": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path],
|
| 208 |
+
"splits_string": args.split,
|
| 209 |
+
"train_valid_test_num_samples": train_val_test_num_samples,
|
| 210 |
+
"seed": args.seed,
|
| 211 |
+
}
|
| 212 |
+
if args.model_type_name == "bert":
|
| 213 |
+
dataset_args.update(
|
| 214 |
+
{
|
| 215 |
+
"max_seq_length": args.seq_length,
|
| 216 |
+
"binary_head": args.bert_binary_head,
|
| 217 |
+
}
|
| 218 |
+
)
|
| 219 |
+
elif args.model_type_name == "gpt":
|
| 220 |
+
dataset_args.update(
|
| 221 |
+
{
|
| 222 |
+
"max_seq_length": args.seq_length,
|
| 223 |
+
}
|
| 224 |
+
)
|
| 225 |
+
elif args.model_type_name == "t5":
|
| 226 |
+
dataset_args.update(
|
| 227 |
+
{
|
| 228 |
+
"max_seq_length": args.encoder_seq_length,
|
| 229 |
+
"max_seq_length_dec": args.decoder_seq_length,
|
| 230 |
+
"dataset_type": "t5",
|
| 231 |
+
}
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"Unsupported model type: {args.model_type_name}")
|
| 235 |
+
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args)
|
| 236 |
+
return train_ds, valid_ds, test_ds
|
| 237 |
+
|
| 238 |
+
if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None:
|
| 239 |
+
return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function
|
| 240 |
+
try:
|
| 241 |
+
args = get_args()
|
| 242 |
+
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
|
| 243 |
+
if args.model_type_name == "bert":
|
| 244 |
+
from pretrain_bert import train_valid_test_datasets_provider
|
| 245 |
+
|
| 246 |
+
train_valid_test_datasets_provider.is_distributed = True
|
| 247 |
+
return train_valid_test_datasets_provider
|
| 248 |
+
elif args.model_type_name == "gpt":
|
| 249 |
+
from pretrain_gpt import train_valid_test_datasets_provider
|
| 250 |
+
|
| 251 |
+
train_valid_test_datasets_provider.is_distributed = True
|
| 252 |
+
return train_valid_test_datasets_provider
|
| 253 |
+
elif args.model_type_name == "t5":
|
| 254 |
+
from pretrain_t5 import train_valid_test_datasets_provider
|
| 255 |
+
|
| 256 |
+
train_valid_test_datasets_provider.is_distributed = True
|
| 257 |
+
return train_valid_test_datasets_provider
|
| 258 |
+
except ImportError:
|
| 259 |
+
pass
|
| 260 |
+
return train_valid_test_datasets_provider
|
| 261 |
+
|
| 262 |
+
def build_train_valid_test_data_iterators(self, accelerator):
|
| 263 |
+
args = get_args()
|
| 264 |
+
|
| 265 |
+
train_valid_test_dataset_provider = self.get_train_valid_test_datasets_provider(accelerator)
|
| 266 |
+
if args.virtual_pipeline_model_parallel_size is not None:
|
| 267 |
+
train_data_iterator = []
|
| 268 |
+
valid_data_iterator = []
|
| 269 |
+
test_data_iterator = []
|
| 270 |
+
for i in range(getattr(args, "model_len", 0)):
|
| 271 |
+
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
| 272 |
+
iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
|
| 273 |
+
train_data_iterator.append(iterators[0])
|
| 274 |
+
valid_data_iterator.append(iterators[1])
|
| 275 |
+
test_data_iterator.append(iterators[2])
|
| 276 |
+
else:
|
| 277 |
+
train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators(
|
| 278 |
+
train_valid_test_dataset_provider
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return train_data_iterator, valid_data_iterator, test_data_iterator
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _handle_megatron_data_iterator(accelerator, data_iterator):
|
| 285 |
+
class DummyMegatronDataloader:
|
| 286 |
+
def __iter__(self):
|
| 287 |
+
return self
|
| 288 |
+
|
| 289 |
+
def __next__(self):
|
| 290 |
+
return {}
|
| 291 |
+
|
| 292 |
+
is_data_iterator_empty = data_iterator is None
|
| 293 |
+
is_src_data_iterator_empty = torch.tensor(is_data_iterator_empty, dtype=torch.bool, device=accelerator.device)
|
| 294 |
+
torch.distributed.broadcast(
|
| 295 |
+
is_src_data_iterator_empty, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
|
| 296 |
+
)
|
| 297 |
+
if not is_src_data_iterator_empty and is_data_iterator_empty:
|
| 298 |
+
return DummyMegatronDataloader()
|
| 299 |
+
return data_iterator
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def prepare_data_loader(accelerator, dataloader):
|
| 303 |
+
accelerator.print("Preparing dataloader")
|
| 304 |
+
args = get_args()
|
| 305 |
+
if not args.megatron_dataset_flag:
|
| 306 |
+
from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader
|
| 307 |
+
|
| 308 |
+
micro_batch_size = args.micro_batch_size * args.num_micro_batches
|
| 309 |
+
kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS}
|
| 310 |
+
if kwargs["batch_size"] is None:
|
| 311 |
+
if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler):
|
| 312 |
+
kwargs["sampler"].batch_size = micro_batch_size
|
| 313 |
+
else:
|
| 314 |
+
del kwargs["sampler"]
|
| 315 |
+
del kwargs["shuffle"]
|
| 316 |
+
del kwargs["batch_size"]
|
| 317 |
+
kwargs["batch_sampler"].batch_size = micro_batch_size
|
| 318 |
+
else:
|
| 319 |
+
del kwargs["batch_sampler"]
|
| 320 |
+
kwargs["batch_size"] = micro_batch_size
|
| 321 |
+
|
| 322 |
+
dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs)
|
| 323 |
+
# split_batches:
|
| 324 |
+
# Megatron only needs to fetch different data between different dp groups,
|
| 325 |
+
# and does not need to split the data within the dp group.
|
| 326 |
+
return prepare_data_loader(
|
| 327 |
+
dataloader,
|
| 328 |
+
accelerator.device,
|
| 329 |
+
num_processes=mpu.get_data_parallel_world_size(),
|
| 330 |
+
process_index=mpu.get_data_parallel_rank(),
|
| 331 |
+
split_batches=False,
|
| 332 |
+
put_on_device=True,
|
| 333 |
+
rng_types=accelerator.rng_types.copy(),
|
| 334 |
+
dispatch_batches=accelerator.dispatch_batches,
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
if args.consumed_samples is not None:
|
| 338 |
+
(
|
| 339 |
+
args.consumed_train_samples,
|
| 340 |
+
args.consumed_valid_samples,
|
| 341 |
+
args.consumed_test_samples,
|
| 342 |
+
) = args.consumed_samples
|
| 343 |
+
else:
|
| 344 |
+
args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0
|
| 345 |
+
args.micro_batch_size = args.micro_batch_size * args.num_micro_batches
|
| 346 |
+
# In order to be compatible with data in transform format,
|
| 347 |
+
# it needs to increase the size of mbs first,
|
| 348 |
+
# and then split the large batch data into some mbs.
|
| 349 |
+
(
|
| 350 |
+
train_data_iterator,
|
| 351 |
+
valid_data_iterator,
|
| 352 |
+
test_data_iterator,
|
| 353 |
+
) = dataloader.build_train_valid_test_data_iterators(accelerator)
|
| 354 |
+
args.micro_batch_size = args.micro_batch_size // args.num_micro_batches
|
| 355 |
+
|
| 356 |
+
train_data_iterator = _handle_megatron_data_iterator(
|
| 357 |
+
accelerator=accelerator, data_iterator=train_data_iterator
|
| 358 |
+
)
|
| 359 |
+
valid_data_iterator = _handle_megatron_data_iterator(
|
| 360 |
+
accelerator=accelerator, data_iterator=valid_data_iterator
|
| 361 |
+
)
|
| 362 |
+
test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator)
|
| 363 |
+
|
| 364 |
+
return train_data_iterator, valid_data_iterator, test_data_iterator
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# optimizer utilities
|
| 368 |
+
class MegatronLMOptimizerWrapper(AcceleratedOptimizer):
|
| 369 |
+
def __init__(self, optimizer):
|
| 370 |
+
super().__init__(optimizer, device_placement=False, scaler=None)
|
| 371 |
+
|
| 372 |
+
def zero_grad(self, set_to_none=None):
|
| 373 |
+
pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
|
| 374 |
+
|
| 375 |
+
def step(self):
|
| 376 |
+
pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def step_was_skipped(self):
|
| 380 |
+
"""Whether or not the optimizer step was done, or skipped because of gradient overflow."""
|
| 381 |
+
return self.optimizer.skipped_iter
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def prepare_optimizer(accelerator, model):
|
| 385 |
+
accelerator.print("Preparing optimizer")
|
| 386 |
+
args = get_args()
|
| 387 |
+
return get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# scheduler utilities
|
| 391 |
+
class MegatronLMDummyScheduler:
|
| 392 |
+
"""
|
| 393 |
+
Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
|
| 394 |
+
loop when scheduler config is specified in the deepspeed config file.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
optimizer (`torch.optim.optimizer.Optimizer`):
|
| 398 |
+
The optimizer to wrap.
|
| 399 |
+
total_num_steps (int):
|
| 400 |
+
Total number of steps.
|
| 401 |
+
warmup_num_steps (int):
|
| 402 |
+
Number of steps for warmup.
|
| 403 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 404 |
+
Other arguments.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs):
|
| 408 |
+
self.optimizer = optimizer
|
| 409 |
+
self.total_num_steps = total_num_steps
|
| 410 |
+
self.warmup_num_steps = warmup_num_steps
|
| 411 |
+
self.kwargs = kwargs
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class MegatronLMSchedulerWrapper(AcceleratedScheduler):
|
| 415 |
+
def __init__(self, scheduler, optimizers):
|
| 416 |
+
super().__init__(scheduler, optimizers)
|
| 417 |
+
|
| 418 |
+
def step(self, *args, **kwargs):
|
| 419 |
+
return # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def prepare_scheduler(accelerator, optimizer, scheduler):
|
| 423 |
+
accelerator.print("Preparing scheduler")
|
| 424 |
+
scheduler = get_optimizer_param_scheduler(optimizer)
|
| 425 |
+
return scheduler
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class AbstractTrainStep(ABC):
|
| 429 |
+
"""Abstract class for batching, forward pass and loss handler."""
|
| 430 |
+
|
| 431 |
+
def __init__(self, name):
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.name = name
|
| 434 |
+
|
| 435 |
+
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
| 436 |
+
pass
|
| 437 |
+
|
| 438 |
+
def get_forward_step_func(self):
|
| 439 |
+
pass
|
| 440 |
+
|
| 441 |
+
def get_loss_func(self, accelerator):
|
| 442 |
+
pass
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class BertTrainStep(AbstractTrainStep):
|
| 446 |
+
"""
|
| 447 |
+
Bert train step class.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
args (`argparse.Namespace`): Megatron-LM arguments.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
def __init__(self, accelerator, args):
|
| 454 |
+
super().__init__("BertTrainStep")
|
| 455 |
+
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
|
| 456 |
+
self.loss_func = self.get_loss_func(accelerator, args.pretraining_flag, args.num_labels)
|
| 457 |
+
self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head)
|
| 458 |
+
if not args.model_return_dict:
|
| 459 |
+
self.model_output_class = None
|
| 460 |
+
else:
|
| 461 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 462 |
+
|
| 463 |
+
self.model_output_class = SequenceClassifierOutput
|
| 464 |
+
|
| 465 |
+
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
| 466 |
+
def get_batch_megatron(data_iterator):
|
| 467 |
+
"""Build the batch."""
|
| 468 |
+
|
| 469 |
+
# Items and their type.
|
| 470 |
+
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
|
| 471 |
+
datatype = torch.int64
|
| 472 |
+
|
| 473 |
+
# Broadcast data.
|
| 474 |
+
if data_iterator is not None:
|
| 475 |
+
data = next(data_iterator)
|
| 476 |
+
else:
|
| 477 |
+
data = None
|
| 478 |
+
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
|
| 479 |
+
|
| 480 |
+
# Unpack.
|
| 481 |
+
tokens = data_b["text"].long()
|
| 482 |
+
types = data_b["types"].long()
|
| 483 |
+
sentence_order = data_b["is_random"].long()
|
| 484 |
+
loss_mask = data_b["loss_mask"].float()
|
| 485 |
+
lm_labels = data_b["labels"].long()
|
| 486 |
+
padding_mask = data_b["padding_mask"].long()
|
| 487 |
+
|
| 488 |
+
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
| 489 |
+
|
| 490 |
+
def get_batch_transformer(data_iterator):
|
| 491 |
+
"""Build the batch."""
|
| 492 |
+
data = next(data_iterator)
|
| 493 |
+
data = send_to_device(data, torch.cuda.current_device())
|
| 494 |
+
|
| 495 |
+
# Unpack.
|
| 496 |
+
tokens = data["input_ids"].long()
|
| 497 |
+
padding_mask = data["attention_mask"].long()
|
| 498 |
+
if "token_type_ids" in data:
|
| 499 |
+
types = data["token_type_ids"].long()
|
| 500 |
+
else:
|
| 501 |
+
types = None
|
| 502 |
+
if "labels" in data:
|
| 503 |
+
lm_labels = data["labels"].long()
|
| 504 |
+
loss_mask = (data["labels"] != -100).to(torch.float)
|
| 505 |
+
else:
|
| 506 |
+
lm_labels = None
|
| 507 |
+
loss_mask = None
|
| 508 |
+
if "next_sentence_label" in data:
|
| 509 |
+
sentence_order = data["next_sentence_label"].long()
|
| 510 |
+
else:
|
| 511 |
+
sentence_order = None
|
| 512 |
+
|
| 513 |
+
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
| 514 |
+
|
| 515 |
+
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
|
| 516 |
+
return accelerator.state.megatron_lm_plugin.custom_get_batch_function
|
| 517 |
+
if megatron_dataset_flag:
|
| 518 |
+
try:
|
| 519 |
+
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
|
| 520 |
+
from pretrain_bert import get_batch
|
| 521 |
+
|
| 522 |
+
return get_batch
|
| 523 |
+
except ImportError:
|
| 524 |
+
pass
|
| 525 |
+
return get_batch_megatron
|
| 526 |
+
else:
|
| 527 |
+
return get_batch_transformer
|
| 528 |
+
|
| 529 |
+
def get_loss_func(self, accelerator, pretraining_flag, num_labels):
|
| 530 |
+
def loss_func_pretrain(loss_mask, sentence_order, output_tensor):
|
| 531 |
+
lm_loss_, sop_logits = output_tensor
|
| 532 |
+
|
| 533 |
+
lm_loss_ = lm_loss_.float()
|
| 534 |
+
loss_mask = loss_mask.float()
|
| 535 |
+
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
|
| 536 |
+
|
| 537 |
+
if sop_logits is not None:
|
| 538 |
+
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)
|
| 539 |
+
sop_loss = sop_loss.float()
|
| 540 |
+
loss = lm_loss + sop_loss
|
| 541 |
+
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
|
| 542 |
+
return loss, {"lm loss": averaged_losses[0], "sop loss": averaged_losses[1]}
|
| 543 |
+
|
| 544 |
+
else:
|
| 545 |
+
loss = lm_loss
|
| 546 |
+
averaged_losses = average_losses_across_data_parallel_group([lm_loss])
|
| 547 |
+
return loss, {"lm loss": averaged_losses[0]}
|
| 548 |
+
|
| 549 |
+
def loss_func_finetune(labels, logits):
|
| 550 |
+
if num_labels == 1:
|
| 551 |
+
# We are doing regression
|
| 552 |
+
loss_fct = MSELoss()
|
| 553 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
| 554 |
+
elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
|
| 555 |
+
loss_fct = CrossEntropyLoss()
|
| 556 |
+
loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
|
| 557 |
+
else:
|
| 558 |
+
loss_fct = BCEWithLogitsLoss()
|
| 559 |
+
loss = loss_fct(logits, labels)
|
| 560 |
+
averaged_losses = average_losses_across_data_parallel_group([loss])
|
| 561 |
+
return loss, {"loss": averaged_losses[0]}
|
| 562 |
+
|
| 563 |
+
if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
|
| 564 |
+
return accelerator.state.megatron_lm_plugin.custom_loss_function
|
| 565 |
+
if pretraining_flag:
|
| 566 |
+
return loss_func_pretrain
|
| 567 |
+
else:
|
| 568 |
+
return loss_func_finetune
|
| 569 |
+
|
| 570 |
+
def get_forward_step_func(self, pretraining_flag, bert_binary_head):
|
| 571 |
+
def forward_step(data_iterator, model):
|
| 572 |
+
"""Forward step."""
|
| 573 |
+
tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator)
|
| 574 |
+
if not bert_binary_head:
|
| 575 |
+
types = None
|
| 576 |
+
# Forward pass through the model.
|
| 577 |
+
if pretraining_flag:
|
| 578 |
+
output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels)
|
| 579 |
+
return output_tensor, partial(self.loss_func, loss_mask, sentence_order)
|
| 580 |
+
else:
|
| 581 |
+
logits = model(tokens, padding_mask, tokentype_ids=types)
|
| 582 |
+
return logits, partial(self.loss_func, labels)
|
| 583 |
+
|
| 584 |
+
return forward_step
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class GPTTrainStep(AbstractTrainStep):
|
| 588 |
+
"""
|
| 589 |
+
GPT train step class.
|
| 590 |
+
|
| 591 |
+
Args:
|
| 592 |
+
args (`argparse.Namespace`): Megatron-LM arguments.
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
+
def __init__(self, accelerator, args):
|
| 596 |
+
super().__init__("GPTTrainStep")
|
| 597 |
+
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
|
| 598 |
+
self.loss_func = self.get_loss_func(accelerator)
|
| 599 |
+
self.forward_step = self.get_forward_step_func()
|
| 600 |
+
self.eod_token = args.padded_vocab_size - 1
|
| 601 |
+
if args.vocab_file is not None:
|
| 602 |
+
tokenizer = get_tokenizer()
|
| 603 |
+
self.eod_token = tokenizer.eod
|
| 604 |
+
self.reset_position_ids = args.reset_position_ids
|
| 605 |
+
self.reset_attention_mask = args.reset_attention_mask
|
| 606 |
+
self.eod_mask_loss = args.eod_mask_loss
|
| 607 |
+
if not args.model_return_dict:
|
| 608 |
+
self.model_output_class = None
|
| 609 |
+
else:
|
| 610 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 611 |
+
|
| 612 |
+
self.model_output_class = CausalLMOutputWithCrossAttentions
|
| 613 |
+
|
| 614 |
+
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
| 615 |
+
def get_batch_megatron(data_iterator):
|
| 616 |
+
"""Generate a batch"""
|
| 617 |
+
# Items and their type.
|
| 618 |
+
keys = ["text"]
|
| 619 |
+
datatype = torch.int64
|
| 620 |
+
|
| 621 |
+
# Broadcast data.
|
| 622 |
+
if data_iterator is not None:
|
| 623 |
+
data = next(data_iterator)
|
| 624 |
+
else:
|
| 625 |
+
data = None
|
| 626 |
+
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
|
| 627 |
+
|
| 628 |
+
# Unpack.
|
| 629 |
+
tokens_ = data_b["text"].long()
|
| 630 |
+
labels = tokens_[:, 1:].contiguous()
|
| 631 |
+
tokens = tokens_[:, :-1].contiguous()
|
| 632 |
+
|
| 633 |
+
# Get the masks and position ids.
|
| 634 |
+
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
| 635 |
+
tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
return tokens, labels, loss_mask, attention_mask, position_ids
|
| 639 |
+
|
| 640 |
+
def get_batch_transformer(data_iterator):
|
| 641 |
+
data = next(data_iterator)
|
| 642 |
+
data = {"input_ids": data["input_ids"]}
|
| 643 |
+
data = send_to_device(data, torch.cuda.current_device())
|
| 644 |
+
|
| 645 |
+
tokens_ = data["input_ids"].long()
|
| 646 |
+
padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token
|
| 647 |
+
tokens_ = torch.concat([tokens_, padding], dim=1)
|
| 648 |
+
labels = tokens_[:, 1:].contiguous()
|
| 649 |
+
tokens = tokens_[:, :-1].contiguous()
|
| 650 |
+
# Get the masks and position ids.
|
| 651 |
+
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
|
| 652 |
+
tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, True
|
| 653 |
+
)
|
| 654 |
+
return tokens, labels, loss_mask, attention_mask, position_ids
|
| 655 |
+
|
| 656 |
+
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
|
| 657 |
+
return accelerator.state.megatron_lm_plugin.custom_get_batch_function
|
| 658 |
+
if megatron_dataset_flag:
|
| 659 |
+
try:
|
| 660 |
+
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
|
| 661 |
+
from pretrain_gpt import get_batch
|
| 662 |
+
|
| 663 |
+
return get_batch
|
| 664 |
+
except ImportError:
|
| 665 |
+
pass
|
| 666 |
+
return get_batch_megatron
|
| 667 |
+
else:
|
| 668 |
+
return get_batch_transformer
|
| 669 |
+
|
| 670 |
+
def get_loss_func(self, accelerator):
|
| 671 |
+
args = get_args()
|
| 672 |
+
|
| 673 |
+
def loss_func(loss_mask, output_tensor):
|
| 674 |
+
if args.return_logits:
|
| 675 |
+
losses, logits = output_tensor
|
| 676 |
+
else:
|
| 677 |
+
losses = output_tensor
|
| 678 |
+
losses = losses.float()
|
| 679 |
+
loss_mask = loss_mask.view(-1).float()
|
| 680 |
+
if args.context_parallel_size > 1:
|
| 681 |
+
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
|
| 682 |
+
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
|
| 683 |
+
loss = loss[0] / loss[1]
|
| 684 |
+
else:
|
| 685 |
+
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
| 686 |
+
|
| 687 |
+
# Check individual rank losses are not NaN prior to DP all-reduce.
|
| 688 |
+
if args.check_for_nan_in_loss_and_grad:
|
| 689 |
+
global_rank = torch.distributed.get_rank()
|
| 690 |
+
assert not loss.isnan(), (
|
| 691 |
+
f"Rank {global_rank}: found NaN in local forward loss calculation. "
|
| 692 |
+
f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}"
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# Reduce loss for logging.
|
| 696 |
+
averaged_loss = average_losses_across_data_parallel_group([loss])
|
| 697 |
+
|
| 698 |
+
output_dict = {"lm loss": averaged_loss[0]}
|
| 699 |
+
if args.return_logits:
|
| 700 |
+
output_dict.update({"logits": logits})
|
| 701 |
+
return loss, output_dict
|
| 702 |
+
|
| 703 |
+
if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
|
| 704 |
+
return accelerator.state.megatron_lm_plugin.custom_loss_function
|
| 705 |
+
return loss_func
|
| 706 |
+
|
| 707 |
+
def get_forward_step_func(self):
|
| 708 |
+
def forward_step(data_iterator, model):
|
| 709 |
+
"""Forward step."""
|
| 710 |
+
# Get the batch.
|
| 711 |
+
tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator)
|
| 712 |
+
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
|
| 713 |
+
|
| 714 |
+
return output_tensor, partial(self.loss_func, loss_mask)
|
| 715 |
+
|
| 716 |
+
return forward_step
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
class T5TrainStep(AbstractTrainStep):
|
| 720 |
+
"""
|
| 721 |
+
T5 train step class.
|
| 722 |
+
|
| 723 |
+
Args:
|
| 724 |
+
args (`argparse.Namespace`): Megatron-LM arguments.
|
| 725 |
+
"""
|
| 726 |
+
|
| 727 |
+
def __init__(self, accelerator, args):
|
| 728 |
+
super().__init__("T5TrainStep")
|
| 729 |
+
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
|
| 730 |
+
self.loss_func = self.get_loss_func(accelerator)
|
| 731 |
+
self.forward_step = self.get_forward_step_func()
|
| 732 |
+
if not args.model_return_dict:
|
| 733 |
+
self.model_output_class = None
|
| 734 |
+
else:
|
| 735 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput
|
| 736 |
+
|
| 737 |
+
self.model_output_class = Seq2SeqLMOutput
|
| 738 |
+
|
| 739 |
+
@staticmethod
|
| 740 |
+
def attn_mask_postprocess(attention_mask):
|
| 741 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 742 |
+
# [b, 1, s]
|
| 743 |
+
attention_mask_b1s = attention_mask.unsqueeze(1)
|
| 744 |
+
# [b, s, 1]
|
| 745 |
+
attention_mask_bs1 = attention_mask.unsqueeze(2)
|
| 746 |
+
# [b, s, s]
|
| 747 |
+
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
|
| 748 |
+
# Convert attention mask to binary:
|
| 749 |
+
extended_attention_mask = attention_mask_bss < 0.5
|
| 750 |
+
return extended_attention_mask
|
| 751 |
+
|
| 752 |
+
@staticmethod
|
| 753 |
+
def get_decoder_mask(seq_length, device):
|
| 754 |
+
attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device))
|
| 755 |
+
attention_mask = attention_mask < 0.5
|
| 756 |
+
return attention_mask
|
| 757 |
+
|
| 758 |
+
@staticmethod
|
| 759 |
+
def get_enc_dec_mask(attention_mask, dec_seq_length, device):
|
| 760 |
+
batch_size, _ = attention_mask.shape
|
| 761 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 762 |
+
# [b, 1, s]
|
| 763 |
+
attention_mask_b1s = attention_mask.unsqueeze(1)
|
| 764 |
+
# [b, s, 1]
|
| 765 |
+
attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device)
|
| 766 |
+
attention_mask_bss = attention_mask_bs1 * attention_mask_b1s
|
| 767 |
+
extended_attention_mask = attention_mask_bss < 0.5
|
| 768 |
+
return extended_attention_mask
|
| 769 |
+
|
| 770 |
+
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
| 771 |
+
def get_batch_megatron(data_iterator):
|
| 772 |
+
"""Build the batch."""
|
| 773 |
+
|
| 774 |
+
keys = ["text_enc", "text_dec", "labels", "loss_mask", "enc_mask", "dec_mask", "enc_dec_mask"]
|
| 775 |
+
datatype = torch.int64
|
| 776 |
+
|
| 777 |
+
# Broadcast data.
|
| 778 |
+
if data_iterator is not None:
|
| 779 |
+
data = next(data_iterator)
|
| 780 |
+
else:
|
| 781 |
+
data = None
|
| 782 |
+
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
|
| 783 |
+
|
| 784 |
+
# Unpack.
|
| 785 |
+
tokens_enc = data_b["text_enc"].long()
|
| 786 |
+
tokens_dec = data_b["text_dec"].long()
|
| 787 |
+
labels = data_b["labels"].long()
|
| 788 |
+
loss_mask = data_b["loss_mask"].float()
|
| 789 |
+
|
| 790 |
+
enc_mask = data_b["enc_mask"] < 0.5
|
| 791 |
+
dec_mask = data_b["dec_mask"] < 0.5
|
| 792 |
+
enc_dec_mask = data_b["enc_dec_mask"] < 0.5
|
| 793 |
+
|
| 794 |
+
return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
|
| 795 |
+
|
| 796 |
+
def get_batch_transformer(data_iterator):
|
| 797 |
+
"""Build the batch."""
|
| 798 |
+
data = next(data_iterator)
|
| 799 |
+
data = send_to_device(data, torch.cuda.current_device())
|
| 800 |
+
|
| 801 |
+
tokens_enc = data["input_ids"].long()
|
| 802 |
+
labels = data["labels"].long()
|
| 803 |
+
loss_mask = (labels != -100).to(torch.float)
|
| 804 |
+
if "decoder_input_ids" in data:
|
| 805 |
+
tokens_dec = data["decoder_input_ids"].long()
|
| 806 |
+
else:
|
| 807 |
+
tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long)
|
| 808 |
+
tokens_dec[..., 1:] = labels[..., :-1].clone()
|
| 809 |
+
tokens_dec[..., 0] = 0
|
| 810 |
+
tokens_dec.masked_fill_(tokens_dec == -100, 0)
|
| 811 |
+
enc_mask = T5TrainStep.attn_mask_postprocess(data["attention_mask"].long())
|
| 812 |
+
dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device)
|
| 813 |
+
enc_dec_mask = T5TrainStep.get_enc_dec_mask(
|
| 814 |
+
data["attention_mask"].long(), tokens_dec.shape[1], tokens_dec.device
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
|
| 818 |
+
|
| 819 |
+
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
|
| 820 |
+
return accelerator.state.megatron_lm_plugin.custom_get_batch_function
|
| 821 |
+
if megatron_dataset_flag:
|
| 822 |
+
try:
|
| 823 |
+
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
|
| 824 |
+
from pretrain_t5 import get_batch
|
| 825 |
+
|
| 826 |
+
return get_batch
|
| 827 |
+
except ImportError:
|
| 828 |
+
pass
|
| 829 |
+
return get_batch_megatron
|
| 830 |
+
else:
|
| 831 |
+
return get_batch_transformer
|
| 832 |
+
|
| 833 |
+
def get_loss_func(self, accelerator):
|
| 834 |
+
def loss_func(loss_mask, output_tensor):
|
| 835 |
+
lm_loss_ = output_tensor.float()
|
| 836 |
+
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
|
| 837 |
+
|
| 838 |
+
loss = lm_loss
|
| 839 |
+
averaged_losses = average_losses_across_data_parallel_group([lm_loss])
|
| 840 |
+
|
| 841 |
+
return loss, {"lm loss": averaged_losses[0]}
|
| 842 |
+
|
| 843 |
+
if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
|
| 844 |
+
return accelerator.state.megatron_lm_plugin.custom_loss_function
|
| 845 |
+
return loss_func
|
| 846 |
+
|
| 847 |
+
def get_forward_step_func(self):
|
| 848 |
+
def forward_step(data_iterator, model):
|
| 849 |
+
"""Forward step."""
|
| 850 |
+
# Get the batch.
|
| 851 |
+
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch(
|
| 852 |
+
data_iterator
|
| 853 |
+
)
|
| 854 |
+
# Forward model lm_labels
|
| 855 |
+
output_tensor = model(
|
| 856 |
+
tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
return output_tensor, partial(self.loss_func, loss_mask)
|
| 860 |
+
|
| 861 |
+
return forward_step
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def finish_mpu_init():
|
| 865 |
+
# torch.distributed initialization
|
| 866 |
+
args = get_args()
|
| 867 |
+
# Pytorch distributed.
|
| 868 |
+
_initialize_distributed()
|
| 869 |
+
|
| 870 |
+
# Random seeds for reproducibility.
|
| 871 |
+
if args.rank == 0:
|
| 872 |
+
print(f"> setting random seeds to {args.seed} ...")
|
| 873 |
+
_set_random_seed(args.seed, args.data_parallel_random_init)
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
# initialize megatron setup
|
| 877 |
+
def initialize(accelerator, extra_args_provider=None, args_defaults={}):
|
| 878 |
+
accelerator.print("Initializing Megatron-LM")
|
| 879 |
+
assert torch.cuda.is_available(), "Megatron requires CUDA."
|
| 880 |
+
|
| 881 |
+
# Parse arguments
|
| 882 |
+
args = parse_args(extra_args_provider, ignore_unknown_args=True)
|
| 883 |
+
|
| 884 |
+
# Set defaults
|
| 885 |
+
for key, value in args_defaults.items():
|
| 886 |
+
if getattr(args, key, None) is not None:
|
| 887 |
+
if args.rank == 0:
|
| 888 |
+
print(
|
| 889 |
+
f"WARNING: overriding default arguments for {key}:{getattr(args, key)} with {key}:{value}",
|
| 890 |
+
flush=True,
|
| 891 |
+
)
|
| 892 |
+
setattr(args, key, value)
|
| 893 |
+
|
| 894 |
+
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
|
| 895 |
+
assert args.load is not None, "--use-checkpoints-args requires --load argument"
|
| 896 |
+
load_args_from_checkpoint(args)
|
| 897 |
+
|
| 898 |
+
validate_args(args)
|
| 899 |
+
|
| 900 |
+
# set global args, build tokenizer, and set adlr-autoresume,
|
| 901 |
+
# tensorboard-writer, and timers.
|
| 902 |
+
set_global_variables(args)
|
| 903 |
+
|
| 904 |
+
# Megatron's MPU is the master. Complete initialization right away.
|
| 905 |
+
finish_mpu_init()
|
| 906 |
+
|
| 907 |
+
# Autoresume.
|
| 908 |
+
_init_autoresume()
|
| 909 |
+
|
| 910 |
+
# Compile dependencies.
|
| 911 |
+
_compile_dependencies()
|
| 912 |
+
|
| 913 |
+
# Set pytorch JIT layer fusion options and warmup JIT functions.
|
| 914 |
+
set_jit_fusion_options()
|
| 915 |
+
args = get_args()
|
| 916 |
+
if getattr(args, "padded_vocab_size", None) is None:
|
| 917 |
+
args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args)
|
| 918 |
+
if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2:
|
| 919 |
+
args.bert_binary_head = True
|
| 920 |
+
else:
|
| 921 |
+
args.bert_binary_head = False
|
| 922 |
+
args.iteration = 0
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
class MegatronEngine(torch.nn.Module):
|
| 926 |
+
"""
|
| 927 |
+
Megatron-LM model wrapper
|
| 928 |
+
|
| 929 |
+
Args:
|
| 930 |
+
accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use.
|
| 931 |
+
model: Megatron-LM model
|
| 932 |
+
optimizer: Megatron-LM optimizer
|
| 933 |
+
lr_scheduler: Megatron-LM lr scheduler
|
| 934 |
+
"""
|
| 935 |
+
|
| 936 |
+
def __init__(self, accelerator, model, optimizer, scheduler):
|
| 937 |
+
super().__init__()
|
| 938 |
+
self.module = model
|
| 939 |
+
self.base_model = model[0]
|
| 940 |
+
self.optimizer = optimizer
|
| 941 |
+
self.scheduler = scheduler
|
| 942 |
+
args = get_args()
|
| 943 |
+
if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None:
|
| 944 |
+
self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class(
|
| 945 |
+
args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs
|
| 946 |
+
)
|
| 947 |
+
elif args.model_type_name == "bert":
|
| 948 |
+
self.train_step_handler = BertTrainStep(accelerator, args)
|
| 949 |
+
elif args.model_type_name == "gpt":
|
| 950 |
+
self.train_step_handler = GPTTrainStep(accelerator, args)
|
| 951 |
+
elif args.model_type_name == "t5":
|
| 952 |
+
self.train_step_handler = T5TrainStep(accelerator, args)
|
| 953 |
+
else:
|
| 954 |
+
raise ValueError(f"Unsupported model type: {args.model_type_name}")
|
| 955 |
+
self.optimizer.skipped_iter = False
|
| 956 |
+
|
| 957 |
+
# Tracking loss.
|
| 958 |
+
self.total_loss_dict = {}
|
| 959 |
+
self.eval_total_loss_dict = {}
|
| 960 |
+
self.iteration = 0
|
| 961 |
+
self.report_memory_flag = True
|
| 962 |
+
self.num_floating_point_operations_so_far = 0
|
| 963 |
+
self.module_config = None
|
| 964 |
+
if args.tensorboard_dir is not None:
|
| 965 |
+
write_args_to_tensorboard()
|
| 966 |
+
|
| 967 |
+
def get_module_config(self):
|
| 968 |
+
args = get_args()
|
| 969 |
+
config = get_model_config(self.module[0])
|
| 970 |
+
# Setup some training config params
|
| 971 |
+
config.grad_scale_func = self.optimizer.scale_loss
|
| 972 |
+
if isinstance(self.module[0], LocalDDP) and args.overlap_grad_reduce:
|
| 973 |
+
assert config.no_sync_func is None, (
|
| 974 |
+
"When overlap_grad_reduce is True, config.no_sync_func must be None; "
|
| 975 |
+
"a custom no_sync_func is not supported when overlapping grad-reduce"
|
| 976 |
+
)
|
| 977 |
+
config.no_sync_func = [model_chunk.no_sync for model_chunk in self.module]
|
| 978 |
+
if len(self.module) == 1:
|
| 979 |
+
config.no_sync_func = config.no_sync_func[0]
|
| 980 |
+
if args.delay_grad_reduce:
|
| 981 |
+
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.module]
|
| 982 |
+
if len(self.module) == 1:
|
| 983 |
+
config.grad_sync_func = config.grad_sync_func[0]
|
| 984 |
+
if args.overlap_param_gather and args.delay_param_gather:
|
| 985 |
+
config.param_sync_func = [
|
| 986 |
+
lambda x: self.optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.module))
|
| 987 |
+
]
|
| 988 |
+
if len(self.module) == 1:
|
| 989 |
+
config.param_sync_func = config.param_sync_func[0]
|
| 990 |
+
config.finalize_model_grads_func = finalize_model_grads
|
| 991 |
+
return config
|
| 992 |
+
|
| 993 |
+
def train(self):
|
| 994 |
+
for model_module in self.module:
|
| 995 |
+
model_module.train()
|
| 996 |
+
|
| 997 |
+
if self.module_config is None:
|
| 998 |
+
self.module_config = self.get_module_config()
|
| 999 |
+
|
| 1000 |
+
self.log_eval_results()
|
| 1001 |
+
|
| 1002 |
+
def eval(self):
|
| 1003 |
+
for model_module in self.module:
|
| 1004 |
+
model_module.eval()
|
| 1005 |
+
|
| 1006 |
+
if self.module_config is None:
|
| 1007 |
+
self.module_config = self.get_module_config()
|
| 1008 |
+
|
| 1009 |
+
def get_batch_data_iterator(self, batch_data):
|
| 1010 |
+
args = get_args()
|
| 1011 |
+
data_chunks = []
|
| 1012 |
+
if len(batch_data) > 0:
|
| 1013 |
+
if args.num_micro_batches > 1:
|
| 1014 |
+
for i in range(0, args.num_micro_batches):
|
| 1015 |
+
data_chunks.append(
|
| 1016 |
+
{
|
| 1017 |
+
k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size]
|
| 1018 |
+
for k, v in batch_data.items()
|
| 1019 |
+
}
|
| 1020 |
+
)
|
| 1021 |
+
else:
|
| 1022 |
+
data_chunks = [batch_data]
|
| 1023 |
+
|
| 1024 |
+
if len(self.module) > 1:
|
| 1025 |
+
batch_data_iterator = (
|
| 1026 |
+
[iter(data_chunks) for _ in range(len(self.module))]
|
| 1027 |
+
if len(batch_data) > 0
|
| 1028 |
+
else [None] * len(self.module)
|
| 1029 |
+
)
|
| 1030 |
+
else:
|
| 1031 |
+
batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None
|
| 1032 |
+
return batch_data_iterator
|
| 1033 |
+
|
| 1034 |
+
def train_step(self, **batch_data):
|
| 1035 |
+
"""
|
| 1036 |
+
Training step for Megatron-LM
|
| 1037 |
+
|
| 1038 |
+
Args:
|
| 1039 |
+
batch_data (:obj:`dict`): The batch data to train on.
|
| 1040 |
+
"""
|
| 1041 |
+
|
| 1042 |
+
batch_data_iterator = self.get_batch_data_iterator(batch_data)
|
| 1043 |
+
|
| 1044 |
+
loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
|
| 1045 |
+
forward_step_func=self.train_step_handler.forward_step,
|
| 1046 |
+
data_iterator=batch_data_iterator,
|
| 1047 |
+
model=self.module,
|
| 1048 |
+
optimizer=self.optimizer,
|
| 1049 |
+
opt_param_scheduler=self.scheduler,
|
| 1050 |
+
config=self.module_config,
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
self.optimizer.skipped_iter = skipped_iter == 1
|
| 1054 |
+
|
| 1055 |
+
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
|
| 1056 |
+
|
| 1057 |
+
def eval_step(self, **batch_data):
|
| 1058 |
+
"""
|
| 1059 |
+
Evaluation step for Megatron-LM
|
| 1060 |
+
|
| 1061 |
+
Args:
|
| 1062 |
+
batch_data (:obj:`dict`): The batch data to evaluate on.
|
| 1063 |
+
"""
|
| 1064 |
+
|
| 1065 |
+
args = get_args()
|
| 1066 |
+
batch_data_iterator = self.get_batch_data_iterator(batch_data)
|
| 1067 |
+
forward_backward_func = get_forward_backward_func()
|
| 1068 |
+
loss_dicts = forward_backward_func(
|
| 1069 |
+
forward_step_func=self.train_step_handler.forward_step,
|
| 1070 |
+
data_iterator=batch_data_iterator,
|
| 1071 |
+
model=self.module,
|
| 1072 |
+
num_microbatches=get_num_microbatches(),
|
| 1073 |
+
seq_length=args.seq_length,
|
| 1074 |
+
micro_batch_size=args.micro_batch_size,
|
| 1075 |
+
forward_only=True,
|
| 1076 |
+
)
|
| 1077 |
+
# Empty unused memory
|
| 1078 |
+
if args.empty_unused_memory_level >= 1:
|
| 1079 |
+
torch.cuda.empty_cache()
|
| 1080 |
+
|
| 1081 |
+
args.consumed_valid_samples += (
|
| 1082 |
+
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
| 1086 |
+
# Average loss across microbatches.
|
| 1087 |
+
loss_reduced = {}
|
| 1088 |
+
for key in loss_dicts[0]:
|
| 1089 |
+
losses_reduced_for_key = [x[key] for x in loss_dicts]
|
| 1090 |
+
if len(losses_reduced_for_key[0].shape) == 0:
|
| 1091 |
+
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
|
| 1092 |
+
else:
|
| 1093 |
+
loss_reduced[key] = torch.concat(losses_reduced_for_key)
|
| 1094 |
+
return loss_reduced
|
| 1095 |
+
return {}
|
| 1096 |
+
|
| 1097 |
+
def forward(self, **batch_data):
|
| 1098 |
+
# During training, we use train_step()
|
| 1099 |
+
# model(**batch_data) performs following operations by delegating it to `self.train_step`:
|
| 1100 |
+
# 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism
|
| 1101 |
+
# 2. Set grad to zero.
|
| 1102 |
+
# 3. forward pass and backward pass using Pipeline Parallelism
|
| 1103 |
+
# 4. Empty unused memory.
|
| 1104 |
+
# 5. Reduce gradients.
|
| 1105 |
+
# 6. Update parameters.
|
| 1106 |
+
# 7. Gather params when using Distributed Optimizer (Data Parallelism).
|
| 1107 |
+
# 8. Update learning rate if scheduler is specified.
|
| 1108 |
+
# 9. Empty unused memory.
|
| 1109 |
+
# 10. Average loss across microbatches and across DP ranks.
|
| 1110 |
+
#
|
| 1111 |
+
# During evaluation, we use eval_step()
|
| 1112 |
+
args = get_args()
|
| 1113 |
+
if self.module[0].training:
|
| 1114 |
+
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data)
|
| 1115 |
+
self.iteration += 1
|
| 1116 |
+
batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
|
| 1117 |
+
args.consumed_train_samples += batch_size
|
| 1118 |
+
self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size)
|
| 1119 |
+
if args.tensorboard_dir is not None:
|
| 1120 |
+
# Logging.
|
| 1121 |
+
loss_scale = self.optimizer.get_loss_scale().item()
|
| 1122 |
+
params_norm = None
|
| 1123 |
+
if args.log_params_norm:
|
| 1124 |
+
params_norm = calc_params_l2_norm(self.model)
|
| 1125 |
+
self.report_memory_flag = training_log(
|
| 1126 |
+
loss_dict,
|
| 1127 |
+
self.total_loss_dict,
|
| 1128 |
+
self.optimizer.param_groups[0]["lr"],
|
| 1129 |
+
self.iteration,
|
| 1130 |
+
loss_scale,
|
| 1131 |
+
self.report_memory_flag,
|
| 1132 |
+
skipped_iter,
|
| 1133 |
+
grad_norm,
|
| 1134 |
+
params_norm,
|
| 1135 |
+
num_zeros_in_grad,
|
| 1136 |
+
)
|
| 1137 |
+
else:
|
| 1138 |
+
loss_dict = self.eval_step(**batch_data)
|
| 1139 |
+
if args.tensorboard_dir is not None:
|
| 1140 |
+
for key in loss_dict:
|
| 1141 |
+
self.eval_total_loss_dict[key] = (
|
| 1142 |
+
self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
|
| 1143 |
+
)
|
| 1144 |
+
self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get(
|
| 1145 |
+
key + "_num_iters", torch.cuda.FloatTensor([0.0])
|
| 1146 |
+
) + torch.cuda.FloatTensor([1.0])
|
| 1147 |
+
|
| 1148 |
+
loss = torch.tensor(0.0, device=torch.cuda.current_device())
|
| 1149 |
+
for key in loss_dict:
|
| 1150 |
+
if len(loss_dict[key].shape) == 0:
|
| 1151 |
+
loss += loss_dict[key]
|
| 1152 |
+
|
| 1153 |
+
logits = None
|
| 1154 |
+
if "logits" in loss_dict:
|
| 1155 |
+
logits = loss_dict["logits"]
|
| 1156 |
+
if self.train_step_handler.model_output_class is not None:
|
| 1157 |
+
return self.train_step_handler.model_output_class(loss=loss, logits=logits)
|
| 1158 |
+
return loss
|
| 1159 |
+
|
| 1160 |
+
def log_eval_results(self):
|
| 1161 |
+
args = get_args()
|
| 1162 |
+
if args.tensorboard_dir is None or self.iteration == 0:
|
| 1163 |
+
return
|
| 1164 |
+
args = get_args()
|
| 1165 |
+
writer = get_tensorboard_writer()
|
| 1166 |
+
string = f"validation loss at iteration {self.iteration} | "
|
| 1167 |
+
for key in self.eval_total_loss_dict:
|
| 1168 |
+
if key.endswith("_num_iters"):
|
| 1169 |
+
continue
|
| 1170 |
+
value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + "_num_iters"]
|
| 1171 |
+
string += f"{key} value: {value} | "
|
| 1172 |
+
ppl = math.exp(min(20, value.item()))
|
| 1173 |
+
if args.pretraining_flag:
|
| 1174 |
+
string += f"{key} PPL: {ppl} | "
|
| 1175 |
+
if writer:
|
| 1176 |
+
writer.add_scalar(f"{key} validation", value.item(), self.iteration)
|
| 1177 |
+
if args.pretraining_flag:
|
| 1178 |
+
writer.add_scalar(f"{key} validation ppl", ppl, self.iteration)
|
| 1179 |
+
|
| 1180 |
+
length = len(string) + 1
|
| 1181 |
+
print_rank_last("-" * length)
|
| 1182 |
+
print_rank_last(string)
|
| 1183 |
+
print_rank_last("-" * length)
|
| 1184 |
+
self.eval_total_loss_dict = {}
|
| 1185 |
+
|
| 1186 |
+
def save_checkpoint(self, output_dir):
|
| 1187 |
+
self.log_eval_results()
|
| 1188 |
+
args = get_args()
|
| 1189 |
+
args.save = output_dir
|
| 1190 |
+
torch.distributed.barrier()
|
| 1191 |
+
save_checkpoint(
|
| 1192 |
+
self.iteration,
|
| 1193 |
+
self.module,
|
| 1194 |
+
self.optimizer,
|
| 1195 |
+
self.scheduler,
|
| 1196 |
+
num_floating_point_operations_so_far=self.num_floating_point_operations_so_far,
|
| 1197 |
+
)
|
| 1198 |
+
torch.distributed.barrier()
|
| 1199 |
+
|
| 1200 |
+
def load_checkpoint(self, input_dir):
|
| 1201 |
+
args = get_args()
|
| 1202 |
+
args.load = input_dir
|
| 1203 |
+
args.consumed_train_samples = 0
|
| 1204 |
+
args.consumed_valid_samples = 0
|
| 1205 |
+
torch.distributed.barrier()
|
| 1206 |
+
iteration, num_floating_point_operations_so_far = load_checkpoint(self.module, self.optimizer, self.scheduler)
|
| 1207 |
+
torch.distributed.barrier()
|
| 1208 |
+
self.iteration = iteration
|
| 1209 |
+
self.num_floating_point_operations_so_far = num_floating_point_operations_so_far
|
| 1210 |
+
if args.fp16 and self.iteration == 0:
|
| 1211 |
+
self.optimizer.reload_model_params()
|
| 1212 |
+
|
| 1213 |
+
def megatron_generate(
|
| 1214 |
+
self,
|
| 1215 |
+
inputs,
|
| 1216 |
+
attention_mask=None,
|
| 1217 |
+
max_length=None,
|
| 1218 |
+
max_new_tokens=None,
|
| 1219 |
+
num_beams=None,
|
| 1220 |
+
temperature=None,
|
| 1221 |
+
top_k=None,
|
| 1222 |
+
top_p=None,
|
| 1223 |
+
length_penalty=None,
|
| 1224 |
+
**kwargs,
|
| 1225 |
+
):
|
| 1226 |
+
"""
|
| 1227 |
+
Generate method for GPT2 model. This method is used for inference. Supports both greedy and beam search along
|
| 1228 |
+
with sampling. Refer the Megatron-LM repo for more details
|
| 1229 |
+
|
| 1230 |
+
Args:
|
| 1231 |
+
inputs (torch.Tensor): input ids
|
| 1232 |
+
attention_mask (torch.Tensor, optional): attention mask. Defaults to None.
|
| 1233 |
+
max_length (int, optional): max length of the generated sequence. Defaults to None.
|
| 1234 |
+
Either this or max_new_tokens should be provided.
|
| 1235 |
+
max_new_tokens (int, optional): max number of tokens to be generated. Defaults to None.
|
| 1236 |
+
Either this or max_length should be provided.
|
| 1237 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to None.
|
| 1238 |
+
temperature (float, optional): temperature for sampling. Defaults to 1.0.
|
| 1239 |
+
top_k (int, optional): top k tokens to consider for sampling. Defaults to 0.0.
|
| 1240 |
+
top_p (float, optional): tokens in top p probability are considered for sampling. Defaults to 0.0.
|
| 1241 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to None.
|
| 1242 |
+
kwargs: additional key-value arguments
|
| 1243 |
+
"""
|
| 1244 |
+
|
| 1245 |
+
# checking if required arguments are passed
|
| 1246 |
+
args = get_args()
|
| 1247 |
+
if args.model_type_name != "gpt":
|
| 1248 |
+
raise NotImplementedError("Generate method is not implemented for this model")
|
| 1249 |
+
|
| 1250 |
+
if args.data_parallel_size > 1:
|
| 1251 |
+
raise ValueError("Generate method requires data parallelism to be 1")
|
| 1252 |
+
|
| 1253 |
+
if args.sequence_parallel:
|
| 1254 |
+
raise ValueError("Generate method requires sequence parallelism to be False")
|
| 1255 |
+
|
| 1256 |
+
if args.recompute_granularity is not None:
|
| 1257 |
+
raise ValueError("Checkpoint activations cannot be set for inference")
|
| 1258 |
+
|
| 1259 |
+
if args.vocab_file is None:
|
| 1260 |
+
raise ValueError("Vocab file is required for inference")
|
| 1261 |
+
|
| 1262 |
+
# Prepare inputs
|
| 1263 |
+
if max_length is None and max_new_tokens is None:
|
| 1264 |
+
raise ValueError("`max_length` or `max_new_tokens` are required for inference")
|
| 1265 |
+
|
| 1266 |
+
if temperature is None:
|
| 1267 |
+
temperature = 1.0
|
| 1268 |
+
elif not (0.0 < temperature <= 100.0):
|
| 1269 |
+
raise ValueError("temperature must be a positive number less than or equal to 100.0")
|
| 1270 |
+
|
| 1271 |
+
if top_k is None:
|
| 1272 |
+
top_k = 0
|
| 1273 |
+
elif not (0 <= top_k <= 1000):
|
| 1274 |
+
raise ValueError("top_k must be a positive number less than or equal to 1000")
|
| 1275 |
+
|
| 1276 |
+
if top_p is None:
|
| 1277 |
+
top_p = 0.0
|
| 1278 |
+
elif top_p > 0.0 and top_k > 0.0:
|
| 1279 |
+
raise ValueError("top_p and top_k sampling cannot be set together")
|
| 1280 |
+
else:
|
| 1281 |
+
if not (0.0 <= top_p <= 1.0):
|
| 1282 |
+
raise ValueError("top_p must be less than or equal to 1.0")
|
| 1283 |
+
|
| 1284 |
+
top_p_decay = kwargs.get("top_p_decay", 0.0)
|
| 1285 |
+
if not (0.0 <= top_p_decay <= 1.0):
|
| 1286 |
+
raise ValueError("top_p_decay must be less than or equal to 1.0")
|
| 1287 |
+
|
| 1288 |
+
top_p_bound = kwargs.get("top_p_bound", 0.0)
|
| 1289 |
+
if not (0.0 <= top_p_bound <= 1.0):
|
| 1290 |
+
raise ValueError("top_p_bound must be less than or equal to 1.0")
|
| 1291 |
+
|
| 1292 |
+
add_BOS = kwargs.get("add_BOS", False)
|
| 1293 |
+
if not (isinstance(add_BOS, bool)):
|
| 1294 |
+
raise ValueError("add_BOS must be a boolean")
|
| 1295 |
+
|
| 1296 |
+
beam_width = num_beams
|
| 1297 |
+
if beam_width is not None:
|
| 1298 |
+
if not isinstance(beam_width, int):
|
| 1299 |
+
raise ValueError("beam_width must be an integer")
|
| 1300 |
+
if beam_width < 1:
|
| 1301 |
+
raise ValueError("beam_width must be greater than 0")
|
| 1302 |
+
if inputs.shape[0] > 1:
|
| 1303 |
+
return "When doing beam_search, batch size must be 1"
|
| 1304 |
+
|
| 1305 |
+
tokenizer = get_tokenizer()
|
| 1306 |
+
|
| 1307 |
+
stop_token = kwargs.get("stop_token", tokenizer.eod)
|
| 1308 |
+
if stop_token is not None:
|
| 1309 |
+
if not isinstance(stop_token, int):
|
| 1310 |
+
raise ValueError("stop_token must be an integer")
|
| 1311 |
+
|
| 1312 |
+
if length_penalty is None:
|
| 1313 |
+
length_penalty = 1.0
|
| 1314 |
+
|
| 1315 |
+
sizes_list = None
|
| 1316 |
+
prompts_tokens_tensor = None
|
| 1317 |
+
prompts_length_tensor = None
|
| 1318 |
+
if torch.distributed.get_rank() == 0:
|
| 1319 |
+
# Get the prompts length.
|
| 1320 |
+
if attention_mask is None:
|
| 1321 |
+
prompts_length_tensor = torch.cuda.LongTensor([inputs.shape[1]] * inputs.shape[0])
|
| 1322 |
+
else:
|
| 1323 |
+
prompts_length_tensor = attention_mask.sum(axis=-1).cuda()
|
| 1324 |
+
|
| 1325 |
+
if max_new_tokens is None:
|
| 1326 |
+
max_new_tokens = max_length - inputs.shape[1]
|
| 1327 |
+
if max_new_tokens <= 0:
|
| 1328 |
+
raise ValueError("max_new_tokens must be greater than 0")
|
| 1329 |
+
|
| 1330 |
+
if add_BOS:
|
| 1331 |
+
max_length = max_new_tokens + inputs.shape[1] + 1
|
| 1332 |
+
# making sure that `max_length` is a multiple of 4 to leverage fused kernels
|
| 1333 |
+
max_length = 4 * math.ceil(max_length / 4)
|
| 1334 |
+
max_new_tokens = max_length - (inputs.shape[1] + 1)
|
| 1335 |
+
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
|
| 1336 |
+
prompts_tokens_tensor = torch.concat(
|
| 1337 |
+
[torch.unsqueeze(padding[:, 0], axis=-1), inputs.cuda(), padding], axis=-1
|
| 1338 |
+
)
|
| 1339 |
+
else:
|
| 1340 |
+
# making sure that `max_length` is a multiple of 4 to leverage fused kernels
|
| 1341 |
+
max_length = max_new_tokens + inputs.shape[1]
|
| 1342 |
+
max_length = 4 * math.ceil(max_length / 4)
|
| 1343 |
+
max_new_tokens = max_length - inputs.shape[1]
|
| 1344 |
+
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
|
| 1345 |
+
prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1)
|
| 1346 |
+
|
| 1347 |
+
# We need the sizes of these tensors for the broadcast
|
| 1348 |
+
sizes_list = [
|
| 1349 |
+
prompts_tokens_tensor.size(0), # Batch size
|
| 1350 |
+
prompts_tokens_tensor.size(1),
|
| 1351 |
+
] # Sequence length
|
| 1352 |
+
|
| 1353 |
+
# First, broadcast the sizes.
|
| 1354 |
+
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0)
|
| 1355 |
+
|
| 1356 |
+
# Now that we have the sizes, we can broadcast the tokens
|
| 1357 |
+
# and length tensors.
|
| 1358 |
+
sizes = sizes_tensor.tolist()
|
| 1359 |
+
context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0)
|
| 1360 |
+
context_length_tensor = broadcast_tensor(sizes[0], torch.int64, tensor=prompts_length_tensor, rank=0)
|
| 1361 |
+
|
| 1362 |
+
# Run the inference
|
| 1363 |
+
random_seed = kwargs.get("random_seed", 0)
|
| 1364 |
+
torch.random.manual_seed(random_seed)
|
| 1365 |
+
unwrapped_model = unwrap_model(self.base_model, (torchDDP, LocalDDP, Float16Module))
|
| 1366 |
+
if beam_width is not None:
|
| 1367 |
+
tokens, _ = beam_search_and_return_on_first_stage(
|
| 1368 |
+
unwrapped_model,
|
| 1369 |
+
context_tokens_tensor,
|
| 1370 |
+
context_length_tensor,
|
| 1371 |
+
beam_width,
|
| 1372 |
+
stop_token=stop_token,
|
| 1373 |
+
num_return_gen=1,
|
| 1374 |
+
length_penalty=length_penalty,
|
| 1375 |
+
)
|
| 1376 |
+
else:
|
| 1377 |
+
tokens, _, _ = generate_tokens_probs_and_return_on_first_stage(
|
| 1378 |
+
unwrapped_model,
|
| 1379 |
+
context_tokens_tensor,
|
| 1380 |
+
context_length_tensor,
|
| 1381 |
+
return_output_log_probs=False,
|
| 1382 |
+
top_k=top_k,
|
| 1383 |
+
top_p=top_p,
|
| 1384 |
+
top_p_decay=top_p_decay,
|
| 1385 |
+
top_p_bound=top_p_bound,
|
| 1386 |
+
temperature=temperature,
|
| 1387 |
+
use_eod_token_for_early_termination=True,
|
| 1388 |
+
)
|
| 1389 |
+
return tokens
|
| 1390 |
+
|
| 1391 |
+
|
| 1392 |
+
# other utilities
|
| 1393 |
+
def avg_losses_across_data_parallel_group(losses):
|
| 1394 |
+
"""
|
| 1395 |
+
Average losses across data parallel group.
|
| 1396 |
+
|
| 1397 |
+
Args:
|
| 1398 |
+
losses (List[Tensor]): List of losses to average across data parallel group.
|
| 1399 |
+
"""
|
| 1400 |
+
|
| 1401 |
+
return average_losses_across_data_parallel_group(losses)
|
| 1402 |
+
|
| 1403 |
+
|
| 1404 |
+
def gather_across_data_parallel_groups(tensor):
|
| 1405 |
+
"""
|
| 1406 |
+
Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks.
|
| 1407 |
+
|
| 1408 |
+
Args:
|
| 1409 |
+
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
| 1410 |
+
The data to gather across data parallel ranks.
|
| 1411 |
+
|
| 1412 |
+
"""
|
| 1413 |
+
|
| 1414 |
+
def _gpu_gather_one(tensor):
|
| 1415 |
+
if tensor.ndim == 0:
|
| 1416 |
+
tensor = tensor.clone()[None]
|
| 1417 |
+
output_tensors = [
|
| 1418 |
+
torch.empty_like(tensor)
|
| 1419 |
+
for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group()))
|
| 1420 |
+
]
|
| 1421 |
+
torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group())
|
| 1422 |
+
return torch.cat(output_tensors, dim=0)
|
| 1423 |
+
|
| 1424 |
+
return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/memory.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
A collection of utilities for ensuring that training can always occur. Heavily influenced by the
|
| 17 |
+
[toma](https://github.com/BlackHC/toma) library.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import functools
|
| 21 |
+
import gc
|
| 22 |
+
import importlib
|
| 23 |
+
import inspect
|
| 24 |
+
import warnings
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from packaging import version
|
| 29 |
+
|
| 30 |
+
from .imports import (
|
| 31 |
+
is_cuda_available,
|
| 32 |
+
is_hpu_available,
|
| 33 |
+
is_ipex_available,
|
| 34 |
+
is_mlu_available,
|
| 35 |
+
is_mps_available,
|
| 36 |
+
is_musa_available,
|
| 37 |
+
is_npu_available,
|
| 38 |
+
is_sdaa_available,
|
| 39 |
+
is_xpu_available,
|
| 40 |
+
)
|
| 41 |
+
from .versions import compare_versions
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def clear_device_cache(garbage_collection=False):
|
| 45 |
+
"""
|
| 46 |
+
Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that
|
| 47 |
+
this is a *considerable* slowdown and should be used sparingly.
|
| 48 |
+
"""
|
| 49 |
+
if garbage_collection:
|
| 50 |
+
gc.collect()
|
| 51 |
+
|
| 52 |
+
if is_xpu_available():
|
| 53 |
+
torch.xpu.empty_cache()
|
| 54 |
+
elif is_mlu_available():
|
| 55 |
+
torch.mlu.empty_cache()
|
| 56 |
+
elif is_sdaa_available():
|
| 57 |
+
torch.sdaa.empty_cache()
|
| 58 |
+
elif is_musa_available():
|
| 59 |
+
torch.musa.empty_cache()
|
| 60 |
+
elif is_npu_available():
|
| 61 |
+
torch.npu.empty_cache()
|
| 62 |
+
elif is_mps_available(min_version="2.0"):
|
| 63 |
+
torch.mps.empty_cache()
|
| 64 |
+
elif is_cuda_available():
|
| 65 |
+
torch.cuda.empty_cache()
|
| 66 |
+
elif is_hpu_available():
|
| 67 |
+
# torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def release_memory(*objects):
|
| 72 |
+
"""
|
| 73 |
+
Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`.
|
| 74 |
+
Returned objects should be reassigned to the same variables.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
objects (`Iterable`):
|
| 78 |
+
An iterable of objects
|
| 79 |
+
Returns:
|
| 80 |
+
A list of `None` objects to replace `objects`
|
| 81 |
+
|
| 82 |
+
Example:
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
>>> import torch
|
| 86 |
+
>>> from accelerate.utils import release_memory
|
| 87 |
+
|
| 88 |
+
>>> a = torch.ones(1000, 1000).cuda()
|
| 89 |
+
>>> b = torch.ones(1000, 1000).cuda()
|
| 90 |
+
>>> a, b = release_memory(a, b)
|
| 91 |
+
```
|
| 92 |
+
"""
|
| 93 |
+
if not isinstance(objects, list):
|
| 94 |
+
objects = list(objects)
|
| 95 |
+
for i in range(len(objects)):
|
| 96 |
+
objects[i] = None
|
| 97 |
+
clear_device_cache(garbage_collection=True)
|
| 98 |
+
return objects
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def should_reduce_batch_size(exception: Exception) -> bool:
|
| 102 |
+
"""
|
| 103 |
+
Checks if `exception` relates to CUDA out-of-memory, XPU out-of-memory, CUDNN not supported, or CPU out-of-memory
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
exception (`Exception`):
|
| 107 |
+
An exception
|
| 108 |
+
"""
|
| 109 |
+
_statements = [
|
| 110 |
+
" out of memory.", # OOM for CUDA, HIP, XPU
|
| 111 |
+
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU
|
| 112 |
+
"DefaultCPUAllocator: can't allocate memory", # CPU OOM
|
| 113 |
+
"FATAL ERROR :: MODULE:PT_DEVMEM Allocation failed", # HPU OOM
|
| 114 |
+
]
|
| 115 |
+
if isinstance(exception, RuntimeError) and len(exception.args) == 1:
|
| 116 |
+
return any(err in exception.args[0] for err in _statements)
|
| 117 |
+
return False
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def find_executable_batch_size(
|
| 121 |
+
function: Optional[callable] = None,
|
| 122 |
+
starting_batch_size: int = 128,
|
| 123 |
+
reduce_batch_size_fn: Optional[callable] = None,
|
| 124 |
+
):
|
| 125 |
+
"""
|
| 126 |
+
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
| 127 |
+
CUDNN, the batch size is multiplied by 0.9 and passed to `function`
|
| 128 |
+
|
| 129 |
+
`function` must take in a `batch_size` parameter as its first argument.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
function (`callable`, *optional*):
|
| 133 |
+
A function to wrap
|
| 134 |
+
starting_batch_size (`int`, *optional*):
|
| 135 |
+
The batch size to try and fit into memory
|
| 136 |
+
|
| 137 |
+
Example:
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
>>> from accelerate.utils import find_executable_batch_size
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
>>> @find_executable_batch_size(starting_batch_size=128)
|
| 144 |
+
... def train(batch_size, model, optimizer):
|
| 145 |
+
... ...
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
>>> train(model, optimizer)
|
| 149 |
+
```
|
| 150 |
+
"""
|
| 151 |
+
if function is None:
|
| 152 |
+
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
|
| 153 |
+
|
| 154 |
+
batch_size = starting_batch_size
|
| 155 |
+
if reduce_batch_size_fn is None:
|
| 156 |
+
|
| 157 |
+
def reduce_batch_size_fn():
|
| 158 |
+
nonlocal batch_size
|
| 159 |
+
batch_size = int(batch_size * 0.9)
|
| 160 |
+
return batch_size
|
| 161 |
+
|
| 162 |
+
def decorator(*args, **kwargs):
|
| 163 |
+
nonlocal batch_size
|
| 164 |
+
clear_device_cache(garbage_collection=True)
|
| 165 |
+
params = list(inspect.signature(function).parameters.keys())
|
| 166 |
+
# Guard against user error
|
| 167 |
+
if len(params) < (len(args) + 1):
|
| 168 |
+
arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])])
|
| 169 |
+
raise TypeError(
|
| 170 |
+
f"Batch size was passed into `{function.__name__}` as the first argument when called."
|
| 171 |
+
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
|
| 172 |
+
)
|
| 173 |
+
while True:
|
| 174 |
+
if batch_size == 0:
|
| 175 |
+
raise RuntimeError("No executable batch size found, reached zero.")
|
| 176 |
+
try:
|
| 177 |
+
return function(batch_size, *args, **kwargs)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
if should_reduce_batch_size(e):
|
| 180 |
+
clear_device_cache(garbage_collection=True)
|
| 181 |
+
batch_size = reduce_batch_size_fn()
|
| 182 |
+
else:
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
return decorator
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_xpu_available_memory(device_index: int):
|
| 189 |
+
if version.parse(torch.__version__).release >= version.parse("2.6").release:
|
| 190 |
+
# torch.xpu.mem_get_info API is available starting from PyTorch 2.6
|
| 191 |
+
# It further requires PyTorch built with the SYCL runtime which supports API
|
| 192 |
+
# to query available device memory. If not available, exception will be
|
| 193 |
+
# raised. Version of SYCL runtime used to build PyTorch is being reported
|
| 194 |
+
# with print(torch.version.xpu) and corresponds to the version of Intel DPC++
|
| 195 |
+
# SYCL compiler. First version to support required feature is 20250001.
|
| 196 |
+
try:
|
| 197 |
+
return torch.xpu.mem_get_info(device_index)[0]
|
| 198 |
+
except Exception:
|
| 199 |
+
pass
|
| 200 |
+
elif is_ipex_available():
|
| 201 |
+
ipex_version = version.parse(importlib.metadata.version("intel_extension_for_pytorch"))
|
| 202 |
+
if compare_versions(ipex_version, ">=", "2.5"):
|
| 203 |
+
from intel_extension_for_pytorch.xpu import mem_get_info
|
| 204 |
+
|
| 205 |
+
return mem_get_info(device_index)[0]
|
| 206 |
+
|
| 207 |
+
warnings.warn(
|
| 208 |
+
"The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version."
|
| 209 |
+
)
|
| 210 |
+
return torch.xpu.max_memory_allocated(device_index)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/modeling.py
ADDED
|
@@ -0,0 +1,2186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import contextlib
|
| 16 |
+
import gc
|
| 17 |
+
import inspect
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
import shutil
|
| 23 |
+
import tempfile
|
| 24 |
+
import warnings
|
| 25 |
+
from collections import OrderedDict, defaultdict
|
| 26 |
+
from typing import Optional, Union
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from torch import distributed as dist
|
| 30 |
+
from torch import nn
|
| 31 |
+
|
| 32 |
+
from ..state import AcceleratorState
|
| 33 |
+
from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
| 34 |
+
from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
|
| 35 |
+
from .imports import (
|
| 36 |
+
is_hpu_available,
|
| 37 |
+
is_mlu_available,
|
| 38 |
+
is_mps_available,
|
| 39 |
+
is_musa_available,
|
| 40 |
+
is_npu_available,
|
| 41 |
+
is_peft_available,
|
| 42 |
+
is_sdaa_available,
|
| 43 |
+
is_torch_xla_available,
|
| 44 |
+
is_xpu_available,
|
| 45 |
+
)
|
| 46 |
+
from .memory import clear_device_cache, get_xpu_available_memory
|
| 47 |
+
from .offload import load_offloaded_weight, offload_weight, save_offload_index
|
| 48 |
+
from .tqdm import is_tqdm_available, tqdm
|
| 49 |
+
from .versions import is_torch_version
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if is_npu_available(check_device=False):
|
| 53 |
+
import torch_npu # noqa: F401
|
| 54 |
+
|
| 55 |
+
if is_mlu_available(check_device=False):
|
| 56 |
+
import torch_mlu # noqa: F401
|
| 57 |
+
|
| 58 |
+
if is_sdaa_available(check_device=False):
|
| 59 |
+
import torch_sdaa # noqa: F401
|
| 60 |
+
|
| 61 |
+
if is_musa_available(check_device=False):
|
| 62 |
+
import torch_musa # noqa: F401
|
| 63 |
+
|
| 64 |
+
from safetensors import safe_open
|
| 65 |
+
from safetensors.torch import load_file as safe_load_file
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def is_peft_model(model):
|
| 74 |
+
from .other import extract_model_from_parallel
|
| 75 |
+
|
| 76 |
+
if is_peft_available():
|
| 77 |
+
from peft import PeftModel
|
| 78 |
+
|
| 79 |
+
return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def check_device_same(first_device, second_device):
|
| 83 |
+
"""
|
| 84 |
+
Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False`
|
| 85 |
+
for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
first_device (`torch.device`):
|
| 89 |
+
First device to check
|
| 90 |
+
second_device (`torch.device`):
|
| 91 |
+
Second device to check
|
| 92 |
+
"""
|
| 93 |
+
if first_device.type != second_device.type:
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
if first_device.type != "cpu" and first_device.index is None:
|
| 97 |
+
# In case the first_device is a cuda device and have
|
| 98 |
+
# the index attribute set to `None`, default it to `0`
|
| 99 |
+
first_device = torch.device(first_device.type, index=0)
|
| 100 |
+
|
| 101 |
+
if second_device.type != "cpu" and second_device.index is None:
|
| 102 |
+
# In case the second_device is a cuda device and have
|
| 103 |
+
# the index attribute set to `None`, default it to `0`
|
| 104 |
+
second_device = torch.device(second_device.type, index=0)
|
| 105 |
+
|
| 106 |
+
return first_device == second_device
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def convert_file_size_to_int(size: Union[int, str]):
|
| 110 |
+
"""
|
| 111 |
+
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
| 115 |
+
|
| 116 |
+
Example:
|
| 117 |
+
|
| 118 |
+
```py
|
| 119 |
+
>>> convert_file_size_to_int("1MiB")
|
| 120 |
+
1048576
|
| 121 |
+
```
|
| 122 |
+
"""
|
| 123 |
+
mem_size = -1
|
| 124 |
+
err_msg = (
|
| 125 |
+
f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')."
|
| 126 |
+
)
|
| 127 |
+
try:
|
| 128 |
+
if isinstance(size, int):
|
| 129 |
+
mem_size = size
|
| 130 |
+
elif size.upper().endswith("GIB"):
|
| 131 |
+
mem_size = int(float(size[:-3]) * (2**30))
|
| 132 |
+
elif size.upper().endswith("MIB"):
|
| 133 |
+
mem_size = int(float(size[:-3]) * (2**20))
|
| 134 |
+
elif size.upper().endswith("KIB"):
|
| 135 |
+
mem_size = int(float(size[:-3]) * (2**10))
|
| 136 |
+
elif size.upper().endswith("GB"):
|
| 137 |
+
int_size = int(float(size[:-2]) * (10**9))
|
| 138 |
+
mem_size = int_size // 8 if size.endswith("b") else int_size
|
| 139 |
+
elif size.upper().endswith("MB"):
|
| 140 |
+
int_size = int(float(size[:-2]) * (10**6))
|
| 141 |
+
mem_size = int_size // 8 if size.endswith("b") else int_size
|
| 142 |
+
elif size.upper().endswith("KB"):
|
| 143 |
+
int_size = int(float(size[:-2]) * (10**3))
|
| 144 |
+
mem_size = int_size // 8 if size.endswith("b") else int_size
|
| 145 |
+
except ValueError:
|
| 146 |
+
raise ValueError(err_msg)
|
| 147 |
+
|
| 148 |
+
if mem_size < 0:
|
| 149 |
+
raise ValueError(err_msg)
|
| 150 |
+
return mem_size
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def dtype_byte_size(dtype: torch.dtype):
|
| 154 |
+
"""
|
| 155 |
+
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
| 156 |
+
|
| 157 |
+
Example:
|
| 158 |
+
|
| 159 |
+
```py
|
| 160 |
+
>>> dtype_byte_size(torch.float32)
|
| 161 |
+
4
|
| 162 |
+
```
|
| 163 |
+
"""
|
| 164 |
+
if dtype == torch.bool:
|
| 165 |
+
return 1 / 8
|
| 166 |
+
elif dtype == CustomDtype.INT2:
|
| 167 |
+
return 1 / 4
|
| 168 |
+
elif dtype == CustomDtype.INT4:
|
| 169 |
+
return 1 / 2
|
| 170 |
+
elif dtype == CustomDtype.FP8:
|
| 171 |
+
return 1
|
| 172 |
+
elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
| 173 |
+
return 1
|
| 174 |
+
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
| 175 |
+
if bit_search is None:
|
| 176 |
+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
| 177 |
+
bit_size = int(bit_search.groups()[0])
|
| 178 |
+
return bit_size // 8
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
|
| 182 |
+
"""
|
| 183 |
+
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
|
| 184 |
+
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
|
| 185 |
+
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
| 186 |
+
non-overlapping lifetimes may have the same id.
|
| 187 |
+
"""
|
| 188 |
+
_SIZE = {
|
| 189 |
+
torch.int64: 8,
|
| 190 |
+
torch.float32: 4,
|
| 191 |
+
torch.int32: 4,
|
| 192 |
+
torch.bfloat16: 2,
|
| 193 |
+
torch.float16: 2,
|
| 194 |
+
torch.int16: 2,
|
| 195 |
+
torch.uint8: 1,
|
| 196 |
+
torch.int8: 1,
|
| 197 |
+
torch.bool: 1,
|
| 198 |
+
torch.float64: 8,
|
| 199 |
+
}
|
| 200 |
+
try:
|
| 201 |
+
storage_ptr = tensor.untyped_storage().data_ptr()
|
| 202 |
+
storage_size = tensor.untyped_storage().nbytes()
|
| 203 |
+
except Exception:
|
| 204 |
+
try:
|
| 205 |
+
# Fallback for torch==1.10
|
| 206 |
+
storage_ptr = tensor.storage().data_ptr()
|
| 207 |
+
storage_size = tensor.storage().size() * _SIZE[tensor.dtype]
|
| 208 |
+
except NotImplementedError:
|
| 209 |
+
# Fallback for meta storage
|
| 210 |
+
storage_ptr = 0
|
| 211 |
+
# On torch >=2.0 this is the tensor size
|
| 212 |
+
storage_size = tensor.nelement() * _SIZE[tensor.dtype]
|
| 213 |
+
|
| 214 |
+
return tensor.device, storage_ptr, storage_size
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def set_module_tensor_to_device(
|
| 218 |
+
module: nn.Module,
|
| 219 |
+
tensor_name: str,
|
| 220 |
+
device: Union[int, str, torch.device],
|
| 221 |
+
value: Optional[torch.Tensor] = None,
|
| 222 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
| 223 |
+
fp16_statistics: Optional[torch.HalfTensor] = None,
|
| 224 |
+
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
| 225 |
+
non_blocking: bool = False,
|
| 226 |
+
clear_cache: bool = True,
|
| 227 |
+
):
|
| 228 |
+
"""
|
| 229 |
+
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
| 230 |
+
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
module (`torch.nn.Module`):
|
| 234 |
+
The module in which the tensor we want to move lives.
|
| 235 |
+
tensor_name (`str`):
|
| 236 |
+
The full name of the parameter/buffer.
|
| 237 |
+
device (`int`, `str` or `torch.device`):
|
| 238 |
+
The device on which to set the tensor.
|
| 239 |
+
value (`torch.Tensor`, *optional*):
|
| 240 |
+
The value of the tensor (useful when going from the meta device to any other device).
|
| 241 |
+
dtype (`torch.dtype`, *optional*):
|
| 242 |
+
If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
|
| 243 |
+
the dtype of the existing parameter in the model.
|
| 244 |
+
fp16_statistics (`torch.HalfTensor`, *optional*):
|
| 245 |
+
The list of fp16 statistics to set on the module, used for 8 bit model serialization.
|
| 246 |
+
tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`):
|
| 247 |
+
A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
|
| 248 |
+
execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
|
| 249 |
+
device for all others, instead of duplicating memory.
|
| 250 |
+
non_blocking (`bool`, *optional*, defaults to `False`):
|
| 251 |
+
If `True`, the device transfer will be asynchronous with respect to the host, if possible.
|
| 252 |
+
clear_cache (`bool`, *optional*, defaults to `True`):
|
| 253 |
+
Whether or not to clear the device cache after setting the tensor on the device.
|
| 254 |
+
"""
|
| 255 |
+
# Recurse if needed
|
| 256 |
+
if "." in tensor_name:
|
| 257 |
+
splits = tensor_name.split(".")
|
| 258 |
+
for split in splits[:-1]:
|
| 259 |
+
new_module = getattr(module, split)
|
| 260 |
+
if new_module is None:
|
| 261 |
+
raise ValueError(f"{module} has no attribute {split}.")
|
| 262 |
+
module = new_module
|
| 263 |
+
tensor_name = splits[-1]
|
| 264 |
+
|
| 265 |
+
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
| 266 |
+
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
| 267 |
+
is_buffer = tensor_name in module._buffers
|
| 268 |
+
old_value = getattr(module, tensor_name)
|
| 269 |
+
|
| 270 |
+
# Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight
|
| 271 |
+
# in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer.
|
| 272 |
+
if (
|
| 273 |
+
value is not None
|
| 274 |
+
and tied_params_map is not None
|
| 275 |
+
and value.data_ptr() in tied_params_map
|
| 276 |
+
and device in tied_params_map[value.data_ptr()]
|
| 277 |
+
):
|
| 278 |
+
module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]
|
| 279 |
+
return
|
| 280 |
+
elif (
|
| 281 |
+
tied_params_map is not None
|
| 282 |
+
and old_value.data_ptr() in tied_params_map
|
| 283 |
+
and device in tied_params_map[old_value.data_ptr()]
|
| 284 |
+
):
|
| 285 |
+
module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
| 289 |
+
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
| 290 |
+
|
| 291 |
+
param = module._parameters[tensor_name] if tensor_name in module._parameters else None
|
| 292 |
+
param_cls = type(param)
|
| 293 |
+
|
| 294 |
+
if value is not None:
|
| 295 |
+
# We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights.
|
| 296 |
+
# In other cases, we want to make sure we're not loading checkpoints that do not match the config.
|
| 297 |
+
if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.'
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
if dtype is None:
|
| 303 |
+
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
|
| 304 |
+
value = value.to(old_value.dtype, non_blocking=non_blocking)
|
| 305 |
+
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
| 306 |
+
value = value.to(dtype, non_blocking=non_blocking)
|
| 307 |
+
|
| 308 |
+
device_quantization = None
|
| 309 |
+
with torch.no_grad():
|
| 310 |
+
# leave it on cpu first before moving them to cuda
|
| 311 |
+
# # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
|
| 312 |
+
if (
|
| 313 |
+
param is not None
|
| 314 |
+
and param.device.type not in ("cuda", "xpu")
|
| 315 |
+
and torch.device(device).type in ("cuda", "xpu")
|
| 316 |
+
and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
|
| 317 |
+
):
|
| 318 |
+
device_quantization = device
|
| 319 |
+
device = "cpu"
|
| 320 |
+
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
|
| 321 |
+
if isinstance(device, int):
|
| 322 |
+
if is_npu_available():
|
| 323 |
+
device = f"npu:{device}"
|
| 324 |
+
elif is_mlu_available():
|
| 325 |
+
device = f"mlu:{device}"
|
| 326 |
+
elif is_sdaa_available():
|
| 327 |
+
device = f"sdaa:{device}"
|
| 328 |
+
elif is_musa_available():
|
| 329 |
+
device = f"musa:{device}"
|
| 330 |
+
elif is_hpu_available():
|
| 331 |
+
device = "hpu"
|
| 332 |
+
if "xpu" in str(device) and not is_xpu_available():
|
| 333 |
+
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
|
| 334 |
+
if value is None:
|
| 335 |
+
new_value = old_value.to(device, non_blocking=non_blocking)
|
| 336 |
+
if dtype is not None and device in ["meta", torch.device("meta")]:
|
| 337 |
+
if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
| 338 |
+
new_value = new_value.to(dtype, non_blocking=non_blocking)
|
| 339 |
+
|
| 340 |
+
if not is_buffer:
|
| 341 |
+
module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
|
| 342 |
+
elif isinstance(value, torch.Tensor):
|
| 343 |
+
new_value = value.to(device, non_blocking=non_blocking)
|
| 344 |
+
else:
|
| 345 |
+
new_value = torch.tensor(value, device=device)
|
| 346 |
+
if device_quantization is not None:
|
| 347 |
+
device = device_quantization
|
| 348 |
+
if is_buffer:
|
| 349 |
+
module._buffers[tensor_name] = new_value
|
| 350 |
+
elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):
|
| 351 |
+
param_cls = type(module._parameters[tensor_name])
|
| 352 |
+
kwargs = module._parameters[tensor_name].__dict__
|
| 353 |
+
if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
|
| 354 |
+
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
|
| 355 |
+
# downcast to fp16 if any - needed for 8bit serialization
|
| 356 |
+
new_value = new_value.to(torch.float16, non_blocking=non_blocking)
|
| 357 |
+
# quantize module that are going to stay on the cpu so that we offload quantized weights
|
| 358 |
+
if device == "cpu" and param_cls.__name__ == "Int8Params":
|
| 359 |
+
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
|
| 360 |
+
new_value.CB = new_value.CB.to("cpu")
|
| 361 |
+
new_value.SCB = new_value.SCB.to("cpu")
|
| 362 |
+
else:
|
| 363 |
+
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
|
| 364 |
+
device, non_blocking=non_blocking
|
| 365 |
+
)
|
| 366 |
+
elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
|
| 367 |
+
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(
|
| 368 |
+
device, non_blocking=non_blocking
|
| 369 |
+
)
|
| 370 |
+
elif param_cls.__name__ in ["AffineQuantizedTensor"]:
|
| 371 |
+
new_value = new_value.to(device, non_blocking=non_blocking)
|
| 372 |
+
else:
|
| 373 |
+
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(
|
| 374 |
+
device, non_blocking=non_blocking
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
module._parameters[tensor_name] = new_value
|
| 378 |
+
if fp16_statistics is not None:
|
| 379 |
+
module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)
|
| 380 |
+
del fp16_statistics
|
| 381 |
+
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
|
| 382 |
+
if (
|
| 383 |
+
module.__class__.__name__ == "Linear8bitLt"
|
| 384 |
+
and getattr(module.weight, "SCB", None) is None
|
| 385 |
+
and str(module.weight.device) != "meta"
|
| 386 |
+
):
|
| 387 |
+
# quantize only if necessary
|
| 388 |
+
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
|
| 389 |
+
if not getattr(module.weight, "SCB", None) and device_index is not None:
|
| 390 |
+
if module.bias is not None and module.bias.device.type != "meta":
|
| 391 |
+
# if a bias exists, we need to wait until the bias is set on the correct device
|
| 392 |
+
module = module.cuda(device_index)
|
| 393 |
+
elif module.bias is None:
|
| 394 |
+
# if no bias exists, we can quantize right away
|
| 395 |
+
module = module.cuda(device_index)
|
| 396 |
+
elif (
|
| 397 |
+
module.__class__.__name__ == "Linear4bit"
|
| 398 |
+
and getattr(module.weight, "quant_state", None) is None
|
| 399 |
+
and str(module.weight.device) != "meta"
|
| 400 |
+
):
|
| 401 |
+
# quantize only if necessary
|
| 402 |
+
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
|
| 403 |
+
if not getattr(module.weight, "quant_state", None) and device_index is not None:
|
| 404 |
+
module.weight = module.weight.cuda(device_index)
|
| 405 |
+
|
| 406 |
+
# clean pre and post forward hook
|
| 407 |
+
if clear_cache and device not in ("cpu", "meta"):
|
| 408 |
+
clear_device_cache()
|
| 409 |
+
|
| 410 |
+
# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
|
| 411 |
+
# order to avoid duplicating memory, see above.
|
| 412 |
+
if (
|
| 413 |
+
tied_params_map is not None
|
| 414 |
+
and old_value.data_ptr() in tied_params_map
|
| 415 |
+
and device not in tied_params_map[old_value.data_ptr()]
|
| 416 |
+
):
|
| 417 |
+
tied_params_map[old_value.data_ptr()][device] = new_value
|
| 418 |
+
elif (
|
| 419 |
+
value is not None
|
| 420 |
+
and tied_params_map is not None
|
| 421 |
+
and value.data_ptr() in tied_params_map
|
| 422 |
+
and device not in tied_params_map[value.data_ptr()]
|
| 423 |
+
):
|
| 424 |
+
tied_params_map[value.data_ptr()][device] = new_value
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def named_module_tensors(
|
| 428 |
+
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
|
| 429 |
+
):
|
| 430 |
+
"""
|
| 431 |
+
A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
|
| 432 |
+
it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
module (`torch.nn.Module`):
|
| 436 |
+
The module we want the tensors on.
|
| 437 |
+
include_buffer (`bool`, *optional*, defaults to `True`):
|
| 438 |
+
Whether or not to include the buffers in the result.
|
| 439 |
+
recurse (`bool`, *optional`, defaults to `False`):
|
| 440 |
+
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
| 441 |
+
remove_non_persistent (`bool`, *optional*, defaults to `False`):
|
| 442 |
+
Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =
|
| 443 |
+
True
|
| 444 |
+
"""
|
| 445 |
+
yield from module.named_parameters(recurse=recurse)
|
| 446 |
+
|
| 447 |
+
if include_buffers:
|
| 448 |
+
non_persistent_buffers = set()
|
| 449 |
+
if remove_non_persistent:
|
| 450 |
+
non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse)
|
| 451 |
+
for named_buffer in module.named_buffers(recurse=recurse):
|
| 452 |
+
name, _ = named_buffer
|
| 453 |
+
if name not in non_persistent_buffers:
|
| 454 |
+
yield named_buffer
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: bool = False):
|
| 458 |
+
"""
|
| 459 |
+
Gather all non persistent buffers of a given modules into a set
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
module (`nn.Module`):
|
| 463 |
+
The module we want the non persistent buffers on.
|
| 464 |
+
recurse (`bool`, *optional*, defaults to `False`):
|
| 465 |
+
Whether or not to go look in every submodule or just return the direct non persistent buffers.
|
| 466 |
+
fqns (`bool`, *optional*, defaults to `False`):
|
| 467 |
+
Whether or not to return the fully-qualified names of the non persistent buffers.
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
non_persistent_buffers_set = module._non_persistent_buffers_set
|
| 471 |
+
if recurse:
|
| 472 |
+
for n, m in module.named_modules():
|
| 473 |
+
if fqns:
|
| 474 |
+
non_persistent_buffers_set |= {n + "." + b for b in m._non_persistent_buffers_set}
|
| 475 |
+
else:
|
| 476 |
+
non_persistent_buffers_set |= m._non_persistent_buffers_set
|
| 477 |
+
|
| 478 |
+
return non_persistent_buffers_set
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def check_tied_parameters_in_config(model: nn.Module):
|
| 482 |
+
"""
|
| 483 |
+
Check if there is any indication in the given model that some weights should be tied.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
model (`torch.nn.Module`): The model to inspect
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
bool: True if the model needs to have tied weights
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
# based on model.tie_weights() method
|
| 493 |
+
has_tied_word_embedding = False
|
| 494 |
+
has_tied_encoder_decoder = False
|
| 495 |
+
has_tied_module = False
|
| 496 |
+
|
| 497 |
+
if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]:
|
| 498 |
+
has_tied_word_embedding = False
|
| 499 |
+
model_decoder_config = None
|
| 500 |
+
if hasattr(model, "config"):
|
| 501 |
+
model_decoder_config = (
|
| 502 |
+
model.config.get_text_config(decoder=True)
|
| 503 |
+
if hasattr(model.config, "get_text_config")
|
| 504 |
+
else model.config
|
| 505 |
+
)
|
| 506 |
+
has_tied_word_embedding = (
|
| 507 |
+
model_decoder_config is not None
|
| 508 |
+
and getattr(model_decoder_config, "tie_word_embeddings", False)
|
| 509 |
+
and model.get_output_embeddings()
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
has_tied_encoder_decoder = (
|
| 513 |
+
hasattr(model, "config")
|
| 514 |
+
and getattr(model.config, "is_encoder_decoder", False)
|
| 515 |
+
and getattr(model.config, "tie_encoder_decoder", False)
|
| 516 |
+
)
|
| 517 |
+
has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules())
|
| 518 |
+
return any([has_tied_word_embedding, has_tied_encoder_decoder, has_tied_module])
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def _get_param_device(param, device_map):
|
| 522 |
+
if param in device_map:
|
| 523 |
+
return device_map[param]
|
| 524 |
+
parent_param = ".".join(param.split(".")[:-1])
|
| 525 |
+
if parent_param == param:
|
| 526 |
+
raise ValueError(f"The `device_map` does not contain the module {param}.")
|
| 527 |
+
else:
|
| 528 |
+
return _get_param_device(parent_param, device_map)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def check_tied_parameters_on_same_device(tied_params, device_map):
|
| 532 |
+
"""
|
| 533 |
+
Check if tied parameters are on the same device
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
tied_params (`List[List[str]]`):
|
| 537 |
+
A list of lists of parameter names being all tied together.
|
| 538 |
+
|
| 539 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`):
|
| 540 |
+
A map that specifies where each submodule should go.
|
| 541 |
+
|
| 542 |
+
"""
|
| 543 |
+
for tie_param in tied_params:
|
| 544 |
+
tie_param_devices = {}
|
| 545 |
+
for param in tie_param:
|
| 546 |
+
tie_param_devices[param] = _get_param_device(param, device_map)
|
| 547 |
+
if len(set(tie_param_devices.values())) > 1:
|
| 548 |
+
logger.warning(
|
| 549 |
+
f"Tied parameters are on different devices: {tie_param_devices}. "
|
| 550 |
+
"Please modify your custom device map or set `device_map='auto'`. "
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]:
|
| 555 |
+
"""
|
| 556 |
+
Find the tied parameters in a given model.
|
| 557 |
+
|
| 558 |
+
<Tip warning={true}>
|
| 559 |
+
|
| 560 |
+
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
|
| 561 |
+
them.
|
| 562 |
+
|
| 563 |
+
</Tip>
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
model (`torch.nn.Module`): The model to inspect.
|
| 567 |
+
|
| 568 |
+
Returns:
|
| 569 |
+
List[List[str]]: A list of lists of parameter names being all tied together.
|
| 570 |
+
|
| 571 |
+
Example:
|
| 572 |
+
|
| 573 |
+
```py
|
| 574 |
+
>>> from collections import OrderedDict
|
| 575 |
+
>>> import torch.nn as nn
|
| 576 |
+
|
| 577 |
+
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
|
| 578 |
+
>>> model.linear2.weight = model.linear1.weight
|
| 579 |
+
>>> find_tied_parameters(model)
|
| 580 |
+
[['linear1.weight', 'linear2.weight']]
|
| 581 |
+
```
|
| 582 |
+
"""
|
| 583 |
+
|
| 584 |
+
# get ALL model parameters and their names
|
| 585 |
+
all_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=False)}
|
| 586 |
+
|
| 587 |
+
# get ONLY unique named parameters,
|
| 588 |
+
# if parameter is tied and have multiple names, it will be included only once
|
| 589 |
+
no_duplicate_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=True)}
|
| 590 |
+
|
| 591 |
+
# the difference of the two sets will give us the tied parameters
|
| 592 |
+
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
|
| 593 |
+
|
| 594 |
+
# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
|
| 595 |
+
# which names refer to the same parameter. To identify this, we need to group them together.
|
| 596 |
+
tied_param_groups = {}
|
| 597 |
+
for tied_param_name in tied_param_names:
|
| 598 |
+
tied_param = all_named_parameters[tied_param_name]
|
| 599 |
+
for param_name, param in no_duplicate_named_parameters.items():
|
| 600 |
+
# compare if parameters are the same, if so, group their names together
|
| 601 |
+
if param is tied_param:
|
| 602 |
+
if param_name not in tied_param_groups:
|
| 603 |
+
tied_param_groups[param_name] = []
|
| 604 |
+
tied_param_groups[param_name].append(tied_param_name)
|
| 605 |
+
|
| 606 |
+
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def retie_parameters(model, tied_params):
|
| 610 |
+
"""
|
| 611 |
+
Reties tied parameters in a given model if the link was broken (for instance when adding hooks).
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
model (`torch.nn.Module`):
|
| 615 |
+
The model in which to retie parameters.
|
| 616 |
+
tied_params (`List[List[str]]`):
|
| 617 |
+
A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`.
|
| 618 |
+
"""
|
| 619 |
+
for tied_group in tied_params:
|
| 620 |
+
param_to_tie = None
|
| 621 |
+
# two loops : the first one to set param_to_tie , the second one to change the values of tied_group
|
| 622 |
+
for param_name in tied_group:
|
| 623 |
+
module = model
|
| 624 |
+
splits = param_name.split(".")
|
| 625 |
+
for split in splits[:-1]:
|
| 626 |
+
module = getattr(module, split)
|
| 627 |
+
param = getattr(module, splits[-1])
|
| 628 |
+
if param_to_tie is None and param.device != torch.device("meta"):
|
| 629 |
+
param_to_tie = param
|
| 630 |
+
break
|
| 631 |
+
if param_to_tie is not None:
|
| 632 |
+
for param_name in tied_group:
|
| 633 |
+
module = model
|
| 634 |
+
splits = param_name.split(".")
|
| 635 |
+
for split in splits[:-1]:
|
| 636 |
+
module = getattr(module, split)
|
| 637 |
+
setattr(module, splits[-1], param_to_tie)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
|
| 641 |
+
"""
|
| 642 |
+
Just does torch.dtype(dtype) if necessary.
|
| 643 |
+
"""
|
| 644 |
+
if isinstance(dtype, str):
|
| 645 |
+
# We accept "torch.float16" or just "float16"
|
| 646 |
+
dtype = dtype.replace("torch.", "")
|
| 647 |
+
dtype = getattr(torch, dtype)
|
| 648 |
+
return dtype
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def compute_module_sizes(
|
| 652 |
+
model: nn.Module,
|
| 653 |
+
dtype: Optional[Union[str, torch.device]] = None,
|
| 654 |
+
special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
|
| 655 |
+
buffers_only: bool = False,
|
| 656 |
+
):
|
| 657 |
+
"""
|
| 658 |
+
Compute the size of each submodule of a given model.
|
| 659 |
+
"""
|
| 660 |
+
if dtype is not None:
|
| 661 |
+
dtype = _get_proper_dtype(dtype)
|
| 662 |
+
dtype_size = dtype_byte_size(dtype)
|
| 663 |
+
if special_dtypes is not None:
|
| 664 |
+
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
| 665 |
+
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
| 666 |
+
module_sizes = defaultdict(int)
|
| 667 |
+
|
| 668 |
+
module_list = []
|
| 669 |
+
|
| 670 |
+
if not buffers_only:
|
| 671 |
+
module_list = named_module_tensors(model, recurse=True)
|
| 672 |
+
else:
|
| 673 |
+
module_list = model.named_buffers(recurse=True)
|
| 674 |
+
|
| 675 |
+
for name, tensor in module_list:
|
| 676 |
+
if special_dtypes is not None and name in special_dtypes:
|
| 677 |
+
size = tensor.numel() * special_dtypes_size[name]
|
| 678 |
+
elif dtype is None:
|
| 679 |
+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
| 680 |
+
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
| 681 |
+
# According to the code in set_module_tensor_to_device, these types won't be converted
|
| 682 |
+
# so use their original size here
|
| 683 |
+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
| 684 |
+
else:
|
| 685 |
+
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
| 686 |
+
name_parts = name.split(".")
|
| 687 |
+
for idx in range(len(name_parts) + 1):
|
| 688 |
+
module_sizes[".".join(name_parts[:idx])] += size
|
| 689 |
+
|
| 690 |
+
return module_sizes
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def compute_module_total_buffer_size(
|
| 694 |
+
model: nn.Module,
|
| 695 |
+
dtype: Optional[Union[str, torch.device]] = None,
|
| 696 |
+
special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
|
| 697 |
+
):
|
| 698 |
+
"""
|
| 699 |
+
Compute the total size of buffers in each submodule of a given model.
|
| 700 |
+
"""
|
| 701 |
+
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True)
|
| 702 |
+
return module_sizes.get("", 0)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def get_max_layer_size(
|
| 706 |
+
modules: list[tuple[str, torch.nn.Module]], module_sizes: dict[str, int], no_split_module_classes: list[str]
|
| 707 |
+
):
|
| 708 |
+
"""
|
| 709 |
+
Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
|
| 710 |
+
definition of a layer being:
|
| 711 |
+
- a module with no direct children (just parameters and buffers)
|
| 712 |
+
- a module whose class name is in the list `no_split_module_classes`
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
modules (`List[Tuple[str, torch.nn.Module]]`):
|
| 716 |
+
The list of named modules where we want to determine the maximum layer size.
|
| 717 |
+
module_sizes (`Dict[str, int]`):
|
| 718 |
+
A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
|
| 719 |
+
no_split_module_classes (`List[str]`):
|
| 720 |
+
A list of class names for layers we don't want to be split.
|
| 721 |
+
|
| 722 |
+
Returns:
|
| 723 |
+
`Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size.
|
| 724 |
+
"""
|
| 725 |
+
max_size = 0
|
| 726 |
+
layer_names = []
|
| 727 |
+
modules_to_treat = modules.copy()
|
| 728 |
+
while len(modules_to_treat) > 0:
|
| 729 |
+
module_name, module = modules_to_treat.pop(0)
|
| 730 |
+
modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else []
|
| 731 |
+
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
|
| 732 |
+
# No splitting this one so we compare to the max_size
|
| 733 |
+
size = module_sizes[module_name]
|
| 734 |
+
if size > max_size:
|
| 735 |
+
max_size = size
|
| 736 |
+
layer_names = [module_name]
|
| 737 |
+
elif size == max_size:
|
| 738 |
+
layer_names.append(module_name)
|
| 739 |
+
else:
|
| 740 |
+
modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat
|
| 741 |
+
return max_size, layer_names
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None):
|
| 745 |
+
"""
|
| 746 |
+
Get the maximum memory available if nothing is passed, converts string to int otherwise.
|
| 747 |
+
"""
|
| 748 |
+
import psutil
|
| 749 |
+
|
| 750 |
+
if max_memory is None:
|
| 751 |
+
max_memory = {}
|
| 752 |
+
# Make sure CUDA is initialized on each GPU to have the right memory info.
|
| 753 |
+
if is_npu_available():
|
| 754 |
+
for i in range(torch.npu.device_count()):
|
| 755 |
+
try:
|
| 756 |
+
_ = torch.tensor(0, device=torch.device("npu", i))
|
| 757 |
+
max_memory[i] = torch.npu.mem_get_info(i)[0]
|
| 758 |
+
except Exception:
|
| 759 |
+
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
|
| 760 |
+
continue
|
| 761 |
+
elif is_mlu_available():
|
| 762 |
+
for i in range(torch.mlu.device_count()):
|
| 763 |
+
try:
|
| 764 |
+
_ = torch.tensor(0, device=torch.device("mlu", i))
|
| 765 |
+
max_memory[i] = torch.mlu.mem_get_info(i)[0]
|
| 766 |
+
except Exception:
|
| 767 |
+
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
|
| 768 |
+
continue
|
| 769 |
+
elif is_sdaa_available():
|
| 770 |
+
for i in range(torch.sdaa.device_count()):
|
| 771 |
+
try:
|
| 772 |
+
_ = torch.tensor(0, device=torch.device("sdaa", i))
|
| 773 |
+
max_memory[i] = torch.sdaa.mem_get_info(i)[0]
|
| 774 |
+
except Exception:
|
| 775 |
+
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
|
| 776 |
+
continue
|
| 777 |
+
elif is_musa_available():
|
| 778 |
+
for i in range(torch.musa.device_count()):
|
| 779 |
+
try:
|
| 780 |
+
_ = torch.tensor(0, device=torch.device("musa", i))
|
| 781 |
+
max_memory[i] = torch.musa.mem_get_info(i)[0]
|
| 782 |
+
except Exception:
|
| 783 |
+
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
|
| 784 |
+
continue
|
| 785 |
+
elif is_xpu_available():
|
| 786 |
+
for i in range(torch.xpu.device_count()):
|
| 787 |
+
try:
|
| 788 |
+
_ = torch.tensor(0, device=torch.device("xpu", i))
|
| 789 |
+
max_memory[i] = get_xpu_available_memory(i)
|
| 790 |
+
except Exception:
|
| 791 |
+
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
|
| 792 |
+
continue
|
| 793 |
+
elif is_hpu_available():
|
| 794 |
+
for i in range(torch.hpu.device_count()):
|
| 795 |
+
try:
|
| 796 |
+
_ = torch.tensor(0, device=torch.device("hpu", i))
|
| 797 |
+
max_memory[i] = torch.hpu.mem_get_info(i)[0]
|
| 798 |
+
except Exception:
|
| 799 |
+
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
|
| 800 |
+
continue
|
| 801 |
+
else:
|
| 802 |
+
for i in range(torch.cuda.device_count()):
|
| 803 |
+
try:
|
| 804 |
+
_ = torch.tensor([0], device=i)
|
| 805 |
+
max_memory[i] = torch.cuda.mem_get_info(i)[0]
|
| 806 |
+
except Exception:
|
| 807 |
+
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
|
| 808 |
+
continue
|
| 809 |
+
# allocate everything in the mps device as the RAM is shared
|
| 810 |
+
if is_mps_available():
|
| 811 |
+
max_memory["mps"] = psutil.virtual_memory().available
|
| 812 |
+
else:
|
| 813 |
+
max_memory["cpu"] = psutil.virtual_memory().available
|
| 814 |
+
return max_memory
|
| 815 |
+
|
| 816 |
+
for key in max_memory:
|
| 817 |
+
if isinstance(max_memory[key], str):
|
| 818 |
+
max_memory[key] = convert_file_size_to_int(max_memory[key])
|
| 819 |
+
|
| 820 |
+
# Need to sort the device by type to make sure that we allocate the gpu first.
|
| 821 |
+
# As gpu/npu/xpu are represented by int, we need to sort them first.
|
| 822 |
+
gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)]
|
| 823 |
+
gpu_devices.sort()
|
| 824 |
+
# check if gpu/npu/xpu devices are available and if not, throw a warning
|
| 825 |
+
if is_npu_available():
|
| 826 |
+
num_devices = torch.npu.device_count()
|
| 827 |
+
elif is_mlu_available():
|
| 828 |
+
num_devices = torch.mlu.device_count()
|
| 829 |
+
elif is_sdaa_available():
|
| 830 |
+
num_devices = torch.sdaa.device_count()
|
| 831 |
+
elif is_musa_available():
|
| 832 |
+
num_devices = torch.musa.device_count()
|
| 833 |
+
elif is_xpu_available():
|
| 834 |
+
num_devices = torch.xpu.device_count()
|
| 835 |
+
elif is_hpu_available():
|
| 836 |
+
num_devices = torch.hpu.device_count()
|
| 837 |
+
else:
|
| 838 |
+
num_devices = torch.cuda.device_count()
|
| 839 |
+
for device in gpu_devices:
|
| 840 |
+
if device >= num_devices or device < 0:
|
| 841 |
+
logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}")
|
| 842 |
+
# Add the other devices in the preset order if they are available
|
| 843 |
+
all_devices = gpu_devices + [k for k in ["mps", "cpu", "disk"] if k in max_memory.keys()]
|
| 844 |
+
# Raise an error if a device is not recognized
|
| 845 |
+
for k in max_memory.keys():
|
| 846 |
+
if k not in all_devices:
|
| 847 |
+
raise ValueError(
|
| 848 |
+
f"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'"
|
| 849 |
+
)
|
| 850 |
+
max_memory = {k: max_memory[k] for k in all_devices}
|
| 851 |
+
|
| 852 |
+
return max_memory
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def clean_device_map(device_map: dict[str, Union[int, str, torch.device]], module_name: str = ""):
|
| 856 |
+
"""
|
| 857 |
+
Cleans a device_map by grouping all submodules that go on the same device together.
|
| 858 |
+
"""
|
| 859 |
+
# Get the value of the current module and if there is only one split across several keys, regroup it.
|
| 860 |
+
prefix = "" if module_name == "" else f"{module_name}."
|
| 861 |
+
values = [v for k, v in device_map.items() if k.startswith(prefix)]
|
| 862 |
+
if len(set(values)) == 1 and len(values) > 1:
|
| 863 |
+
for k in [k for k in device_map if k.startswith(prefix)]:
|
| 864 |
+
del device_map[k]
|
| 865 |
+
device_map[module_name] = values[0]
|
| 866 |
+
|
| 867 |
+
# Recurse over the children
|
| 868 |
+
children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)]
|
| 869 |
+
idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1
|
| 870 |
+
children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules)
|
| 871 |
+
for child in children_modules:
|
| 872 |
+
clean_device_map(device_map, module_name=child)
|
| 873 |
+
|
| 874 |
+
return device_map
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
def load_offloaded_weights(model, index, offload_folder):
|
| 878 |
+
"""
|
| 879 |
+
Loads the weights from the offload folder into the model.
|
| 880 |
+
|
| 881 |
+
Args:
|
| 882 |
+
model (`torch.nn.Module`):
|
| 883 |
+
The model to load the weights into.
|
| 884 |
+
index (`dict`):
|
| 885 |
+
A dictionary containing the parameter name and its metadata for each parameter that was offloaded from the
|
| 886 |
+
model.
|
| 887 |
+
offload_folder (`str`):
|
| 888 |
+
The folder where the offloaded weights are stored.
|
| 889 |
+
"""
|
| 890 |
+
if index is None or len(index) == 0:
|
| 891 |
+
# Nothing to do
|
| 892 |
+
return
|
| 893 |
+
for param_name, metadata in index.items():
|
| 894 |
+
if "SCB" in param_name:
|
| 895 |
+
continue
|
| 896 |
+
fp16_statistics = None
|
| 897 |
+
if "weight" in param_name and param_name.replace("weight", "SCB") in index.keys():
|
| 898 |
+
weight_name = param_name.replace("weight", "SCB")
|
| 899 |
+
fp16_statistics = load_offloaded_weight(
|
| 900 |
+
os.path.join(offload_folder, f"{weight_name}.dat"), index[weight_name]
|
| 901 |
+
)
|
| 902 |
+
tensor_file = os.path.join(offload_folder, f"{param_name}.dat")
|
| 903 |
+
weight = load_offloaded_weight(tensor_file, metadata)
|
| 904 |
+
set_module_tensor_to_device(model, param_name, "cpu", value=weight, fp16_statistics=fp16_statistics)
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def get_module_leaves(module_sizes):
|
| 908 |
+
module_children = {}
|
| 909 |
+
for module in module_sizes:
|
| 910 |
+
if module == "" or "." not in module:
|
| 911 |
+
continue
|
| 912 |
+
parent = module.rsplit(".", 1)[0]
|
| 913 |
+
module_children[parent] = module_children.get(parent, 0) + 1
|
| 914 |
+
leaves = [module for module in module_sizes if module_children.get(module, 0) == 0 and module != ""]
|
| 915 |
+
return leaves
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def get_balanced_memory(
|
| 919 |
+
model: nn.Module,
|
| 920 |
+
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
|
| 921 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 922 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
| 923 |
+
special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
|
| 924 |
+
low_zero: bool = False,
|
| 925 |
+
):
|
| 926 |
+
"""
|
| 927 |
+
Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.
|
| 928 |
+
|
| 929 |
+
<Tip>
|
| 930 |
+
|
| 931 |
+
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
|
| 932 |
+
meta device (as it would if initialized within the `init_empty_weights` context manager).
|
| 933 |
+
|
| 934 |
+
</Tip>
|
| 935 |
+
|
| 936 |
+
Args:
|
| 937 |
+
model (`torch.nn.Module`):
|
| 938 |
+
The model to analyze.
|
| 939 |
+
max_memory (`Dict`, *optional*):
|
| 940 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
| 941 |
+
Example: `max_memory={0: "1GB"}`.
|
| 942 |
+
no_split_module_classes (`List[str]`, *optional*):
|
| 943 |
+
A list of layer class names that should never be split across device (for instance any layer that has a
|
| 944 |
+
residual connection).
|
| 945 |
+
dtype (`str` or `torch.dtype`, *optional*):
|
| 946 |
+
If provided, the weights will be converted to that type when loaded.
|
| 947 |
+
special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
|
| 948 |
+
If provided, special dtypes to consider for some specific weights (will override dtype used as default for
|
| 949 |
+
all weights).
|
| 950 |
+
low_zero (`bool`, *optional*):
|
| 951 |
+
Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the
|
| 952 |
+
Transformers generate function).
|
| 953 |
+
"""
|
| 954 |
+
# Get default / clean up max_memory
|
| 955 |
+
user_not_set_max_memory = max_memory is None
|
| 956 |
+
max_memory = get_max_memory(max_memory)
|
| 957 |
+
|
| 958 |
+
if is_npu_available():
|
| 959 |
+
expected_device_type = "npu"
|
| 960 |
+
elif is_mlu_available():
|
| 961 |
+
expected_device_type = "mlu"
|
| 962 |
+
elif is_sdaa_available():
|
| 963 |
+
expected_device_type = "sdaa"
|
| 964 |
+
elif is_musa_available():
|
| 965 |
+
expected_device_type = "musa"
|
| 966 |
+
elif is_xpu_available():
|
| 967 |
+
expected_device_type = "xpu"
|
| 968 |
+
elif is_hpu_available():
|
| 969 |
+
expected_device_type = "hpu"
|
| 970 |
+
elif is_mps_available():
|
| 971 |
+
expected_device_type = "mps"
|
| 972 |
+
else:
|
| 973 |
+
expected_device_type = "cuda"
|
| 974 |
+
num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0])
|
| 975 |
+
|
| 976 |
+
if num_devices == 0:
|
| 977 |
+
return max_memory
|
| 978 |
+
|
| 979 |
+
if num_devices == 1:
|
| 980 |
+
# We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer
|
| 981 |
+
low_zero = False
|
| 982 |
+
# If user just asked us to handle memory usage, we should avoid OOM
|
| 983 |
+
if user_not_set_max_memory:
|
| 984 |
+
for key in max_memory.keys():
|
| 985 |
+
if isinstance(key, int):
|
| 986 |
+
max_memory[key] *= 0.9 # 90% is a good compromise
|
| 987 |
+
logger.info(
|
| 988 |
+
f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. "
|
| 989 |
+
"You can set `max_memory` in to a higher value to use more memory (at your own risk)."
|
| 990 |
+
)
|
| 991 |
+
break # only one device
|
| 992 |
+
|
| 993 |
+
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
|
| 994 |
+
per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)
|
| 995 |
+
|
| 996 |
+
# We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get
|
| 997 |
+
# slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to
|
| 998 |
+
# add which is the biggest of:
|
| 999 |
+
# - the size of no split block (if applicable)
|
| 1000 |
+
# - the mean of the layer sizes
|
| 1001 |
+
if no_split_module_classes is None:
|
| 1002 |
+
no_split_module_classes = []
|
| 1003 |
+
elif not isinstance(no_split_module_classes, (list, tuple)):
|
| 1004 |
+
no_split_module_classes = [no_split_module_classes]
|
| 1005 |
+
|
| 1006 |
+
# Identify the size of the no_split_block modules
|
| 1007 |
+
if len(no_split_module_classes) > 0:
|
| 1008 |
+
no_split_children = {}
|
| 1009 |
+
for name, size in module_sizes.items():
|
| 1010 |
+
if name == "":
|
| 1011 |
+
continue
|
| 1012 |
+
submodule = model
|
| 1013 |
+
for submodule_name in name.split("."):
|
| 1014 |
+
submodule = getattr(submodule, submodule_name)
|
| 1015 |
+
class_name = submodule.__class__.__name__
|
| 1016 |
+
if class_name in no_split_module_classes and class_name not in no_split_children:
|
| 1017 |
+
no_split_children[class_name] = size
|
| 1018 |
+
|
| 1019 |
+
if set(no_split_children.keys()) == set(no_split_module_classes):
|
| 1020 |
+
break
|
| 1021 |
+
buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0
|
| 1022 |
+
else:
|
| 1023 |
+
buffer = 0
|
| 1024 |
+
|
| 1025 |
+
# Compute mean of final modules. In the first dict of module sizes, leaves are the parameters
|
| 1026 |
+
leaves = get_module_leaves(module_sizes)
|
| 1027 |
+
leaves_set = set(leaves) # Convert to set for O(1) membership testing
|
| 1028 |
+
module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}
|
| 1029 |
+
# Once removed, leaves are the final modules.
|
| 1030 |
+
leaves = get_module_leaves(module_sizes)
|
| 1031 |
+
mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))
|
| 1032 |
+
buffer = int(1.25 * max(buffer, mean_leaves))
|
| 1033 |
+
per_gpu += buffer
|
| 1034 |
+
|
| 1035 |
+
# Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them)
|
| 1036 |
+
gpus_idx_list = list(
|
| 1037 |
+
sorted(
|
| 1038 |
+
device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0
|
| 1039 |
+
)
|
| 1040 |
+
)
|
| 1041 |
+
# The last device is left with max_memory just in case the buffer is not enough.
|
| 1042 |
+
for idx in gpus_idx_list[:-1]:
|
| 1043 |
+
max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])
|
| 1044 |
+
|
| 1045 |
+
if low_zero:
|
| 1046 |
+
min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)]))
|
| 1047 |
+
max_memory[0] = min(min_zero, max_memory[0])
|
| 1048 |
+
|
| 1049 |
+
return max_memory
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def calculate_maximum_sizes(model: torch.nn.Module):
|
| 1053 |
+
"Computes the total size of the model and its largest layer"
|
| 1054 |
+
sizes = compute_module_sizes(model)
|
| 1055 |
+
# `transformers` models store this information for us
|
| 1056 |
+
no_split_modules = getattr(model, "_no_split_modules", None)
|
| 1057 |
+
if no_split_modules is None:
|
| 1058 |
+
no_split_modules = []
|
| 1059 |
+
|
| 1060 |
+
modules_to_treat = (
|
| 1061 |
+
list(model.named_parameters(recurse=False))
|
| 1062 |
+
+ list(model.named_children())
|
| 1063 |
+
+ list(model.named_buffers(recurse=False))
|
| 1064 |
+
)
|
| 1065 |
+
largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules)
|
| 1066 |
+
total_size = sizes[""]
|
| 1067 |
+
return total_size, largest_layer
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
def _init_infer_auto_device_map(
|
| 1071 |
+
model: nn.Module,
|
| 1072 |
+
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
|
| 1073 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 1074 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
| 1075 |
+
special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
|
| 1076 |
+
) -> tuple[
|
| 1077 |
+
list[Union[int, str]],
|
| 1078 |
+
dict[Union[int, str], Union[int, str]],
|
| 1079 |
+
list[Union[int, str]],
|
| 1080 |
+
list[int],
|
| 1081 |
+
dict[str, int],
|
| 1082 |
+
list[list[str]],
|
| 1083 |
+
list[str],
|
| 1084 |
+
list[tuple[str, nn.Module]],
|
| 1085 |
+
]:
|
| 1086 |
+
"""
|
| 1087 |
+
Initialize variables required for computing the device map for model allocation.
|
| 1088 |
+
"""
|
| 1089 |
+
max_memory = get_max_memory(max_memory)
|
| 1090 |
+
if no_split_module_classes is None:
|
| 1091 |
+
no_split_module_classes = []
|
| 1092 |
+
elif not isinstance(no_split_module_classes, (list, tuple)):
|
| 1093 |
+
no_split_module_classes = [no_split_module_classes]
|
| 1094 |
+
|
| 1095 |
+
devices = list(max_memory.keys())
|
| 1096 |
+
if "disk" not in devices:
|
| 1097 |
+
devices.append("disk")
|
| 1098 |
+
gpus = [device for device in devices if device not in ["cpu", "disk"]]
|
| 1099 |
+
|
| 1100 |
+
# Devices that need to keep space for a potential offloaded layer.
|
| 1101 |
+
if "mps" in gpus:
|
| 1102 |
+
main_devices = ["mps"]
|
| 1103 |
+
elif len(gpus) > 0:
|
| 1104 |
+
main_devices = [gpus[0], "cpu"]
|
| 1105 |
+
else:
|
| 1106 |
+
main_devices = ["cpu"]
|
| 1107 |
+
|
| 1108 |
+
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
|
| 1109 |
+
tied_parameters = find_tied_parameters(model)
|
| 1110 |
+
if check_tied_parameters_in_config(model) and len(tied_parameters) == 0:
|
| 1111 |
+
logger.warning(
|
| 1112 |
+
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
# Direct submodules and parameters
|
| 1116 |
+
modules_to_treat = (
|
| 1117 |
+
list(model.named_parameters(recurse=False))
|
| 1118 |
+
+ list(model.named_children())
|
| 1119 |
+
+ list(model.named_buffers(recurse=False))
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
return (
|
| 1123 |
+
devices,
|
| 1124 |
+
max_memory,
|
| 1125 |
+
main_devices,
|
| 1126 |
+
gpus,
|
| 1127 |
+
module_sizes,
|
| 1128 |
+
tied_parameters,
|
| 1129 |
+
no_split_module_classes,
|
| 1130 |
+
modules_to_treat,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def get_module_size_with_ties(
|
| 1135 |
+
tied_params,
|
| 1136 |
+
module_size,
|
| 1137 |
+
module_sizes,
|
| 1138 |
+
modules_to_treat,
|
| 1139 |
+
) -> tuple[int, list[str], list[nn.Module]]:
|
| 1140 |
+
"""
|
| 1141 |
+
Calculate the total size of a module, including its tied parameters.
|
| 1142 |
+
|
| 1143 |
+
Args:
|
| 1144 |
+
tied_params (`List[str]`): The list of tied parameters.
|
| 1145 |
+
module_size (`int`): The size of the module without tied parameters.
|
| 1146 |
+
module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size.
|
| 1147 |
+
modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat.
|
| 1148 |
+
|
| 1149 |
+
Returns:
|
| 1150 |
+
`Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the
|
| 1151 |
+
tied modules.
|
| 1152 |
+
"""
|
| 1153 |
+
if len(tied_params) < 1:
|
| 1154 |
+
return module_size, [], []
|
| 1155 |
+
tied_module_names = []
|
| 1156 |
+
tied_modules = []
|
| 1157 |
+
|
| 1158 |
+
for tied_param in tied_params:
|
| 1159 |
+
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0]
|
| 1160 |
+
tied_module_names.append(modules_to_treat[tied_module_index][0])
|
| 1161 |
+
tied_modules.append(modules_to_treat[tied_module_index][1])
|
| 1162 |
+
|
| 1163 |
+
module_size_with_ties = module_size
|
| 1164 |
+
for tied_param, tied_module_name in zip(tied_params, tied_module_names):
|
| 1165 |
+
module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]
|
| 1166 |
+
|
| 1167 |
+
return module_size_with_ties, tied_module_names, tied_modules
|
| 1168 |
+
|
| 1169 |
+
|
| 1170 |
+
def fallback_allocate(
|
| 1171 |
+
modules: list[tuple[str, nn.Module]],
|
| 1172 |
+
module_sizes: dict[str, int],
|
| 1173 |
+
size_limit: Union[int, str],
|
| 1174 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 1175 |
+
tied_parameters: Optional[list[list[str]]] = None,
|
| 1176 |
+
) -> tuple[Optional[str], Optional[nn.Module], list[tuple[str, nn.Module]]]:
|
| 1177 |
+
"""
|
| 1178 |
+
Find a module that fits in the size limit using BFS and return it with its name and the remaining modules.
|
| 1179 |
+
|
| 1180 |
+
Args:
|
| 1181 |
+
modules (`List[Tuple[str, nn.Module]]`):
|
| 1182 |
+
The list of named modules to search in.
|
| 1183 |
+
module_sizes (`Dict[str, int]`):
|
| 1184 |
+
A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
|
| 1185 |
+
size_limit (`Union[int, str]`):
|
| 1186 |
+
The maximum size a module can have.
|
| 1187 |
+
no_split_module_classes (`Optional[List[str]]`, *optional*):
|
| 1188 |
+
A list of class names for layers we don't want to be split.
|
| 1189 |
+
tied_parameters (`Optional[List[List[str]]`, *optional*):
|
| 1190 |
+
A list of lists of parameter names being all tied together.
|
| 1191 |
+
|
| 1192 |
+
Returns:
|
| 1193 |
+
`Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: A tuple containing:
|
| 1194 |
+
- The name of the module that fits within the size limit.
|
| 1195 |
+
- The module itself.
|
| 1196 |
+
- The list of remaining modules after the found module is removed.
|
| 1197 |
+
"""
|
| 1198 |
+
try:
|
| 1199 |
+
size_limit = convert_file_size_to_int(size_limit)
|
| 1200 |
+
except ValueError:
|
| 1201 |
+
return None, None, modules
|
| 1202 |
+
|
| 1203 |
+
if no_split_module_classes is None:
|
| 1204 |
+
no_split_module_classes = []
|
| 1205 |
+
|
| 1206 |
+
if tied_parameters is None:
|
| 1207 |
+
tied_parameters = []
|
| 1208 |
+
|
| 1209 |
+
modules_to_search = modules.copy()
|
| 1210 |
+
module_found = False
|
| 1211 |
+
|
| 1212 |
+
while modules_to_search:
|
| 1213 |
+
name, module = modules_to_search.pop(0)
|
| 1214 |
+
|
| 1215 |
+
tied_param_groups = [
|
| 1216 |
+
tied_group
|
| 1217 |
+
for tied_group in tied_parameters
|
| 1218 |
+
if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
|
| 1219 |
+
]
|
| 1220 |
+
|
| 1221 |
+
tied_params = sum(
|
| 1222 |
+
[[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
module_size_with_ties, _, _ = get_module_size_with_ties(
|
| 1226 |
+
tied_params, module_sizes[name], module_sizes, modules_to_search
|
| 1227 |
+
)
|
| 1228 |
+
|
| 1229 |
+
# If the module fits in the size limit, we found it.
|
| 1230 |
+
if module_size_with_ties <= size_limit:
|
| 1231 |
+
module_found = True
|
| 1232 |
+
break
|
| 1233 |
+
|
| 1234 |
+
# The module is too big, we need to split it if possible.
|
| 1235 |
+
modules_children = (
|
| 1236 |
+
[]
|
| 1237 |
+
if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
|
| 1238 |
+
else list(module.named_children())
|
| 1239 |
+
)
|
| 1240 |
+
|
| 1241 |
+
# Split fails, move to the next module
|
| 1242 |
+
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
|
| 1243 |
+
continue
|
| 1244 |
+
|
| 1245 |
+
# split is possible, add the children to the list of modules to search
|
| 1246 |
+
modules_children = list(module.named_parameters(recurse=False)) + modules_children
|
| 1247 |
+
modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search
|
| 1248 |
+
|
| 1249 |
+
if not module_found:
|
| 1250 |
+
return None, None, modules
|
| 1251 |
+
|
| 1252 |
+
# Prepare the module list for removal of the found module
|
| 1253 |
+
current_names = [n for n, _ in modules]
|
| 1254 |
+
dot_idx = [i for i, c in enumerate(name) if c == "."]
|
| 1255 |
+
|
| 1256 |
+
for dot_index in dot_idx:
|
| 1257 |
+
parent_name = name[:dot_index]
|
| 1258 |
+
if parent_name in current_names:
|
| 1259 |
+
parent_module_idx = current_names.index(parent_name)
|
| 1260 |
+
_, parent_module = modules[parent_module_idx]
|
| 1261 |
+
module_children = list(parent_module.named_parameters(recurse=False)) + list(
|
| 1262 |
+
parent_module.named_children()
|
| 1263 |
+
)
|
| 1264 |
+
modules = (
|
| 1265 |
+
modules[:parent_module_idx]
|
| 1266 |
+
+ [(f"{parent_name}.{n}", v) for n, v in module_children]
|
| 1267 |
+
+ modules[parent_module_idx + 1 :]
|
| 1268 |
+
)
|
| 1269 |
+
current_names = [n for n, _ in modules]
|
| 1270 |
+
|
| 1271 |
+
# Now the target module should be directly in the list
|
| 1272 |
+
target_idx = current_names.index(name)
|
| 1273 |
+
name, module = modules.pop(target_idx)
|
| 1274 |
+
|
| 1275 |
+
return name, module, modules
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
def infer_auto_device_map(
|
| 1279 |
+
model: nn.Module,
|
| 1280 |
+
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
|
| 1281 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 1282 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
| 1283 |
+
special_dtypes: Optional[dict[str, Union[str, torch.dtype]]] = None,
|
| 1284 |
+
verbose: bool = False,
|
| 1285 |
+
clean_result: bool = True,
|
| 1286 |
+
offload_buffers: bool = False,
|
| 1287 |
+
fallback_allocation: bool = False,
|
| 1288 |
+
):
|
| 1289 |
+
"""
|
| 1290 |
+
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
|
| 1291 |
+
such that:
|
| 1292 |
+
- we don't exceed the memory available of any of the GPU.
|
| 1293 |
+
- if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that
|
| 1294 |
+
has the largest size.
|
| 1295 |
+
- if offload to the CPU is needed,we don't exceed the RAM available on the CPU.
|
| 1296 |
+
- if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk
|
| 1297 |
+
that has the largest size.
|
| 1298 |
+
|
| 1299 |
+
<Tip>
|
| 1300 |
+
|
| 1301 |
+
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
|
| 1302 |
+
meta device (as it would if initialized within the `init_empty_weights` context manager).
|
| 1303 |
+
|
| 1304 |
+
</Tip>
|
| 1305 |
+
|
| 1306 |
+
Args:
|
| 1307 |
+
model (`torch.nn.Module`):
|
| 1308 |
+
The model to analyze.
|
| 1309 |
+
max_memory (`Dict`, *optional*):
|
| 1310 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
| 1311 |
+
Example: `max_memory={0: "1GB"}`.
|
| 1312 |
+
no_split_module_classes (`List[str]`, *optional*):
|
| 1313 |
+
A list of layer class names that should never be split across device (for instance any layer that has a
|
| 1314 |
+
residual connection).
|
| 1315 |
+
dtype (`str` or `torch.dtype`, *optional*):
|
| 1316 |
+
If provided, the weights will be converted to that type when loaded.
|
| 1317 |
+
special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
|
| 1318 |
+
If provided, special dtypes to consider for some specific weights (will override dtype used as default for
|
| 1319 |
+
all weights).
|
| 1320 |
+
verbose (`bool`, *optional*, defaults to `False`):
|
| 1321 |
+
Whether or not to provide debugging statements as the function builds the device_map.
|
| 1322 |
+
clean_result (`bool`, *optional*, defaults to `True`):
|
| 1323 |
+
Clean the resulting device_map by grouping all submodules that go on the same device together.
|
| 1324 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 1325 |
+
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
|
| 1326 |
+
well as the parameters.
|
| 1327 |
+
fallback_allocation (`bool`, *optional*, defaults to `False`):
|
| 1328 |
+
When regular allocation fails, try to allocate a module that fits in the size limit using BFS.
|
| 1329 |
+
"""
|
| 1330 |
+
|
| 1331 |
+
# Initialize the variables
|
| 1332 |
+
(
|
| 1333 |
+
devices,
|
| 1334 |
+
max_memory,
|
| 1335 |
+
main_devices,
|
| 1336 |
+
gpus,
|
| 1337 |
+
module_sizes,
|
| 1338 |
+
tied_parameters,
|
| 1339 |
+
no_split_module_classes,
|
| 1340 |
+
modules_to_treat,
|
| 1341 |
+
) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes)
|
| 1342 |
+
|
| 1343 |
+
device_map = OrderedDict()
|
| 1344 |
+
current_device = 0
|
| 1345 |
+
device_memory_used = {device: 0 for device in devices}
|
| 1346 |
+
device_buffer_sizes = {}
|
| 1347 |
+
device_minimum_assignment_memory = {}
|
| 1348 |
+
|
| 1349 |
+
# Initialize maximum largest layer, to know which space to keep in memory
|
| 1350 |
+
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
|
| 1351 |
+
|
| 1352 |
+
# Ready ? This is going to be a bit messy.
|
| 1353 |
+
while len(modules_to_treat) > 0:
|
| 1354 |
+
name, module = modules_to_treat.pop(0)
|
| 1355 |
+
if verbose:
|
| 1356 |
+
print(f"\nTreating module {name}.")
|
| 1357 |
+
# Max size in the remaining layers may have changed since we took one, so we maybe update it.
|
| 1358 |
+
max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
|
| 1359 |
+
if len(max_layer_names) == 0:
|
| 1360 |
+
max_layer_size, max_layer_names = get_max_layer_size(
|
| 1361 |
+
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
| 1362 |
+
module_sizes,
|
| 1363 |
+
no_split_module_classes,
|
| 1364 |
+
)
|
| 1365 |
+
# Assess size needed
|
| 1366 |
+
module_size = module_sizes[name]
|
| 1367 |
+
|
| 1368 |
+
# We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
|
| 1369 |
+
# and the other is not.
|
| 1370 |
+
# Note: If we are currently processing the name `compute.weight`, an other parameter named
|
| 1371 |
+
# e.g. `compute.weight_submodule.parameter`
|
| 1372 |
+
# needs to be considered outside the current module, hence the check with additional dots.
|
| 1373 |
+
tied_param_groups = [
|
| 1374 |
+
tied_group
|
| 1375 |
+
for tied_group in tied_parameters
|
| 1376 |
+
if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
|
| 1377 |
+
]
|
| 1378 |
+
|
| 1379 |
+
if verbose and len(tied_param_groups) > 0:
|
| 1380 |
+
print(f" Found the relevant tied param groups {tied_param_groups}")
|
| 1381 |
+
|
| 1382 |
+
# Then we keep track of all the parameters that are tied to the current module, but not in the current module
|
| 1383 |
+
tied_params = sum(
|
| 1384 |
+
[[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
if verbose and len(tied_params) > 0:
|
| 1388 |
+
print(f" So those parameters need to be taken into account {tied_params}")
|
| 1389 |
+
|
| 1390 |
+
device = devices[current_device]
|
| 1391 |
+
current_max_size = max_memory[device] if device != "disk" else None
|
| 1392 |
+
current_memory_reserved = 0
|
| 1393 |
+
# Reduce max size available by the largest layer.
|
| 1394 |
+
if devices[current_device] in main_devices:
|
| 1395 |
+
current_max_size = current_max_size - max_layer_size
|
| 1396 |
+
current_memory_reserved = max_layer_size
|
| 1397 |
+
|
| 1398 |
+
module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(
|
| 1399 |
+
tied_params, module_size, module_sizes, modules_to_treat
|
| 1400 |
+
)
|
| 1401 |
+
|
| 1402 |
+
# The module and its tied modules fit on the current device.
|
| 1403 |
+
if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
|
| 1404 |
+
if verbose:
|
| 1405 |
+
output = f"Putting {name}"
|
| 1406 |
+
|
| 1407 |
+
if tied_module_names:
|
| 1408 |
+
output += f" and {tied_module_names}"
|
| 1409 |
+
else:
|
| 1410 |
+
output += f" (size={module_size})"
|
| 1411 |
+
|
| 1412 |
+
if current_max_size is not None:
|
| 1413 |
+
output += f" (available={current_max_size - device_memory_used[device]})"
|
| 1414 |
+
|
| 1415 |
+
output += f" on {device}."
|
| 1416 |
+
print(output)
|
| 1417 |
+
|
| 1418 |
+
device_memory_used[device] += module_size_with_ties
|
| 1419 |
+
|
| 1420 |
+
# Assign the primary module to the device.
|
| 1421 |
+
device_map[name] = device
|
| 1422 |
+
|
| 1423 |
+
# Assign tied modules if any.
|
| 1424 |
+
for tied_module_name in tied_module_names:
|
| 1425 |
+
if tied_module_name in [m[0] for m in modules_to_treat]:
|
| 1426 |
+
# Find the index of the tied module in the list
|
| 1427 |
+
tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name)
|
| 1428 |
+
# Remove the tied module from the list to prevent reprocessing
|
| 1429 |
+
modules_to_treat.pop(tied_module_index)
|
| 1430 |
+
|
| 1431 |
+
# Assign the tied module to the device
|
| 1432 |
+
device_map[tied_module_name] = device
|
| 1433 |
+
|
| 1434 |
+
# Buffer Handling
|
| 1435 |
+
if not offload_buffers and isinstance(module, nn.Module):
|
| 1436 |
+
# Compute the total buffer size for the module
|
| 1437 |
+
current_buffer_size = compute_module_total_buffer_size(
|
| 1438 |
+
module, dtype=dtype, special_dtypes=special_dtypes
|
| 1439 |
+
)
|
| 1440 |
+
# Update the buffer size on the device
|
| 1441 |
+
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size
|
| 1442 |
+
|
| 1443 |
+
continue
|
| 1444 |
+
|
| 1445 |
+
# The current module itself fits, so we try to split the tied modules.
|
| 1446 |
+
if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size:
|
| 1447 |
+
# can we split one of the tied modules to make it smaller or do we need to go on the next device?
|
| 1448 |
+
if verbose:
|
| 1449 |
+
print(
|
| 1450 |
+
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
|
| 1451 |
+
f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
|
| 1452 |
+
)
|
| 1453 |
+
split_happened = False
|
| 1454 |
+
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
|
| 1455 |
+
tied_module_children = list(tied_module.named_children())
|
| 1456 |
+
if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
|
| 1457 |
+
# can't break this one.
|
| 1458 |
+
continue
|
| 1459 |
+
|
| 1460 |
+
if verbose:
|
| 1461 |
+
print(f"Splitting {tied_module_name}.")
|
| 1462 |
+
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
|
| 1463 |
+
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
|
| 1464 |
+
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]
|
| 1465 |
+
|
| 1466 |
+
modules_to_treat = (
|
| 1467 |
+
[(name, module)]
|
| 1468 |
+
+ modules_to_treat[:tied_module_index]
|
| 1469 |
+
+ tied_module_children
|
| 1470 |
+
+ modules_to_treat[tied_module_index + 1 :]
|
| 1471 |
+
)
|
| 1472 |
+
# Update the max layer size.
|
| 1473 |
+
max_layer_size, max_layer_names = get_max_layer_size(
|
| 1474 |
+
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
| 1475 |
+
module_sizes,
|
| 1476 |
+
no_split_module_classes,
|
| 1477 |
+
)
|
| 1478 |
+
split_happened = True
|
| 1479 |
+
break
|
| 1480 |
+
|
| 1481 |
+
if split_happened:
|
| 1482 |
+
continue
|
| 1483 |
+
|
| 1484 |
+
# If the tied module is not split, we go to the next device
|
| 1485 |
+
if verbose:
|
| 1486 |
+
print("None of the tied module can be split, going to the next device.")
|
| 1487 |
+
|
| 1488 |
+
# The current module itself doesn't fit, so we have to split it or go to the next device.
|
| 1489 |
+
if device_memory_used[device] + module_size >= current_max_size:
|
| 1490 |
+
# Split or not split?
|
| 1491 |
+
modules_children = (
|
| 1492 |
+
[]
|
| 1493 |
+
if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
|
| 1494 |
+
else list(module.named_children())
|
| 1495 |
+
)
|
| 1496 |
+
if verbose:
|
| 1497 |
+
print(
|
| 1498 |
+
f"Not enough space on {devices[current_device]} to put {name} (space available "
|
| 1499 |
+
f"{current_max_size - device_memory_used[device]}, module size {module_size})."
|
| 1500 |
+
)
|
| 1501 |
+
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
|
| 1502 |
+
# -> no split, we go to the next device
|
| 1503 |
+
if verbose:
|
| 1504 |
+
print("This module cannot be split, going to the next device.")
|
| 1505 |
+
|
| 1506 |
+
else:
|
| 1507 |
+
# -> split, we replace the module studied by its children + parameters
|
| 1508 |
+
if verbose:
|
| 1509 |
+
print(f"Splitting {name}.")
|
| 1510 |
+
modules_children = list(module.named_parameters(recurse=False)) + modules_children
|
| 1511 |
+
modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
|
| 1512 |
+
# Update the max layer size.
|
| 1513 |
+
max_layer_size, max_layer_names = get_max_layer_size(
|
| 1514 |
+
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
| 1515 |
+
module_sizes,
|
| 1516 |
+
no_split_module_classes,
|
| 1517 |
+
)
|
| 1518 |
+
continue
|
| 1519 |
+
|
| 1520 |
+
# If no module is assigned to the current device, we attempt to allocate a fallback module
|
| 1521 |
+
# if fallback_allocation is enabled.
|
| 1522 |
+
if device_memory_used[device] == 0 and fallback_allocation and device != "disk":
|
| 1523 |
+
# We try to allocate a module that fits in the size limit using BFS.
|
| 1524 |
+
# Recompute the current max size as we need to consider the current module as well.
|
| 1525 |
+
current_max_size = max_memory[device] - max(max_layer_size, module_size_with_ties)
|
| 1526 |
+
|
| 1527 |
+
fallback_module_name, fallback_module, remaining_modules = fallback_allocate(
|
| 1528 |
+
modules_to_treat,
|
| 1529 |
+
module_sizes,
|
| 1530 |
+
current_max_size - device_memory_used[device],
|
| 1531 |
+
no_split_module_classes,
|
| 1532 |
+
tied_parameters,
|
| 1533 |
+
)
|
| 1534 |
+
# use the next iteration to put the fallback module on the next device to avoid code duplication
|
| 1535 |
+
if fallback_module is not None:
|
| 1536 |
+
modules_to_treat = [(fallback_module_name, fallback_module)] + [(name, module)] + remaining_modules
|
| 1537 |
+
continue
|
| 1538 |
+
|
| 1539 |
+
if device_memory_used[device] == 0:
|
| 1540 |
+
device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved
|
| 1541 |
+
|
| 1542 |
+
# Neither the current module nor any tied modules can be split, so we move to the next device.
|
| 1543 |
+
device_memory_used[device] = device_memory_used[device] + current_memory_reserved
|
| 1544 |
+
current_device += 1
|
| 1545 |
+
modules_to_treat = [(name, module)] + modules_to_treat
|
| 1546 |
+
|
| 1547 |
+
device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}
|
| 1548 |
+
|
| 1549 |
+
if clean_result:
|
| 1550 |
+
device_map = clean_device_map(device_map)
|
| 1551 |
+
|
| 1552 |
+
non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)
|
| 1553 |
+
if non_gpu_buffer_size > 0 and not offload_buffers:
|
| 1554 |
+
is_buffer_fit_any_gpu = False
|
| 1555 |
+
for gpu_device, gpu_max_memory in max_memory.items():
|
| 1556 |
+
if gpu_device == "cpu" or gpu_device == "disk":
|
| 1557 |
+
continue
|
| 1558 |
+
|
| 1559 |
+
if not is_buffer_fit_any_gpu:
|
| 1560 |
+
gpu_memory_used = device_memory_used.get(gpu_device, 0)
|
| 1561 |
+
|
| 1562 |
+
if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
|
| 1563 |
+
is_buffer_fit_any_gpu = True
|
| 1564 |
+
|
| 1565 |
+
if len(gpus) > 0 and not is_buffer_fit_any_gpu:
|
| 1566 |
+
warnings.warn(
|
| 1567 |
+
f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does "
|
| 1568 |
+
f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using "
|
| 1569 |
+
f"offload_buffers=True."
|
| 1570 |
+
)
|
| 1571 |
+
|
| 1572 |
+
if device_minimum_assignment_memory:
|
| 1573 |
+
devices_info = "\n".join(
|
| 1574 |
+
f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items()
|
| 1575 |
+
)
|
| 1576 |
+
logger.info(
|
| 1577 |
+
f"Based on the current allocation process, no modules could be assigned to the following devices due to "
|
| 1578 |
+
f"insufficient memory:\n"
|
| 1579 |
+
f"{devices_info}\n"
|
| 1580 |
+
f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing "
|
| 1581 |
+
f"the available memory for these devices to at least the specified minimum, or adjusting the model config."
|
| 1582 |
+
)
|
| 1583 |
+
return device_map
|
| 1584 |
+
|
| 1585 |
+
|
| 1586 |
+
def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, torch.device]]):
|
| 1587 |
+
"""
|
| 1588 |
+
Checks a device map covers everything in a given model.
|
| 1589 |
+
|
| 1590 |
+
Args:
|
| 1591 |
+
model (`torch.nn.Module`): The model to check the device map against.
|
| 1592 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.
|
| 1593 |
+
"""
|
| 1594 |
+
all_module_names = dict(model.named_modules())
|
| 1595 |
+
invalid_keys = [k for k in device_map if k != "" and k not in all_module_names]
|
| 1596 |
+
|
| 1597 |
+
if invalid_keys:
|
| 1598 |
+
warnings.warn(
|
| 1599 |
+
f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning
|
| 1600 |
+
)
|
| 1601 |
+
|
| 1602 |
+
all_model_tensors = [name for name, _ in model.state_dict().items()]
|
| 1603 |
+
for module_name in device_map.keys():
|
| 1604 |
+
if module_name == "":
|
| 1605 |
+
all_model_tensors.clear()
|
| 1606 |
+
break
|
| 1607 |
+
else:
|
| 1608 |
+
all_model_tensors = [
|
| 1609 |
+
name
|
| 1610 |
+
for name in all_model_tensors
|
| 1611 |
+
if not name == module_name and not name.startswith(module_name + ".")
|
| 1612 |
+
]
|
| 1613 |
+
if len(all_model_tensors) > 0:
|
| 1614 |
+
non_covered_params = ", ".join(all_model_tensors)
|
| 1615 |
+
raise ValueError(
|
| 1616 |
+
f"The device_map provided does not give any device for the following parameters: {non_covered_params}"
|
| 1617 |
+
)
|
| 1618 |
+
|
| 1619 |
+
|
| 1620 |
+
def load_state_dict(checkpoint_file, device_map=None):
|
| 1621 |
+
"""
|
| 1622 |
+
Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the
|
| 1623 |
+
weights can be fast-loaded directly on the GPU.
|
| 1624 |
+
|
| 1625 |
+
Args:
|
| 1626 |
+
checkpoint_file (`str`): The path to the checkpoint to load.
|
| 1627 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 1628 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
| 1629 |
+
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
| 1630 |
+
"""
|
| 1631 |
+
if checkpoint_file.endswith(".safetensors"):
|
| 1632 |
+
with safe_open(checkpoint_file, framework="pt") as f:
|
| 1633 |
+
metadata = f.metadata()
|
| 1634 |
+
weight_names = f.keys()
|
| 1635 |
+
|
| 1636 |
+
if metadata is None:
|
| 1637 |
+
logger.warning(
|
| 1638 |
+
f"The safetensors archive passed at {checkpoint_file} does not contain metadata. "
|
| 1639 |
+
"Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata."
|
| 1640 |
+
)
|
| 1641 |
+
metadata = {"format": "pt"}
|
| 1642 |
+
|
| 1643 |
+
if metadata.get("format") not in ["pt", "tf", "flax"]:
|
| 1644 |
+
raise OSError(
|
| 1645 |
+
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
| 1646 |
+
"you save your model with the `save_pretrained` method."
|
| 1647 |
+
)
|
| 1648 |
+
elif metadata["format"] != "pt":
|
| 1649 |
+
raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.")
|
| 1650 |
+
if device_map is None:
|
| 1651 |
+
return safe_load_file(checkpoint_file)
|
| 1652 |
+
else:
|
| 1653 |
+
# if we only have one device we can load everything directly
|
| 1654 |
+
if len(set(device_map.values())) == 1:
|
| 1655 |
+
device = list(device_map.values())[0]
|
| 1656 |
+
target_device = device
|
| 1657 |
+
if isinstance(device, int):
|
| 1658 |
+
if is_npu_available():
|
| 1659 |
+
target_device = f"npu:{device}"
|
| 1660 |
+
elif is_hpu_available():
|
| 1661 |
+
target_device = "hpu"
|
| 1662 |
+
|
| 1663 |
+
return safe_load_file(checkpoint_file, device=target_device)
|
| 1664 |
+
|
| 1665 |
+
devices = list(set(device_map.values()) - {"disk"})
|
| 1666 |
+
# cpu device should always exist as fallback option
|
| 1667 |
+
if "cpu" not in devices:
|
| 1668 |
+
devices.append("cpu")
|
| 1669 |
+
|
| 1670 |
+
# For each device, get the weights that go there
|
| 1671 |
+
device_weights = {device: [] for device in devices}
|
| 1672 |
+
for module_name, device in device_map.items():
|
| 1673 |
+
if device in devices:
|
| 1674 |
+
device_weights[device].extend(
|
| 1675 |
+
[k for k in weight_names if k == module_name or k.startswith(module_name + ".")]
|
| 1676 |
+
)
|
| 1677 |
+
|
| 1678 |
+
# all weights that haven't defined a device should be loaded on CPU
|
| 1679 |
+
device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])])
|
| 1680 |
+
tensors = {}
|
| 1681 |
+
if is_tqdm_available():
|
| 1682 |
+
progress_bar = tqdm(
|
| 1683 |
+
main_process_only=False,
|
| 1684 |
+
total=sum([len(device_weights[device]) for device in devices]),
|
| 1685 |
+
unit="w",
|
| 1686 |
+
smoothing=0,
|
| 1687 |
+
leave=False,
|
| 1688 |
+
)
|
| 1689 |
+
else:
|
| 1690 |
+
progress_bar = None
|
| 1691 |
+
for device in devices:
|
| 1692 |
+
target_device = device
|
| 1693 |
+
if isinstance(device, int):
|
| 1694 |
+
if is_npu_available():
|
| 1695 |
+
target_device = f"npu:{device}"
|
| 1696 |
+
elif is_hpu_available():
|
| 1697 |
+
target_device = "hpu"
|
| 1698 |
+
|
| 1699 |
+
with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
|
| 1700 |
+
for key in device_weights[device]:
|
| 1701 |
+
if progress_bar is not None:
|
| 1702 |
+
progress_bar.set_postfix(dev=device, refresh=False)
|
| 1703 |
+
progress_bar.set_description(key)
|
| 1704 |
+
tensors[key] = f.get_tensor(key)
|
| 1705 |
+
if progress_bar is not None:
|
| 1706 |
+
progress_bar.update()
|
| 1707 |
+
if progress_bar is not None:
|
| 1708 |
+
progress_bar.close()
|
| 1709 |
+
|
| 1710 |
+
return tensors
|
| 1711 |
+
else:
|
| 1712 |
+
return torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)
|
| 1713 |
+
|
| 1714 |
+
|
| 1715 |
+
def get_state_dict_offloaded_model(model: nn.Module):
|
| 1716 |
+
"""
|
| 1717 |
+
Returns the state dictionary for an offloaded model via iterative onloading
|
| 1718 |
+
|
| 1719 |
+
Args:
|
| 1720 |
+
model (`torch.nn.Module`):
|
| 1721 |
+
The offloaded model we want to save
|
| 1722 |
+
"""
|
| 1723 |
+
|
| 1724 |
+
state_dict = {}
|
| 1725 |
+
placeholders = set()
|
| 1726 |
+
for name, module in model.named_modules():
|
| 1727 |
+
if name == "":
|
| 1728 |
+
continue
|
| 1729 |
+
|
| 1730 |
+
try:
|
| 1731 |
+
with align_module_device(module, "cpu"):
|
| 1732 |
+
module_state_dict = module.state_dict()
|
| 1733 |
+
except MemoryError:
|
| 1734 |
+
raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None
|
| 1735 |
+
|
| 1736 |
+
for key in module_state_dict:
|
| 1737 |
+
# ignore placeholder parameters that are still on the meta device
|
| 1738 |
+
if module_state_dict[key].device == torch.device("meta"):
|
| 1739 |
+
placeholders.add(name + f".{key}")
|
| 1740 |
+
continue
|
| 1741 |
+
params = module_state_dict[key]
|
| 1742 |
+
state_dict[name + f".{key}"] = params.to("cpu") # move buffers to cpu
|
| 1743 |
+
for key in placeholders.copy():
|
| 1744 |
+
if key in state_dict:
|
| 1745 |
+
placeholders.remove(key)
|
| 1746 |
+
if placeholders:
|
| 1747 |
+
logger.warning(f"The following tensors were not saved because they were still on meta device: {placeholders}")
|
| 1748 |
+
|
| 1749 |
+
return state_dict
|
| 1750 |
+
|
| 1751 |
+
|
| 1752 |
+
def get_state_dict_from_offload(
|
| 1753 |
+
module: nn.Module,
|
| 1754 |
+
module_name: str,
|
| 1755 |
+
state_dict: dict[str, Union[str, torch.tensor]],
|
| 1756 |
+
device_to_put_offload: Union[int, str, torch.device] = "cpu",
|
| 1757 |
+
):
|
| 1758 |
+
"""
|
| 1759 |
+
Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defaults
|
| 1760 |
+
to cpu).
|
| 1761 |
+
|
| 1762 |
+
Args:
|
| 1763 |
+
module: (`torch.nn.Module`):
|
| 1764 |
+
The module we want to retrieve a state dictionary from
|
| 1765 |
+
module_name: (`str`):
|
| 1766 |
+
The name of the module of interest
|
| 1767 |
+
state_dict (`Dict[str, Union[int, str, torch.device]]`):
|
| 1768 |
+
Dictionary of {module names: parameters}
|
| 1769 |
+
device_to_put_offload (`Union[int, str, torch.device]`):
|
| 1770 |
+
Device to load offloaded parameters into, defaults to the cpu.
|
| 1771 |
+
"""
|
| 1772 |
+
|
| 1773 |
+
root = module_name[: module_name.rfind(".")] # module name without .weight or .bias
|
| 1774 |
+
|
| 1775 |
+
# do not move parameters if the module is not offloaded
|
| 1776 |
+
if not has_offloaded_params(module):
|
| 1777 |
+
device_to_put_offload = None
|
| 1778 |
+
|
| 1779 |
+
# assign the device to which the offloaded parameters will be sent
|
| 1780 |
+
with align_module_device(module, device_to_put_offload):
|
| 1781 |
+
for m_key, params in module.state_dict().items():
|
| 1782 |
+
if (root + f".{m_key}") in state_dict:
|
| 1783 |
+
state_dict[root + f".{m_key}"] = params
|
| 1784 |
+
|
| 1785 |
+
return state_dict
|
| 1786 |
+
|
| 1787 |
+
|
| 1788 |
+
def load_checkpoint_in_model(
|
| 1789 |
+
model: nn.Module,
|
| 1790 |
+
checkpoint: Union[str, os.PathLike],
|
| 1791 |
+
device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
|
| 1792 |
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
| 1793 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
| 1794 |
+
offload_state_dict: bool = False,
|
| 1795 |
+
offload_buffers: bool = False,
|
| 1796 |
+
keep_in_fp32_modules: Optional[list[str]] = None,
|
| 1797 |
+
offload_8bit_bnb: bool = False,
|
| 1798 |
+
strict: bool = False,
|
| 1799 |
+
full_state_dict: bool = True,
|
| 1800 |
+
broadcast_from_rank0: bool = False,
|
| 1801 |
+
):
|
| 1802 |
+
"""
|
| 1803 |
+
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
|
| 1804 |
+
loaded.
|
| 1805 |
+
|
| 1806 |
+
<Tip warning={true}>
|
| 1807 |
+
|
| 1808 |
+
Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To
|
| 1809 |
+
group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`].
|
| 1810 |
+
|
| 1811 |
+
</Tip>
|
| 1812 |
+
|
| 1813 |
+
Args:
|
| 1814 |
+
model (`torch.nn.Module`):
|
| 1815 |
+
The model in which we want to load a checkpoint.
|
| 1816 |
+
checkpoint (`str` or `os.PathLike`):
|
| 1817 |
+
The folder checkpoint to load. It can be:
|
| 1818 |
+
- a path to a file containing a whole model state dict
|
| 1819 |
+
- a path to a `.json` file containing the index to a sharded checkpoint
|
| 1820 |
+
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
| 1821 |
+
- a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.
|
| 1822 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 1823 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
| 1824 |
+
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
| 1825 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
| 1826 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
| 1827 |
+
dtype (`str` or `torch.dtype`, *optional*):
|
| 1828 |
+
If provided, the weights will be converted to that type when loaded.
|
| 1829 |
+
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
| 1830 |
+
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
|
| 1831 |
+
the weight of the CPU state dict + the biggest shard does not fit.
|
| 1832 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
| 1833 |
+
Whether or not to include the buffers in the weights offloaded to disk.
|
| 1834 |
+
keep_in_fp32_modules(`List[str]`, *optional*):
|
| 1835 |
+
A list of the modules that we keep in `torch.float32` dtype.
|
| 1836 |
+
offload_8bit_bnb (`bool`, *optional*):
|
| 1837 |
+
Whether or not to enable offload of 8-bit modules on cpu/disk.
|
| 1838 |
+
strict (`bool`, *optional*, defaults to `False`):
|
| 1839 |
+
Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
|
| 1840 |
+
state_dict.
|
| 1841 |
+
full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the
|
| 1842 |
+
loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.
|
| 1843 |
+
broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed
|
| 1844 |
+
`ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors
|
| 1845 |
+
in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)
|
| 1846 |
+
according to the local shards in the model.
|
| 1847 |
+
|
| 1848 |
+
"""
|
| 1849 |
+
if offload_8bit_bnb:
|
| 1850 |
+
from .bnb import quantize_and_offload_8bit
|
| 1851 |
+
|
| 1852 |
+
tied_params = find_tied_parameters(model)
|
| 1853 |
+
|
| 1854 |
+
if check_tied_parameters_in_config(model) and len(tied_params) == 0:
|
| 1855 |
+
logger.warning(
|
| 1856 |
+
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
|
| 1857 |
+
)
|
| 1858 |
+
if device_map is not None:
|
| 1859 |
+
check_tied_parameters_on_same_device(tied_params, device_map)
|
| 1860 |
+
|
| 1861 |
+
if offload_folder is None and device_map is not None and "disk" in device_map.values():
|
| 1862 |
+
raise ValueError(
|
| 1863 |
+
"At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`."
|
| 1864 |
+
)
|
| 1865 |
+
elif offload_folder is not None and device_map is not None and "disk" in device_map.values():
|
| 1866 |
+
os.makedirs(offload_folder, exist_ok=True)
|
| 1867 |
+
|
| 1868 |
+
if isinstance(dtype, str):
|
| 1869 |
+
# We accept "torch.float16" or just "float16"
|
| 1870 |
+
dtype = dtype.replace("torch.", "")
|
| 1871 |
+
dtype = getattr(torch, dtype)
|
| 1872 |
+
|
| 1873 |
+
checkpoint_files = None
|
| 1874 |
+
index_filename = None
|
| 1875 |
+
if os.path.isfile(checkpoint):
|
| 1876 |
+
if str(checkpoint).endswith(".json"):
|
| 1877 |
+
index_filename = checkpoint
|
| 1878 |
+
else:
|
| 1879 |
+
checkpoint_files = [checkpoint]
|
| 1880 |
+
elif os.path.isdir(checkpoint):
|
| 1881 |
+
# check if the whole state dict is present
|
| 1882 |
+
potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
|
| 1883 |
+
potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME]
|
| 1884 |
+
if len(potential_state_bin) == 1:
|
| 1885 |
+
checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
|
| 1886 |
+
elif len(potential_state_safetensor) == 1:
|
| 1887 |
+
checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]
|
| 1888 |
+
else:
|
| 1889 |
+
# otherwise check for sharded checkpoints
|
| 1890 |
+
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
|
| 1891 |
+
if len(potential_index) == 0:
|
| 1892 |
+
raise ValueError(
|
| 1893 |
+
f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
|
| 1894 |
+
)
|
| 1895 |
+
elif len(potential_index) == 1:
|
| 1896 |
+
index_filename = os.path.join(checkpoint, potential_index[0])
|
| 1897 |
+
else:
|
| 1898 |
+
raise ValueError(
|
| 1899 |
+
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
|
| 1900 |
+
)
|
| 1901 |
+
else:
|
| 1902 |
+
raise ValueError(
|
| 1903 |
+
"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
|
| 1904 |
+
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
|
| 1905 |
+
)
|
| 1906 |
+
|
| 1907 |
+
if index_filename is not None:
|
| 1908 |
+
checkpoint_folder = os.path.split(index_filename)[0]
|
| 1909 |
+
with open(index_filename) as f:
|
| 1910 |
+
index = json.loads(f.read())
|
| 1911 |
+
|
| 1912 |
+
if "weight_map" in index:
|
| 1913 |
+
index = index["weight_map"]
|
| 1914 |
+
checkpoint_files = sorted(list(set(index.values())))
|
| 1915 |
+
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
|
| 1916 |
+
|
| 1917 |
+
# Logic for missing/unexepected keys goes here.
|
| 1918 |
+
|
| 1919 |
+
offload_index = {}
|
| 1920 |
+
if offload_state_dict:
|
| 1921 |
+
state_dict_folder = tempfile.mkdtemp()
|
| 1922 |
+
state_dict_index = {}
|
| 1923 |
+
|
| 1924 |
+
unexpected_keys = set()
|
| 1925 |
+
model_keys = set(model.state_dict().keys())
|
| 1926 |
+
buffer_names = [name for name, _ in model.named_buffers()]
|
| 1927 |
+
model_devices = {t.device for t in model.state_dict().values() if isinstance(t, torch.Tensor)}
|
| 1928 |
+
model_physical_devices = model_devices - {torch.device("meta")}
|
| 1929 |
+
for checkpoint_file in checkpoint_files:
|
| 1930 |
+
if device_map is None:
|
| 1931 |
+
# exception for multi-device loading was made for the meta device in torch v2.7.0
|
| 1932 |
+
# https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/checkpoint/state_dict.py#L557-L563
|
| 1933 |
+
# https://github.com/pytorch/pytorch/blob/v2.7.0-rc2/torch/distributed/checkpoint/state_dict.py#L575-L587
|
| 1934 |
+
if is_torch_version(">=", "2.2.0") and (
|
| 1935 |
+
(is_torch_version(">=", "2.7.0") and len(model_physical_devices) <= 1) or len(model_devices) <= 1
|
| 1936 |
+
):
|
| 1937 |
+
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
|
| 1938 |
+
|
| 1939 |
+
broadcast_from_rank0 &= is_torch_version(">=", "2.4.0")
|
| 1940 |
+
loaded_checkpoint = (
|
| 1941 |
+
load_state_dict(checkpoint_file, device_map=device_map)
|
| 1942 |
+
if not broadcast_from_rank0 or dist.get_rank() == 0
|
| 1943 |
+
else {}
|
| 1944 |
+
)
|
| 1945 |
+
set_model_state_dict(
|
| 1946 |
+
model,
|
| 1947 |
+
loaded_checkpoint,
|
| 1948 |
+
options=StateDictOptions(
|
| 1949 |
+
full_state_dict=full_state_dict,
|
| 1950 |
+
strict=strict,
|
| 1951 |
+
**({"broadcast_from_rank0": broadcast_from_rank0} if is_torch_version(">=", "2.4.0") else {}),
|
| 1952 |
+
),
|
| 1953 |
+
)
|
| 1954 |
+
else:
|
| 1955 |
+
loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
|
| 1956 |
+
model.load_state_dict(loaded_checkpoint, strict=strict)
|
| 1957 |
+
|
| 1958 |
+
unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys)
|
| 1959 |
+
else:
|
| 1960 |
+
loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
|
| 1961 |
+
|
| 1962 |
+
for param_name, param in loaded_checkpoint.items():
|
| 1963 |
+
# skip SCB parameter (for 8-bit serialization)
|
| 1964 |
+
if "SCB" in param_name:
|
| 1965 |
+
continue
|
| 1966 |
+
|
| 1967 |
+
if param_name not in model_keys:
|
| 1968 |
+
unexpected_keys.add(param_name)
|
| 1969 |
+
if not strict:
|
| 1970 |
+
continue # Skip loading this parameter.
|
| 1971 |
+
|
| 1972 |
+
module_name = param_name
|
| 1973 |
+
|
| 1974 |
+
while len(module_name) > 0 and module_name not in device_map:
|
| 1975 |
+
module_name = ".".join(module_name.split(".")[:-1])
|
| 1976 |
+
if module_name == "" and "" not in device_map:
|
| 1977 |
+
# TODO: group all errors and raise at the end.
|
| 1978 |
+
raise ValueError(f"{param_name} doesn't have any device set.")
|
| 1979 |
+
param_device = device_map[module_name]
|
| 1980 |
+
new_dtype = dtype
|
| 1981 |
+
if dtype is not None and torch.is_floating_point(param):
|
| 1982 |
+
if keep_in_fp32_modules is not None and dtype == torch.float16:
|
| 1983 |
+
proceed = False
|
| 1984 |
+
for key in keep_in_fp32_modules:
|
| 1985 |
+
if ((key in param_name) and (key + "." in param_name)) or key == param_name:
|
| 1986 |
+
proceed = True
|
| 1987 |
+
break
|
| 1988 |
+
if proceed:
|
| 1989 |
+
new_dtype = torch.float32
|
| 1990 |
+
|
| 1991 |
+
if "weight" in param_name and param_name.replace("weight", "SCB") in loaded_checkpoint.keys():
|
| 1992 |
+
if param.dtype == torch.int8:
|
| 1993 |
+
fp16_statistics = loaded_checkpoint[param_name.replace("weight", "SCB")]
|
| 1994 |
+
else:
|
| 1995 |
+
fp16_statistics = None
|
| 1996 |
+
|
| 1997 |
+
if param_device == "disk":
|
| 1998 |
+
if offload_buffers or param_name not in buffer_names:
|
| 1999 |
+
if new_dtype is None:
|
| 2000 |
+
new_dtype = param.dtype
|
| 2001 |
+
if offload_8bit_bnb:
|
| 2002 |
+
quantize_and_offload_8bit(
|
| 2003 |
+
model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics
|
| 2004 |
+
)
|
| 2005 |
+
continue
|
| 2006 |
+
else:
|
| 2007 |
+
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
|
| 2008 |
+
offload_weight(param, param_name, offload_folder, index=offload_index)
|
| 2009 |
+
elif param_device == "cpu" and offload_state_dict:
|
| 2010 |
+
if new_dtype is None:
|
| 2011 |
+
new_dtype = param.dtype
|
| 2012 |
+
if offload_8bit_bnb:
|
| 2013 |
+
quantize_and_offload_8bit(
|
| 2014 |
+
model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics
|
| 2015 |
+
)
|
| 2016 |
+
else:
|
| 2017 |
+
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
|
| 2018 |
+
offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
|
| 2019 |
+
else:
|
| 2020 |
+
set_module_tensor_to_device(
|
| 2021 |
+
model,
|
| 2022 |
+
param_name,
|
| 2023 |
+
param_device,
|
| 2024 |
+
value=param,
|
| 2025 |
+
dtype=new_dtype,
|
| 2026 |
+
fp16_statistics=fp16_statistics,
|
| 2027 |
+
)
|
| 2028 |
+
|
| 2029 |
+
# Force Python to clean up.
|
| 2030 |
+
del loaded_checkpoint
|
| 2031 |
+
gc.collect()
|
| 2032 |
+
|
| 2033 |
+
if not strict and len(unexpected_keys) > 0:
|
| 2034 |
+
logger.warning(
|
| 2035 |
+
f"Some weights of the model checkpoint at {checkpoint} were not used when"
|
| 2036 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint."
|
| 2037 |
+
)
|
| 2038 |
+
|
| 2039 |
+
save_offload_index(offload_index, offload_folder)
|
| 2040 |
+
|
| 2041 |
+
# Load back offloaded state dict on CPU
|
| 2042 |
+
if offload_state_dict:
|
| 2043 |
+
load_offloaded_weights(model, state_dict_index, state_dict_folder)
|
| 2044 |
+
shutil.rmtree(state_dict_folder)
|
| 2045 |
+
|
| 2046 |
+
retie_parameters(model, tied_params)
|
| 2047 |
+
|
| 2048 |
+
|
| 2049 |
+
def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwargs: AutocastKwargs = None):
|
| 2050 |
+
"""
|
| 2051 |
+
Return a context manager for autocasting mixed precision
|
| 2052 |
+
|
| 2053 |
+
Args:
|
| 2054 |
+
native_amp (`bool`, *optional*, defaults to False):
|
| 2055 |
+
Whether mixed precision is actually enabled.
|
| 2056 |
+
cache_enabled (`bool`, *optional*, defaults to True):
|
| 2057 |
+
Whether the weight cache inside autocast should be enabled.
|
| 2058 |
+
"""
|
| 2059 |
+
state = AcceleratorState()
|
| 2060 |
+
if autocast_kwargs is None:
|
| 2061 |
+
autocast_kwargs = {}
|
| 2062 |
+
else:
|
| 2063 |
+
autocast_kwargs = autocast_kwargs.to_kwargs()
|
| 2064 |
+
if native_amp:
|
| 2065 |
+
device_type = (
|
| 2066 |
+
"cuda"
|
| 2067 |
+
if (state.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_gpu=True))
|
| 2068 |
+
else state.device.type
|
| 2069 |
+
)
|
| 2070 |
+
if state.mixed_precision == "fp16":
|
| 2071 |
+
return torch.autocast(device_type=device_type, dtype=torch.float16, **autocast_kwargs)
|
| 2072 |
+
elif state.mixed_precision in ["bf16", "fp8"] and state.distributed_type in [
|
| 2073 |
+
DistributedType.NO,
|
| 2074 |
+
DistributedType.MULTI_CPU,
|
| 2075 |
+
DistributedType.MULTI_GPU,
|
| 2076 |
+
DistributedType.MULTI_MLU,
|
| 2077 |
+
DistributedType.MULTI_SDAA,
|
| 2078 |
+
DistributedType.MULTI_MUSA,
|
| 2079 |
+
DistributedType.MULTI_NPU,
|
| 2080 |
+
DistributedType.MULTI_XPU,
|
| 2081 |
+
DistributedType.MULTI_HPU,
|
| 2082 |
+
DistributedType.FSDP,
|
| 2083 |
+
DistributedType.XLA,
|
| 2084 |
+
]:
|
| 2085 |
+
return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)
|
| 2086 |
+
else:
|
| 2087 |
+
return torch.autocast(device_type=device_type, **autocast_kwargs)
|
| 2088 |
+
else:
|
| 2089 |
+
return contextlib.nullcontext()
|
| 2090 |
+
|
| 2091 |
+
|
| 2092 |
+
def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
|
| 2093 |
+
"""
|
| 2094 |
+
A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return
|
| 2095 |
+
it.
|
| 2096 |
+
|
| 2097 |
+
Args:
|
| 2098 |
+
distributed_type (`DistributedType`, *optional*, defaults to None):
|
| 2099 |
+
The type of distributed environment.
|
| 2100 |
+
kwargs:
|
| 2101 |
+
Additional arguments for the utilized `GradScaler` constructor.
|
| 2102 |
+
"""
|
| 2103 |
+
if distributed_type == DistributedType.FSDP:
|
| 2104 |
+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
| 2105 |
+
|
| 2106 |
+
return ShardedGradScaler(**kwargs)
|
| 2107 |
+
if is_torch_xla_available(check_is_gpu=True):
|
| 2108 |
+
import torch_xla.amp as xamp
|
| 2109 |
+
|
| 2110 |
+
return xamp.GradScaler(**kwargs)
|
| 2111 |
+
elif is_mlu_available():
|
| 2112 |
+
return torch.mlu.amp.GradScaler(**kwargs)
|
| 2113 |
+
elif is_sdaa_available():
|
| 2114 |
+
return torch.sdaa.amp.GradScaler(**kwargs)
|
| 2115 |
+
elif is_musa_available():
|
| 2116 |
+
return torch.musa.amp.GradScaler(**kwargs)
|
| 2117 |
+
elif is_npu_available():
|
| 2118 |
+
return torch.npu.amp.GradScaler(**kwargs)
|
| 2119 |
+
elif is_hpu_available():
|
| 2120 |
+
return torch.amp.GradScaler("hpu", **kwargs)
|
| 2121 |
+
elif is_xpu_available():
|
| 2122 |
+
return torch.amp.GradScaler("xpu", **kwargs)
|
| 2123 |
+
elif is_mps_available():
|
| 2124 |
+
if not is_torch_version(">=", "2.8.0"):
|
| 2125 |
+
raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
|
| 2126 |
+
return torch.amp.GradScaler("mps", **kwargs)
|
| 2127 |
+
else:
|
| 2128 |
+
if is_torch_version(">=", "2.3"):
|
| 2129 |
+
return torch.amp.GradScaler("cuda", **kwargs)
|
| 2130 |
+
else:
|
| 2131 |
+
return torch.cuda.amp.GradScaler(**kwargs)
|
| 2132 |
+
|
| 2133 |
+
|
| 2134 |
+
def has_offloaded_params(module: torch.nn.Module) -> bool:
|
| 2135 |
+
"""
|
| 2136 |
+
Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with
|
| 2137 |
+
offloading enabled
|
| 2138 |
+
|
| 2139 |
+
Args:
|
| 2140 |
+
module (`torch.nn.Module`): The module to check for an offload hook.
|
| 2141 |
+
|
| 2142 |
+
Returns:
|
| 2143 |
+
bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise.
|
| 2144 |
+
"""
|
| 2145 |
+
from ..hooks import AlignDevicesHook # avoid circular import
|
| 2146 |
+
|
| 2147 |
+
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload
|
| 2148 |
+
|
| 2149 |
+
|
| 2150 |
+
@contextlib.contextmanager
|
| 2151 |
+
def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None):
|
| 2152 |
+
"""
|
| 2153 |
+
Context manager that moves a module's parameters to the specified execution device.
|
| 2154 |
+
|
| 2155 |
+
Args:
|
| 2156 |
+
module (`torch.nn.Module`):
|
| 2157 |
+
Module with parameters to align.
|
| 2158 |
+
execution_device (`torch.device`, *optional*):
|
| 2159 |
+
If provided, overrides the module's execution device within the context. Otherwise, use hook execution
|
| 2160 |
+
device or pass
|
| 2161 |
+
"""
|
| 2162 |
+
if has_offloaded_params(module):
|
| 2163 |
+
if execution_device is not None:
|
| 2164 |
+
original_device = module._hf_hook.execution_device
|
| 2165 |
+
module._hf_hook.execution_device = execution_device
|
| 2166 |
+
|
| 2167 |
+
try:
|
| 2168 |
+
module._hf_hook.pre_forward(module)
|
| 2169 |
+
yield
|
| 2170 |
+
finally:
|
| 2171 |
+
module._hf_hook.post_forward(module, None)
|
| 2172 |
+
if execution_device is not None:
|
| 2173 |
+
module._hf_hook.execution_device = original_device
|
| 2174 |
+
|
| 2175 |
+
elif execution_device is not None:
|
| 2176 |
+
devices = {name: param.device for name, param in module.named_parameters(recurse=False)}
|
| 2177 |
+
try:
|
| 2178 |
+
for name in devices:
|
| 2179 |
+
set_module_tensor_to_device(module, name, execution_device)
|
| 2180 |
+
yield
|
| 2181 |
+
finally:
|
| 2182 |
+
for name, device in devices.items():
|
| 2183 |
+
set_module_tensor_to_device(module, name, device)
|
| 2184 |
+
|
| 2185 |
+
else:
|
| 2186 |
+
yield
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/offload.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from collections.abc import Mapping
|
| 18 |
+
from typing import Optional, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from safetensors import safe_open
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def offload_weight(weight, weight_name, offload_folder, index=None):
|
| 26 |
+
dtype = None
|
| 27 |
+
# Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16.
|
| 28 |
+
if str(weight.dtype) == "torch.bfloat16":
|
| 29 |
+
# Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s.
|
| 30 |
+
weight = weight.view(torch.int16)
|
| 31 |
+
dtype = "bfloat16"
|
| 32 |
+
array = weight.cpu().numpy()
|
| 33 |
+
tensor_file = os.path.join(offload_folder, f"{weight_name}.dat")
|
| 34 |
+
if index is not None:
|
| 35 |
+
if dtype is None:
|
| 36 |
+
dtype = str(array.dtype)
|
| 37 |
+
index[weight_name] = {"dtype": dtype, "shape": list(array.shape)}
|
| 38 |
+
if array.ndim == 0:
|
| 39 |
+
array = array[None]
|
| 40 |
+
file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
|
| 41 |
+
file_array[:] = array[:]
|
| 42 |
+
file_array.flush()
|
| 43 |
+
return index
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_offloaded_weight(weight_file, weight_info):
|
| 47 |
+
shape = tuple(weight_info["shape"])
|
| 48 |
+
if shape == ():
|
| 49 |
+
# NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor
|
| 50 |
+
shape = (1,)
|
| 51 |
+
|
| 52 |
+
dtype = weight_info["dtype"]
|
| 53 |
+
if dtype == "bfloat16":
|
| 54 |
+
# NumPy does not support bfloat16 so this was saved as a int16
|
| 55 |
+
dtype = "int16"
|
| 56 |
+
|
| 57 |
+
weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode="r")
|
| 58 |
+
|
| 59 |
+
if len(weight_info["shape"]) == 0:
|
| 60 |
+
weight = weight[0]
|
| 61 |
+
weight = torch.tensor(weight)
|
| 62 |
+
if weight_info["dtype"] == "bfloat16":
|
| 63 |
+
weight = weight.view(torch.bfloat16)
|
| 64 |
+
|
| 65 |
+
return weight
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def save_offload_index(index, offload_folder):
|
| 69 |
+
if index is None or len(index) == 0:
|
| 70 |
+
# Nothing to save
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
offload_index_file = os.path.join(offload_folder, "index.json")
|
| 74 |
+
if os.path.isfile(offload_index_file):
|
| 75 |
+
with open(offload_index_file, encoding="utf-8") as f:
|
| 76 |
+
current_index = json.load(f)
|
| 77 |
+
else:
|
| 78 |
+
current_index = {}
|
| 79 |
+
current_index.update(index)
|
| 80 |
+
|
| 81 |
+
with open(offload_index_file, "w", encoding="utf-8") as f:
|
| 82 |
+
json.dump(current_index, f, indent=2)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: dict[str, torch.Tensor]):
|
| 86 |
+
"""
|
| 87 |
+
Offload a state dict in a given folder.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
save_dir (`str` or `os.PathLike`):
|
| 91 |
+
The directory in which to offload the state dict.
|
| 92 |
+
state_dict (`Dict[str, torch.Tensor]`):
|
| 93 |
+
The dictionary of tensors to offload.
|
| 94 |
+
"""
|
| 95 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 96 |
+
index = {}
|
| 97 |
+
for name, parameter in state_dict.items():
|
| 98 |
+
index = offload_weight(parameter, name, save_dir, index=index)
|
| 99 |
+
|
| 100 |
+
# Update index
|
| 101 |
+
save_offload_index(index, save_dir)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class PrefixedDataset(Mapping):
|
| 105 |
+
"""
|
| 106 |
+
Will access keys in a given dataset by adding a prefix.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
dataset (`Mapping`): Any map with string keys.
|
| 110 |
+
prefix (`str`): A prefix to add when trying to access any element in the underlying dataset.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, dataset: Mapping, prefix: str):
|
| 114 |
+
self.dataset = dataset
|
| 115 |
+
self.prefix = prefix
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, key):
|
| 118 |
+
return self.dataset[f"{self.prefix}{key}"]
|
| 119 |
+
|
| 120 |
+
def __iter__(self):
|
| 121 |
+
return iter([key for key in self.dataset if key.startswith(self.prefix)])
|
| 122 |
+
|
| 123 |
+
def __len__(self):
|
| 124 |
+
return len(self.dataset)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class OffloadedWeightsLoader(Mapping):
|
| 128 |
+
"""
|
| 129 |
+
A collection that loads weights stored in a given state dict or memory-mapped on disk.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
| 133 |
+
A dictionary parameter name to tensor.
|
| 134 |
+
save_folder (`str` or `os.PathLike`, *optional*):
|
| 135 |
+
The directory in which the weights are stored (by `offload_state_dict` for instance).
|
| 136 |
+
index (`Dict`, *optional*):
|
| 137 |
+
A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
|
| 138 |
+
to the index saved in `save_folder`.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
state_dict: Optional[dict[str, torch.Tensor]] = None,
|
| 144 |
+
save_folder: Optional[Union[str, os.PathLike]] = None,
|
| 145 |
+
index: Optional[Mapping] = None,
|
| 146 |
+
device=None,
|
| 147 |
+
):
|
| 148 |
+
if state_dict is None and save_folder is None and index is None:
|
| 149 |
+
raise ValueError("Need either a `state_dict`, a `save_folder` or an `index` containing offloaded weights.")
|
| 150 |
+
|
| 151 |
+
self.state_dict = {} if state_dict is None else state_dict
|
| 152 |
+
self.save_folder = save_folder
|
| 153 |
+
if index is None and save_folder is not None:
|
| 154 |
+
with open(os.path.join(save_folder, "index.json")) as f:
|
| 155 |
+
index = json.load(f)
|
| 156 |
+
self.index = {} if index is None else index
|
| 157 |
+
self.all_keys = list(self.state_dict.keys())
|
| 158 |
+
self.all_keys.extend([key for key in self.index if key not in self.all_keys])
|
| 159 |
+
self.device = device
|
| 160 |
+
|
| 161 |
+
def __getitem__(self, key: str):
|
| 162 |
+
# State dict gets priority
|
| 163 |
+
if key in self.state_dict:
|
| 164 |
+
return self.state_dict[key]
|
| 165 |
+
weight_info = self.index[key]
|
| 166 |
+
if weight_info.get("safetensors_file") is not None:
|
| 167 |
+
device = "cpu" if self.device is None else self.device
|
| 168 |
+
tensor = None
|
| 169 |
+
try:
|
| 170 |
+
with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f:
|
| 171 |
+
tensor = f.get_tensor(weight_info.get("weight_name", key))
|
| 172 |
+
except TypeError:
|
| 173 |
+
# if failed to get_tensor on the device, such as bf16 on mps, try to load it on CPU first
|
| 174 |
+
with safe_open(weight_info["safetensors_file"], framework="pt", device="cpu") as f:
|
| 175 |
+
tensor = f.get_tensor(weight_info.get("weight_name", key))
|
| 176 |
+
|
| 177 |
+
if "dtype" in weight_info:
|
| 178 |
+
tensor = tensor.to(getattr(torch, weight_info["dtype"]))
|
| 179 |
+
|
| 180 |
+
if tensor.device != torch.device(device):
|
| 181 |
+
tensor = tensor.to(device)
|
| 182 |
+
return tensor
|
| 183 |
+
|
| 184 |
+
weight_file = os.path.join(self.save_folder, f"{key}.dat")
|
| 185 |
+
return load_offloaded_weight(weight_file, weight_info)
|
| 186 |
+
|
| 187 |
+
def __iter__(self):
|
| 188 |
+
return iter(self.all_keys)
|
| 189 |
+
|
| 190 |
+
def __len__(self):
|
| 191 |
+
return len(self.all_keys)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def extract_submodules_state_dict(state_dict: dict[str, torch.Tensor], submodule_names: list[str]):
|
| 195 |
+
"""
|
| 196 |
+
Extract the sub state-dict corresponding to a list of given submodules.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
state_dict (`Dict[str, torch.Tensor]`): The state dict to extract from.
|
| 200 |
+
submodule_names (`List[str]`): The list of submodule names we want to extract.
|
| 201 |
+
"""
|
| 202 |
+
result = {}
|
| 203 |
+
for module_name in submodule_names:
|
| 204 |
+
# We want to catch module_name parameter (module_name.xxx) or potentially module_name, but not any of the
|
| 205 |
+
# submodules that could being like module_name (transformers.h.1 and transformers.h.10 for instance)
|
| 206 |
+
result.update(
|
| 207 |
+
{
|
| 208 |
+
key: param
|
| 209 |
+
for key, param in state_dict.items()
|
| 210 |
+
if key == module_name or key.startswith(module_name + ".")
|
| 211 |
+
}
|
| 212 |
+
)
|
| 213 |
+
return result
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
A set of basic tensor ops compatible with tpu, gpu, and multigpu
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import pickle
|
| 19 |
+
import warnings
|
| 20 |
+
from collections.abc import Mapping
|
| 21 |
+
from contextlib import contextmanager, nullcontext
|
| 22 |
+
from functools import update_wrapper, wraps
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from ..state import AcceleratorState, PartialState
|
| 28 |
+
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
|
| 29 |
+
from .dataclasses import DistributedType, TensorInformation
|
| 30 |
+
from .imports import (
|
| 31 |
+
is_npu_available,
|
| 32 |
+
is_torch_distributed_available,
|
| 33 |
+
is_torch_xla_available,
|
| 34 |
+
)
|
| 35 |
+
from .versions import is_torch_version
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if is_torch_xla_available():
|
| 39 |
+
import torch_xla.core.xla_model as xm
|
| 40 |
+
|
| 41 |
+
if is_torch_distributed_available():
|
| 42 |
+
from torch.distributed import ReduceOp
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def is_torch_tensor(tensor):
|
| 46 |
+
return isinstance(tensor, torch.Tensor)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def is_torch_xpu_tensor(tensor):
|
| 50 |
+
return isinstance(
|
| 51 |
+
tensor,
|
| 52 |
+
torch.xpu.FloatTensor,
|
| 53 |
+
torch.xpu.ByteTensor,
|
| 54 |
+
torch.xpu.IntTensor,
|
| 55 |
+
torch.xpu.LongTensor,
|
| 56 |
+
torch.xpu.HalfTensor,
|
| 57 |
+
torch.xpu.DoubleTensor,
|
| 58 |
+
torch.xpu.BFloat16Tensor,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def is_tensor_information(tensor_info):
|
| 63 |
+
return isinstance(tensor_info, TensorInformation)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def is_namedtuple(data):
|
| 67 |
+
"""
|
| 68 |
+
Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a
|
| 69 |
+
`namedtuple` perfectly.
|
| 70 |
+
"""
|
| 71 |
+
return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def honor_type(obj, generator):
|
| 75 |
+
"""
|
| 76 |
+
Cast a generator to the same type as obj (list, tuple, or namedtuple)
|
| 77 |
+
"""
|
| 78 |
+
# Some objects may not be able to instantiate from a generator directly
|
| 79 |
+
if is_namedtuple(obj):
|
| 80 |
+
return type(obj)(*list(generator))
|
| 81 |
+
else:
|
| 82 |
+
return type(obj)(generator)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
func (`callable`):
|
| 91 |
+
The function to recursively apply.
|
| 92 |
+
data (nested list/tuple/dictionary of `main_type`):
|
| 93 |
+
The data on which to apply `func`
|
| 94 |
+
*args:
|
| 95 |
+
Positional arguments that will be passed to `func` when applied on the unpacked data.
|
| 96 |
+
main_type (`type`, *optional*, defaults to `torch.Tensor`):
|
| 97 |
+
The base type of the objects to which apply `func`.
|
| 98 |
+
error_on_other_type (`bool`, *optional*, defaults to `False`):
|
| 99 |
+
Whether to return an error or not if after unpacking `data`, we get on an object that is not of type
|
| 100 |
+
`main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged.
|
| 101 |
+
**kwargs (additional keyword arguments, *optional*):
|
| 102 |
+
Keyword arguments that will be passed to `func` when applied on the unpacked data.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
The same data structure as `data` with `func` applied to every object of type `main_type`.
|
| 106 |
+
"""
|
| 107 |
+
if isinstance(data, (tuple, list)):
|
| 108 |
+
return honor_type(
|
| 109 |
+
data,
|
| 110 |
+
(
|
| 111 |
+
recursively_apply(
|
| 112 |
+
func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
|
| 113 |
+
)
|
| 114 |
+
for o in data
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
elif isinstance(data, Mapping):
|
| 118 |
+
return type(data)(
|
| 119 |
+
{
|
| 120 |
+
k: recursively_apply(
|
| 121 |
+
func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
|
| 122 |
+
)
|
| 123 |
+
for k, v in data.items()
|
| 124 |
+
}
|
| 125 |
+
)
|
| 126 |
+
elif test_type(data):
|
| 127 |
+
return func(data, *args, **kwargs)
|
| 128 |
+
elif error_on_other_type:
|
| 129 |
+
raise TypeError(
|
| 130 |
+
f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of "
|
| 131 |
+
f"objects that are valid for `{test_type.__name__}` should be passed."
|
| 132 |
+
)
|
| 133 |
+
return data
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
|
| 137 |
+
"""
|
| 138 |
+
Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
| 142 |
+
The data to send to a given device.
|
| 143 |
+
device (`torch.device`):
|
| 144 |
+
The device to send the data to.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
The same data structure as `tensor` with all tensors sent to the proper device.
|
| 148 |
+
"""
|
| 149 |
+
if is_torch_tensor(tensor) or hasattr(tensor, "to"):
|
| 150 |
+
# `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)).
|
| 151 |
+
if device == "npu":
|
| 152 |
+
device = "npu:0"
|
| 153 |
+
try:
|
| 154 |
+
return tensor.to(device, non_blocking=non_blocking)
|
| 155 |
+
except TypeError: # .to() doesn't accept non_blocking as kwarg
|
| 156 |
+
return tensor.to(device)
|
| 157 |
+
except AssertionError as error:
|
| 158 |
+
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
|
| 159 |
+
# This call is inside the try-block since is_npu_available is not supported by torch.compile.
|
| 160 |
+
if is_npu_available():
|
| 161 |
+
if isinstance(device, int):
|
| 162 |
+
device = f"npu:{device}"
|
| 163 |
+
else:
|
| 164 |
+
raise error
|
| 165 |
+
try:
|
| 166 |
+
return tensor.to(device, non_blocking=non_blocking)
|
| 167 |
+
except TypeError: # .to() doesn't accept non_blocking as kwarg
|
| 168 |
+
return tensor.to(device)
|
| 169 |
+
elif isinstance(tensor, (tuple, list)):
|
| 170 |
+
return honor_type(
|
| 171 |
+
tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
|
| 172 |
+
)
|
| 173 |
+
elif isinstance(tensor, Mapping):
|
| 174 |
+
if isinstance(skip_keys, str):
|
| 175 |
+
skip_keys = [skip_keys]
|
| 176 |
+
elif skip_keys is None:
|
| 177 |
+
skip_keys = []
|
| 178 |
+
return type(tensor)(
|
| 179 |
+
{
|
| 180 |
+
k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys)
|
| 181 |
+
for k, t in tensor.items()
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
return tensor
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_data_structure(data):
|
| 189 |
+
"""
|
| 190 |
+
Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
data (nested list/tuple/dictionary of `torch.Tensor`):
|
| 194 |
+
The data to send to analyze.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def _get_data_structure(tensor):
|
| 201 |
+
return TensorInformation(shape=tensor.shape, dtype=tensor.dtype)
|
| 202 |
+
|
| 203 |
+
return recursively_apply(_get_data_structure, data)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_shape(data):
|
| 207 |
+
"""
|
| 208 |
+
Recursively gathers the shape of a nested list/tuple/dictionary of tensors as a list.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
data (nested list/tuple/dictionary of `torch.Tensor`):
|
| 212 |
+
The data to send to analyze.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
The same data structure as `data` with lists of tensor shapes instead of tensors.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def _get_shape(tensor):
|
| 219 |
+
return list(tensor.shape)
|
| 220 |
+
|
| 221 |
+
return recursively_apply(_get_shape, data)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def initialize_tensors(data_structure):
|
| 225 |
+
"""
|
| 226 |
+
Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def _initialize_tensor(tensor_info):
|
| 233 |
+
return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype)
|
| 234 |
+
|
| 235 |
+
return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def find_batch_size(data):
|
| 239 |
+
"""
|
| 240 |
+
Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
`int`: The batch size.
|
| 247 |
+
"""
|
| 248 |
+
if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0):
|
| 249 |
+
raise ValueError(f"Cannot find the batch size from empty {type(data)}.")
|
| 250 |
+
|
| 251 |
+
if isinstance(data, (tuple, list)):
|
| 252 |
+
return find_batch_size(data[0])
|
| 253 |
+
elif isinstance(data, Mapping):
|
| 254 |
+
for k in data.keys():
|
| 255 |
+
return find_batch_size(data[k])
|
| 256 |
+
elif not isinstance(data, torch.Tensor):
|
| 257 |
+
raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.")
|
| 258 |
+
return data.shape[0]
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def ignorant_find_batch_size(data):
|
| 262 |
+
"""
|
| 263 |
+
Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
`int`: The batch size.
|
| 270 |
+
"""
|
| 271 |
+
try:
|
| 272 |
+
return find_batch_size(data)
|
| 273 |
+
except (ValueError, TypeError):
|
| 274 |
+
pass
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def listify(data):
|
| 279 |
+
"""
|
| 280 |
+
Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to convert to regular numbers.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
The same data structure as `data` with lists of numbers instead of `torch.Tensor`.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def _convert_to_list(tensor):
|
| 290 |
+
tensor = tensor.detach().cpu()
|
| 291 |
+
if tensor.dtype == torch.bfloat16:
|
| 292 |
+
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
|
| 293 |
+
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
|
| 294 |
+
# Until Numpy adds bfloat16, we must convert float32.
|
| 295 |
+
tensor = tensor.to(torch.float32)
|
| 296 |
+
return tensor.tolist()
|
| 297 |
+
|
| 298 |
+
return recursively_apply(_convert_to_list, data)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _tpu_gather(tensor):
|
| 302 |
+
def _tpu_gather_one(tensor):
|
| 303 |
+
if tensor.ndim == 0:
|
| 304 |
+
tensor = tensor.clone()[None]
|
| 305 |
+
|
| 306 |
+
# Can only gather contiguous tensors
|
| 307 |
+
if not tensor.is_contiguous():
|
| 308 |
+
tensor = tensor.contiguous()
|
| 309 |
+
return xm.all_gather(tensor)
|
| 310 |
+
|
| 311 |
+
res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True)
|
| 312 |
+
xm.mark_step()
|
| 313 |
+
return res
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _gpu_gather(tensor):
|
| 317 |
+
state = PartialState()
|
| 318 |
+
gather_op = torch.distributed.all_gather_into_tensor
|
| 319 |
+
|
| 320 |
+
# NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0
|
| 321 |
+
if state.device.type == "xpu" and is_torch_version("<=", "2.8"):
|
| 322 |
+
torch.xpu.synchronize()
|
| 323 |
+
|
| 324 |
+
def _gpu_gather_one(tensor):
|
| 325 |
+
if tensor.ndim == 0:
|
| 326 |
+
tensor = tensor.clone()[None]
|
| 327 |
+
|
| 328 |
+
# Can only gather contiguous tensors
|
| 329 |
+
if not tensor.is_contiguous():
|
| 330 |
+
tensor = tensor.contiguous()
|
| 331 |
+
|
| 332 |
+
if state.backend is not None and state.backend != "gloo":
|
| 333 |
+
# We use `empty` as `all_gather_into_tensor` slightly
|
| 334 |
+
# differs from `all_gather` for better efficiency,
|
| 335 |
+
# and we rely on the number of items in the tensor
|
| 336 |
+
# rather than its direct shape
|
| 337 |
+
output_tensors = torch.empty(
|
| 338 |
+
state.num_processes * tensor.numel(),
|
| 339 |
+
dtype=tensor.dtype,
|
| 340 |
+
device=state.device,
|
| 341 |
+
)
|
| 342 |
+
gather_op(output_tensors, tensor)
|
| 343 |
+
return output_tensors.view(-1, *tensor.size()[1:])
|
| 344 |
+
else:
|
| 345 |
+
# a backend of `None` is always CPU
|
| 346 |
+
# also gloo does not support `all_gather_into_tensor`,
|
| 347 |
+
# which will result in a larger memory overhead for the op
|
| 348 |
+
output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
|
| 349 |
+
torch.distributed.all_gather(output_tensors, tensor)
|
| 350 |
+
return torch.cat(output_tensors, dim=0)
|
| 351 |
+
|
| 352 |
+
return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class DistributedOperationException(Exception):
|
| 356 |
+
"""
|
| 357 |
+
An exception class for distributed operations. Raised if the operation cannot be performed due to the shape of the
|
| 358 |
+
tensors.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def verify_operation(function):
|
| 365 |
+
"""
|
| 366 |
+
Verifies that `tensor` is the same shape across all processes. Only ran if `PartialState().debug` is `True`.
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
@wraps(function)
|
| 370 |
+
def wrapper(*args, **kwargs):
|
| 371 |
+
if PartialState().distributed_type == DistributedType.NO or not PartialState().debug:
|
| 372 |
+
return function(*args, **kwargs)
|
| 373 |
+
operation = f"{function.__module__}.{function.__name__}"
|
| 374 |
+
if "tensor" in kwargs:
|
| 375 |
+
tensor = kwargs["tensor"]
|
| 376 |
+
else:
|
| 377 |
+
tensor = args[0]
|
| 378 |
+
if PartialState().device.type != find_device(tensor).type:
|
| 379 |
+
raise DistributedOperationException(
|
| 380 |
+
f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. "
|
| 381 |
+
f"Please move it to the {PartialState().device.type} before calling {operation}."
|
| 382 |
+
)
|
| 383 |
+
shapes = get_shape(tensor)
|
| 384 |
+
output = gather_object([shapes])
|
| 385 |
+
if output[0] is not None:
|
| 386 |
+
are_same = output.count(output[0]) == len(output)
|
| 387 |
+
if not are_same:
|
| 388 |
+
process_shape_str = "\n - ".join([f"Process {i}: {shape}" for i, shape in enumerate(output)])
|
| 389 |
+
raise DistributedOperationException(
|
| 390 |
+
f"Cannot apply desired operation due to shape mismatches. "
|
| 391 |
+
"All shapes across devices must be valid."
|
| 392 |
+
f"\n\nOperation: `{operation}`\nInput shapes:\n - {process_shape_str}"
|
| 393 |
+
)
|
| 394 |
+
return function(*args, **kwargs)
|
| 395 |
+
|
| 396 |
+
return wrapper
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def chained_operation(function):
|
| 400 |
+
"""
|
| 401 |
+
Checks that `verify_operation` failed and if so reports a more helpful error chaining the existing
|
| 402 |
+
`DistributedOperationException`.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
@wraps(function)
|
| 406 |
+
def wrapper(*args, **kwargs):
|
| 407 |
+
try:
|
| 408 |
+
return function(*args, **kwargs)
|
| 409 |
+
except DistributedOperationException as e:
|
| 410 |
+
operation = f"{function.__module__}.{function.__name__}"
|
| 411 |
+
raise DistributedOperationException(
|
| 412 |
+
f"Error found while calling `{operation}`. Please see the earlier error for more details."
|
| 413 |
+
) from e
|
| 414 |
+
|
| 415 |
+
return wrapper
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
@verify_operation
|
| 419 |
+
def gather(tensor):
|
| 420 |
+
"""
|
| 421 |
+
Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
| 425 |
+
The data to gather.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
The same data structure as `tensor` with all tensors sent to the proper device.
|
| 429 |
+
"""
|
| 430 |
+
if PartialState().distributed_type == DistributedType.XLA:
|
| 431 |
+
return _tpu_gather(tensor)
|
| 432 |
+
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
|
| 433 |
+
return _gpu_gather(tensor)
|
| 434 |
+
else:
|
| 435 |
+
return tensor
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def _gpu_gather_object(object: Any):
|
| 439 |
+
output_objects = [None for _ in range(PartialState().num_processes)]
|
| 440 |
+
torch.distributed.all_gather_object(output_objects, object)
|
| 441 |
+
# all_gather_object returns a list of lists, so we need to flatten it
|
| 442 |
+
return [x for y in output_objects for x in y]
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def gather_object(object: Any):
|
| 446 |
+
"""
|
| 447 |
+
Recursively gather object in a nested list/tuple/dictionary of objects from all devices.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
object (nested list/tuple/dictionary of picklable object):
|
| 451 |
+
The data to gather.
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
The same data structure as `object` with all the objects sent to every device.
|
| 455 |
+
"""
|
| 456 |
+
if PartialState().distributed_type == DistributedType.XLA:
|
| 457 |
+
raise NotImplementedError("gather objects in TPU is not supported")
|
| 458 |
+
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
|
| 459 |
+
return _gpu_gather_object(object)
|
| 460 |
+
else:
|
| 461 |
+
return object
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _gpu_broadcast(data, src=0):
|
| 465 |
+
def _gpu_broadcast_one(tensor, src=0):
|
| 466 |
+
torch.distributed.broadcast(tensor, src=src)
|
| 467 |
+
return tensor
|
| 468 |
+
|
| 469 |
+
return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
|
| 473 |
+
if isinstance(tensor, (list, tuple)):
|
| 474 |
+
return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
|
| 475 |
+
elif isinstance(tensor, Mapping):
|
| 476 |
+
return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()})
|
| 477 |
+
return xm.mesh_reduce(name, tensor, lambda x: x[src])
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
TENSOR_TYPE_TO_INT = {
|
| 481 |
+
torch.float: 1,
|
| 482 |
+
torch.double: 2,
|
| 483 |
+
torch.half: 3,
|
| 484 |
+
torch.bfloat16: 4,
|
| 485 |
+
torch.uint8: 5,
|
| 486 |
+
torch.int8: 6,
|
| 487 |
+
torch.int16: 7,
|
| 488 |
+
torch.int32: 8,
|
| 489 |
+
torch.int64: 9,
|
| 490 |
+
torch.bool: 10,
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
TENSOR_INT_TO_DTYPE = {v: k for k, v in TENSOR_TYPE_TO_INT.items()}
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def gather_tensor_shape(tensor):
|
| 497 |
+
"""
|
| 498 |
+
Grabs the shape of `tensor` only available on one process and returns a tensor of its shape
|
| 499 |
+
"""
|
| 500 |
+
# Allocate 80 bytes to store the shape
|
| 501 |
+
max_tensor_dimension = 2**20
|
| 502 |
+
state = PartialState()
|
| 503 |
+
base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device)
|
| 504 |
+
|
| 505 |
+
# Since PyTorch can't just send a tensor to another GPU without
|
| 506 |
+
# knowing its size, we store the size of the tensor with data
|
| 507 |
+
# in an allocation
|
| 508 |
+
if tensor is not None:
|
| 509 |
+
shape = tensor.shape
|
| 510 |
+
tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype]
|
| 511 |
+
base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int)
|
| 512 |
+
# Perform a reduction to copy the size data onto all GPUs
|
| 513 |
+
base_tensor = reduce(base_tensor, reduction="sum")
|
| 514 |
+
base_tensor = base_tensor[base_tensor.nonzero()]
|
| 515 |
+
# The last non-zero data contains the coded dtype the source tensor is
|
| 516 |
+
dtype = int(base_tensor[-1:][0])
|
| 517 |
+
base_tensor = base_tensor[:-1]
|
| 518 |
+
return base_tensor, dtype
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
|
| 522 |
+
"""
|
| 523 |
+
Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
|
| 524 |
+
each worker doesn't need to know its shape when used (and tensor can be `None`)
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
tensor (`torch.tensor`):
|
| 528 |
+
The tensor that should be sent to all devices. Must only have it be defined on a single device, the rest
|
| 529 |
+
should be `None`.
|
| 530 |
+
"""
|
| 531 |
+
state = PartialState()
|
| 532 |
+
shape, dtype = gather_tensor_shape(tensor)
|
| 533 |
+
if tensor is None:
|
| 534 |
+
tensor = torch.zeros(shape, dtype=TENSOR_INT_TO_DTYPE[dtype]).to(state.device)
|
| 535 |
+
return reduce(tensor, reduction="sum")
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
@verify_operation
|
| 539 |
+
def broadcast(tensor, from_process: int = 0):
|
| 540 |
+
"""
|
| 541 |
+
Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
| 545 |
+
The data to gather.
|
| 546 |
+
from_process (`int`, *optional*, defaults to 0):
|
| 547 |
+
The process from which to send the data
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
The same data structure as `tensor` with all tensors broadcasted to the proper device.
|
| 551 |
+
"""
|
| 552 |
+
if PartialState().distributed_type == DistributedType.XLA:
|
| 553 |
+
return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast")
|
| 554 |
+
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
|
| 555 |
+
return _gpu_broadcast(tensor, src=from_process)
|
| 556 |
+
else:
|
| 557 |
+
return tensor
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def broadcast_object_list(object_list, from_process: int = 0):
|
| 561 |
+
"""
|
| 562 |
+
Broadcast a list of picklable objects from one process to the others.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
object_list (list of picklable objects):
|
| 566 |
+
The list of objects to broadcast. This list will be modified inplace.
|
| 567 |
+
from_process (`int`, *optional*, defaults to 0):
|
| 568 |
+
The process from which to send the data.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
The same list containing the objects from process 0.
|
| 572 |
+
"""
|
| 573 |
+
if PartialState().distributed_type == DistributedType.XLA:
|
| 574 |
+
for i, obj in enumerate(object_list):
|
| 575 |
+
object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process])
|
| 576 |
+
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
|
| 577 |
+
torch.distributed.broadcast_object_list(object_list, src=from_process)
|
| 578 |
+
return object_list
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def slice_tensors(data, tensor_slice, process_index=None, num_processes=None):
|
| 582 |
+
"""
|
| 583 |
+
Recursively takes a slice in a nested list/tuple/dictionary of tensors.
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
data (nested list/tuple/dictionary of `torch.Tensor`):
|
| 587 |
+
The data to slice.
|
| 588 |
+
tensor_slice (`slice`):
|
| 589 |
+
The slice to take.
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
The same data structure as `data` with all the tensors slices.
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
+
def _slice_tensor(tensor, tensor_slice):
|
| 596 |
+
return tensor[tensor_slice]
|
| 597 |
+
|
| 598 |
+
return recursively_apply(_slice_tensor, data, tensor_slice)
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def concatenate(data, dim=0):
|
| 602 |
+
"""
|
| 603 |
+
Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
|
| 607 |
+
The data to concatenate.
|
| 608 |
+
dim (`int`, *optional*, defaults to 0):
|
| 609 |
+
The dimension on which to concatenate.
|
| 610 |
+
|
| 611 |
+
Returns:
|
| 612 |
+
The same data structure as `data` with all the tensors concatenated.
|
| 613 |
+
"""
|
| 614 |
+
if isinstance(data[0], (tuple, list)):
|
| 615 |
+
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
|
| 616 |
+
elif isinstance(data[0], Mapping):
|
| 617 |
+
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
|
| 618 |
+
elif not isinstance(data[0], torch.Tensor):
|
| 619 |
+
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
|
| 620 |
+
return torch.cat(data, dim=dim)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class CannotPadNestedTensorWarning(UserWarning):
|
| 624 |
+
pass
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
@chained_operation
|
| 628 |
+
def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
|
| 629 |
+
"""
|
| 630 |
+
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they
|
| 631 |
+
can safely be gathered.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
| 635 |
+
The data to gather.
|
| 636 |
+
dim (`int`, *optional*, defaults to 0):
|
| 637 |
+
The dimension on which to pad.
|
| 638 |
+
pad_index (`int`, *optional*, defaults to 0):
|
| 639 |
+
The value with which to pad.
|
| 640 |
+
pad_first (`bool`, *optional*, defaults to `False`):
|
| 641 |
+
Whether to pad at the beginning or the end.
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
|
| 645 |
+
if getattr(tensor, "is_nested", False):
|
| 646 |
+
warnings.warn(
|
| 647 |
+
"Cannot pad nested tensors without more information. Leaving unprocessed.",
|
| 648 |
+
CannotPadNestedTensorWarning,
|
| 649 |
+
)
|
| 650 |
+
return tensor
|
| 651 |
+
if dim >= len(tensor.shape) or dim < -len(tensor.shape):
|
| 652 |
+
return tensor
|
| 653 |
+
# Convert negative dimensions to non-negative
|
| 654 |
+
if dim < 0:
|
| 655 |
+
dim += len(tensor.shape)
|
| 656 |
+
|
| 657 |
+
# Gather all sizes
|
| 658 |
+
size = torch.tensor(tensor.shape, device=tensor.device)[None]
|
| 659 |
+
sizes = gather(size).cpu()
|
| 660 |
+
# Then pad to the maximum size
|
| 661 |
+
max_size = max(s[dim] for s in sizes)
|
| 662 |
+
if max_size == tensor.shape[dim]:
|
| 663 |
+
return tensor
|
| 664 |
+
|
| 665 |
+
old_size = tensor.shape
|
| 666 |
+
new_size = list(old_size)
|
| 667 |
+
new_size[dim] = max_size
|
| 668 |
+
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
|
| 669 |
+
if pad_first:
|
| 670 |
+
indices = tuple(
|
| 671 |
+
slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size))
|
| 672 |
+
)
|
| 673 |
+
else:
|
| 674 |
+
indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
|
| 675 |
+
new_tensor[indices] = tensor
|
| 676 |
+
return new_tensor
|
| 677 |
+
|
| 678 |
+
return recursively_apply(
|
| 679 |
+
_pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def pad_input_tensors(tensor, batch_size, num_processes, dim=0):
|
| 684 |
+
"""
|
| 685 |
+
Takes a `tensor` of arbitrary size and pads it so that it can work given `num_processes` needed dimensions.
|
| 686 |
+
|
| 687 |
+
New tensors are just the last input repeated.
|
| 688 |
+
|
| 689 |
+
E.g.:
|
| 690 |
+
Tensor: ([3,4,4]) Num processes: 4 Expected result shape: ([4,4,4])
|
| 691 |
+
|
| 692 |
+
"""
|
| 693 |
+
|
| 694 |
+
def _pad_input_tensors(tensor, batch_size, num_processes, dim=0):
|
| 695 |
+
remainder = batch_size // num_processes
|
| 696 |
+
last_inputs = batch_size - (remainder * num_processes)
|
| 697 |
+
if batch_size // num_processes == 0:
|
| 698 |
+
to_pad = num_processes - batch_size
|
| 699 |
+
else:
|
| 700 |
+
to_pad = num_processes - (batch_size // num_processes)
|
| 701 |
+
# In the rare case that `to_pad` is negative,
|
| 702 |
+
# we need to pad the last inputs - the found `to_pad`
|
| 703 |
+
if last_inputs > to_pad & to_pad < 1:
|
| 704 |
+
to_pad = last_inputs - to_pad
|
| 705 |
+
old_size = tensor.shape
|
| 706 |
+
new_size = list(old_size)
|
| 707 |
+
new_size[0] = batch_size + to_pad
|
| 708 |
+
new_tensor = tensor.new_zeros(tuple(new_size))
|
| 709 |
+
indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
|
| 710 |
+
new_tensor[indices] = tensor
|
| 711 |
+
return new_tensor
|
| 712 |
+
|
| 713 |
+
return recursively_apply(
|
| 714 |
+
_pad_input_tensors,
|
| 715 |
+
tensor,
|
| 716 |
+
error_on_other_type=True,
|
| 717 |
+
batch_size=batch_size,
|
| 718 |
+
num_processes=num_processes,
|
| 719 |
+
dim=dim,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
@verify_operation
|
| 724 |
+
def reduce(tensor, reduction="mean", scale=1.0):
|
| 725 |
+
"""
|
| 726 |
+
Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the
|
| 727 |
+
mean of a given operation.
|
| 728 |
+
|
| 729 |
+
Args:
|
| 730 |
+
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
| 731 |
+
The data to reduce.
|
| 732 |
+
reduction (`str`, *optional*, defaults to `"mean"`):
|
| 733 |
+
A reduction method. Can be of "mean", "sum", or "none"
|
| 734 |
+
scale (`float`, *optional*):
|
| 735 |
+
A default scaling value to be applied after the reduce, only valid on XLA.
|
| 736 |
+
|
| 737 |
+
Returns:
|
| 738 |
+
The same data structure as `data` with all the tensors reduced.
|
| 739 |
+
"""
|
| 740 |
+
|
| 741 |
+
def _reduce_across_processes(tensor, reduction="mean", scale=1.0):
|
| 742 |
+
state = PartialState()
|
| 743 |
+
cloned_tensor = tensor.clone()
|
| 744 |
+
if state.distributed_type == DistributedType.NO:
|
| 745 |
+
return cloned_tensor
|
| 746 |
+
if state.distributed_type == DistributedType.XLA:
|
| 747 |
+
# Some processes may have different HLO graphs than other
|
| 748 |
+
# processes, for example in the breakpoint API
|
| 749 |
+
# accelerator.set_trigger(). Use mark_step to make HLOs
|
| 750 |
+
# the same on all processes.
|
| 751 |
+
xm.mark_step()
|
| 752 |
+
xm.all_reduce(xm.REDUCE_SUM, [cloned_tensor], scale)
|
| 753 |
+
xm.mark_step()
|
| 754 |
+
elif state.distributed_type.value in TORCH_DISTRIBUTED_OPERATION_TYPES:
|
| 755 |
+
torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM)
|
| 756 |
+
if reduction == "mean":
|
| 757 |
+
cloned_tensor /= state.num_processes
|
| 758 |
+
return cloned_tensor
|
| 759 |
+
|
| 760 |
+
return recursively_apply(
|
| 761 |
+
_reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction, scale=scale
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def convert_to_fp32(tensor):
|
| 766 |
+
"""
|
| 767 |
+
Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
| 771 |
+
The data to convert from FP16/BF16 to FP32.
|
| 772 |
+
|
| 773 |
+
Returns:
|
| 774 |
+
The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32.
|
| 775 |
+
"""
|
| 776 |
+
|
| 777 |
+
def _convert_to_fp32(tensor):
|
| 778 |
+
return tensor.float()
|
| 779 |
+
|
| 780 |
+
def _is_fp16_bf16_tensor(tensor):
|
| 781 |
+
return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in (
|
| 782 |
+
torch.float16,
|
| 783 |
+
torch.bfloat16,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
class ConvertOutputsToFp32:
|
| 790 |
+
"""
|
| 791 |
+
Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16
|
| 792 |
+
precision will be convert back to FP32.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
model_forward (`Callable`):
|
| 796 |
+
The function which outputs we want to treat.
|
| 797 |
+
|
| 798 |
+
Returns:
|
| 799 |
+
The same function as `model_forward` but with converted outputs.
|
| 800 |
+
"""
|
| 801 |
+
|
| 802 |
+
def __init__(self, model_forward):
|
| 803 |
+
self.model_forward = model_forward
|
| 804 |
+
update_wrapper(self, model_forward)
|
| 805 |
+
|
| 806 |
+
def __call__(self, *args, **kwargs):
|
| 807 |
+
return convert_to_fp32(self.model_forward(*args, **kwargs))
|
| 808 |
+
|
| 809 |
+
def __getstate__(self):
|
| 810 |
+
raise pickle.PicklingError(
|
| 811 |
+
"Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it."
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def convert_outputs_to_fp32(model_forward):
|
| 816 |
+
model_forward = ConvertOutputsToFp32(model_forward)
|
| 817 |
+
|
| 818 |
+
def forward(*args, **kwargs):
|
| 819 |
+
return model_forward(*args, **kwargs)
|
| 820 |
+
|
| 821 |
+
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
| 822 |
+
forward.__wrapped__ = model_forward
|
| 823 |
+
|
| 824 |
+
return forward
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def find_device(data):
|
| 828 |
+
"""
|
| 829 |
+
Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
|
| 830 |
+
|
| 831 |
+
Args:
|
| 832 |
+
(nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
|
| 833 |
+
"""
|
| 834 |
+
if isinstance(data, Mapping):
|
| 835 |
+
for obj in data.values():
|
| 836 |
+
device = find_device(obj)
|
| 837 |
+
if device is not None:
|
| 838 |
+
return device
|
| 839 |
+
elif isinstance(data, (tuple, list)):
|
| 840 |
+
for obj in data:
|
| 841 |
+
device = find_device(obj)
|
| 842 |
+
if device is not None:
|
| 843 |
+
return device
|
| 844 |
+
elif isinstance(data, torch.Tensor):
|
| 845 |
+
return data.device
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
@contextmanager
|
| 849 |
+
def GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True):
|
| 850 |
+
"""
|
| 851 |
+
Wrapper around `deepspeed.runtime.zero.GatheredParameters`, but if Zero-3 is not enabled, will be a no-op context
|
| 852 |
+
manager.
|
| 853 |
+
"""
|
| 854 |
+
# We need to use the `AcceleratorState` here since it has access to the deepspeed plugin
|
| 855 |
+
if AcceleratorState().distributed_type != DistributedType.DEEPSPEED or (
|
| 856 |
+
AcceleratorState().deepspeed_plugin is not None
|
| 857 |
+
and not AcceleratorState().deepspeed_plugin.is_zero3_init_enabled()
|
| 858 |
+
):
|
| 859 |
+
gather_param_context = nullcontext()
|
| 860 |
+
else:
|
| 861 |
+
import deepspeed
|
| 862 |
+
|
| 863 |
+
gather_param_context = deepspeed.zero.GatheredParameters(
|
| 864 |
+
params, modifier_rank=modifier_rank, fwd_module=fwd_module, enabled=enabled
|
| 865 |
+
)
|
| 866 |
+
with gather_param_context:
|
| 867 |
+
yield
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/other.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import collections
|
| 16 |
+
import platform
|
| 17 |
+
import re
|
| 18 |
+
import socket
|
| 19 |
+
from codecs import encode
|
| 20 |
+
from collections import OrderedDict
|
| 21 |
+
from functools import partial, reduce
|
| 22 |
+
from types import MethodType
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
from packaging.version import Version
|
| 28 |
+
from safetensors.torch import save_file as safe_save_file
|
| 29 |
+
|
| 30 |
+
from ..commands.config.default import write_basic_config # noqa: F401
|
| 31 |
+
from ..logging import get_logger
|
| 32 |
+
from ..state import PartialState
|
| 33 |
+
from .constants import FSDP_PYTORCH_VERSION
|
| 34 |
+
from .dataclasses import DistributedType
|
| 35 |
+
from .imports import (
|
| 36 |
+
is_deepspeed_available,
|
| 37 |
+
is_numpy_available,
|
| 38 |
+
is_torch_distributed_available,
|
| 39 |
+
is_torch_xla_available,
|
| 40 |
+
is_weights_only_available,
|
| 41 |
+
)
|
| 42 |
+
from .modeling import id_tensor_storage
|
| 43 |
+
from .transformer_engine import convert_model
|
| 44 |
+
from .versions import is_torch_version
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_torch_xla_available():
|
| 51 |
+
import torch_xla.core.xla_model as xm
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def is_compiled_module(module: torch.nn.Module) -> bool:
|
| 55 |
+
"""
|
| 56 |
+
Check whether the module was compiled with torch.compile()
|
| 57 |
+
"""
|
| 58 |
+
if not hasattr(torch, "_dynamo"):
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def has_compiled_regions(module: torch.nn.Module) -> bool:
|
| 65 |
+
"""
|
| 66 |
+
Check whether the module has submodules that were compiled with `torch.compile()`.
|
| 67 |
+
"""
|
| 68 |
+
if not hasattr(torch, "_dynamo"):
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
if module._modules:
|
| 72 |
+
for submodule in module.modules():
|
| 73 |
+
if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule):
|
| 74 |
+
return True
|
| 75 |
+
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def is_repeated_blocks(module: torch.nn.Module) -> bool:
|
| 80 |
+
"""
|
| 81 |
+
Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This
|
| 82 |
+
is useful to determine whether we should apply regional compilation to the module.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def has_repeated_blocks(module: torch.nn.Module) -> bool:
|
| 89 |
+
"""
|
| 90 |
+
Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at
|
| 91 |
+
any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the
|
| 92 |
+
module.
|
| 93 |
+
"""
|
| 94 |
+
if module._modules:
|
| 95 |
+
for submodule in module.modules():
|
| 96 |
+
if is_repeated_blocks(submodule):
|
| 97 |
+
return True
|
| 98 |
+
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
|
| 103 |
+
"""
|
| 104 |
+
Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to
|
| 105 |
+
hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be
|
| 106 |
+
accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately.
|
| 107 |
+
|
| 108 |
+
This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general.
|
| 109 |
+
See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
module (`torch.nn.Module`):
|
| 113 |
+
The model to compile.
|
| 114 |
+
**compile_kwargs:
|
| 115 |
+
Additional keyword arguments to pass to `torch.compile()`.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
`torch.nn.Module`: A new instance of the model with some compiled regions.
|
| 119 |
+
|
| 120 |
+
Example:
|
| 121 |
+
```python
|
| 122 |
+
>>> from accelerate.utils import compile_regions
|
| 123 |
+
>>> from transformers import AutoModelForCausalLM
|
| 124 |
+
|
| 125 |
+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 126 |
+
>>> compiled_model = compile_regions(model, mode="reduce-overhead")
|
| 127 |
+
>>> compiled_model.transformer.h[0]
|
| 128 |
+
OptimizedModule(
|
| 129 |
+
(_orig_mod): GPT2Block(
|
| 130 |
+
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 131 |
+
(attn): GPT2Attention(
|
| 132 |
+
(c_attn): Conv1D(nf=2304, nx=768)
|
| 133 |
+
(c_proj): Conv1D(nf=768, nx=768)
|
| 134 |
+
(attn_dropout): Dropout(p=0.1, inplace=False)
|
| 135 |
+
(resid_dropout): Dropout(p=0.1, inplace=False)
|
| 136 |
+
)
|
| 137 |
+
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 138 |
+
(mlp): GPT2MLP(
|
| 139 |
+
(c_fc): Conv1D(nf=3072, nx=768)
|
| 140 |
+
(c_proj): Conv1D(nf=768, nx=3072)
|
| 141 |
+
(act): NewGELUActivation()
|
| 142 |
+
(dropout): Dropout(p=0.1, inplace=False)
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
```
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
|
| 150 |
+
if is_repeated_blocks(module):
|
| 151 |
+
new_module = torch.nn.ModuleList()
|
| 152 |
+
for submodule in module:
|
| 153 |
+
new_module.append(torch.compile(submodule, **compile_kwargs))
|
| 154 |
+
elif has_repeated_blocks(module):
|
| 155 |
+
new_module = module.__class__.__new__(module.__class__)
|
| 156 |
+
new_module.__dict__.update(module.__dict__)
|
| 157 |
+
new_module._modules = {}
|
| 158 |
+
for name, submodule in module.named_children():
|
| 159 |
+
new_module.add_module(name, _compile_regions(submodule, **compile_kwargs))
|
| 160 |
+
else:
|
| 161 |
+
new_module = torch.compile(module, **compile_kwargs)
|
| 162 |
+
|
| 163 |
+
return new_module
|
| 164 |
+
|
| 165 |
+
new_module = _compile_regions(module, **compile_kwargs)
|
| 166 |
+
|
| 167 |
+
if "_orig_mod" not in new_module.__dict__:
|
| 168 |
+
# Keeps a reference to the original module to decompile/unwrap it later
|
| 169 |
+
new_module.__dict__["_orig_mod"] = module
|
| 170 |
+
|
| 171 |
+
return new_module
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
|
| 175 |
+
"""
|
| 176 |
+
Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`.
|
| 177 |
+
Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that
|
| 178 |
+
`torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method
|
| 179 |
+
instead.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
module (`torch.nn.Module`):
|
| 183 |
+
The model to compile.
|
| 184 |
+
**compile_kwargs:
|
| 185 |
+
Additional keyword arguments to pass to `module.compile()`.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
if is_repeated_blocks(module):
|
| 189 |
+
for submodule in module:
|
| 190 |
+
submodule.compile(**compile_kwargs)
|
| 191 |
+
elif has_repeated_blocks(module):
|
| 192 |
+
for child in module.children():
|
| 193 |
+
compile_regions_deepspeed(child, **compile_kwargs)
|
| 194 |
+
else: # leaf node
|
| 195 |
+
module.compile(**compile_kwargs)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def model_has_dtensor(model: torch.nn.Module) -> bool:
|
| 199 |
+
"""
|
| 200 |
+
Check if the model has DTensor parameters.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
model (`torch.nn.Module`):
|
| 204 |
+
The model to check.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
`bool`: Whether the model has DTensor parameters.
|
| 208 |
+
"""
|
| 209 |
+
if is_torch_version(">=", "2.5.0"):
|
| 210 |
+
from torch.distributed.tensor import DTensor
|
| 211 |
+
else:
|
| 212 |
+
# from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
|
| 213 |
+
from torch.distributed._tensor import DTensor
|
| 214 |
+
|
| 215 |
+
return any(isinstance(p, DTensor) for p in model.parameters())
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def extract_model_from_parallel(
|
| 219 |
+
model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
|
| 220 |
+
):
|
| 221 |
+
"""
|
| 222 |
+
Extract a model from its distributed containers.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
model (`torch.nn.Module`):
|
| 226 |
+
The model to extract.
|
| 227 |
+
keep_fp32_wrapper (`bool`, *optional*):
|
| 228 |
+
Whether to remove mixed precision hooks from the model.
|
| 229 |
+
keep_torch_compile (`bool`, *optional*):
|
| 230 |
+
Whether to unwrap compiled model.
|
| 231 |
+
recursive (`bool`, *optional*, defaults to `False`):
|
| 232 |
+
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
|
| 233 |
+
recursively, not just the top-level distributed containers.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
`torch.nn.Module`: The extracted model.
|
| 237 |
+
"""
|
| 238 |
+
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
|
| 239 |
+
|
| 240 |
+
is_compiled = is_compiled_module(model)
|
| 241 |
+
has_compiled = has_compiled_regions(model)
|
| 242 |
+
|
| 243 |
+
if is_compiled:
|
| 244 |
+
compiled_model = model
|
| 245 |
+
model = model._orig_mod
|
| 246 |
+
elif has_compiled:
|
| 247 |
+
compiled_model = model
|
| 248 |
+
model = model.__dict__["_orig_mod"]
|
| 249 |
+
|
| 250 |
+
if is_deepspeed_available():
|
| 251 |
+
from deepspeed import DeepSpeedEngine
|
| 252 |
+
|
| 253 |
+
options += (DeepSpeedEngine,)
|
| 254 |
+
|
| 255 |
+
if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
|
| 256 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
| 257 |
+
|
| 258 |
+
options += (FSDP,)
|
| 259 |
+
|
| 260 |
+
while isinstance(model, options):
|
| 261 |
+
model = model.module
|
| 262 |
+
|
| 263 |
+
if recursive:
|
| 264 |
+
# This is needed in cases such as using FSDPv2 on XLA
|
| 265 |
+
def _recursive_unwrap(module):
|
| 266 |
+
# Wrapped modules are standardly wrapped as `module`, similar to the cases earlier
|
| 267 |
+
# with DDP, DataParallel, DeepSpeed, and FSDP
|
| 268 |
+
if hasattr(module, "module"):
|
| 269 |
+
unwrapped_module = _recursive_unwrap(module.module)
|
| 270 |
+
else:
|
| 271 |
+
unwrapped_module = module
|
| 272 |
+
# Next unwrap child sublayers recursively
|
| 273 |
+
for name, child in unwrapped_module.named_children():
|
| 274 |
+
setattr(unwrapped_module, name, _recursive_unwrap(child))
|
| 275 |
+
return unwrapped_module
|
| 276 |
+
|
| 277 |
+
# Start with top-level
|
| 278 |
+
model = _recursive_unwrap(model)
|
| 279 |
+
|
| 280 |
+
if not keep_fp32_wrapper:
|
| 281 |
+
forward = model.forward
|
| 282 |
+
original_forward = model.__dict__.pop("_original_forward", None)
|
| 283 |
+
if original_forward is not None:
|
| 284 |
+
while hasattr(forward, "__wrapped__"):
|
| 285 |
+
forward = forward.__wrapped__
|
| 286 |
+
if forward == original_forward:
|
| 287 |
+
break
|
| 288 |
+
model.forward = MethodType(forward, model)
|
| 289 |
+
if getattr(model, "_converted_to_transformer_engine", False):
|
| 290 |
+
convert_model(model, to_transformer_engine=False)
|
| 291 |
+
|
| 292 |
+
if keep_torch_compile:
|
| 293 |
+
if is_compiled:
|
| 294 |
+
compiled_model._orig_mod = model
|
| 295 |
+
model = compiled_model
|
| 296 |
+
elif has_compiled:
|
| 297 |
+
compiled_model.__dict__["_orig_mod"] = model
|
| 298 |
+
model = compiled_model
|
| 299 |
+
|
| 300 |
+
return model
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def wait_for_everyone():
|
| 304 |
+
"""
|
| 305 |
+
Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
|
| 306 |
+
|
| 307 |
+
<Tip warning={true}>
|
| 308 |
+
|
| 309 |
+
Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
|
| 310 |
+
|
| 311 |
+
</Tip>
|
| 312 |
+
"""
|
| 313 |
+
PartialState().wait_for_everyone()
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def clean_state_dict_for_safetensors(state_dict: dict):
|
| 317 |
+
"""
|
| 318 |
+
Cleans the state dictionary from a model and removes tensor aliasing if present.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
state_dict (`dict`):
|
| 322 |
+
The state dictionary from a model
|
| 323 |
+
"""
|
| 324 |
+
ptrs = collections.defaultdict(list)
|
| 325 |
+
# When bnb serialization is used, weights in state dict can be strings
|
| 326 |
+
for name, tensor in state_dict.items():
|
| 327 |
+
if not isinstance(tensor, str):
|
| 328 |
+
ptrs[id_tensor_storage(tensor)].append(name)
|
| 329 |
+
|
| 330 |
+
# These are all pointers of tensors with shared memory
|
| 331 |
+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
| 332 |
+
warn_names = set()
|
| 333 |
+
for names in shared_ptrs.values():
|
| 334 |
+
# When not all duplicates have been cleaned, we still remove those keys but put a clear warning.
|
| 335 |
+
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
| 336 |
+
# the key back leading to random tensor. A proper warning will be shown
|
| 337 |
+
# during reload (if applicable), but since the file is not necessarily compatible with
|
| 338 |
+
# the config, better show a proper warning.
|
| 339 |
+
found_names = [name for name in names if name in state_dict]
|
| 340 |
+
warn_names.update(found_names[1:])
|
| 341 |
+
for name in found_names[1:]:
|
| 342 |
+
del state_dict[name]
|
| 343 |
+
if len(warn_names) > 0:
|
| 344 |
+
logger.warning(
|
| 345 |
+
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
|
| 346 |
+
)
|
| 347 |
+
state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()}
|
| 348 |
+
return state_dict
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
|
| 352 |
+
"""
|
| 353 |
+
Save the data to disk. Use in place of `torch.save()`.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
obj:
|
| 357 |
+
The data to save
|
| 358 |
+
f:
|
| 359 |
+
The file (or file-like object) to use to save the data
|
| 360 |
+
save_on_each_node (`bool`, *optional*, defaults to `False`):
|
| 361 |
+
Whether to only save on the global main process
|
| 362 |
+
safe_serialization (`bool`, *optional*, defaults to `False`):
|
| 363 |
+
Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
| 364 |
+
"""
|
| 365 |
+
# When TorchXLA is enabled, it's necessary to transfer all data to the CPU before saving.
|
| 366 |
+
# Another issue arises with `id_tensor_storage`, which treats all XLA tensors as identical.
|
| 367 |
+
# If tensors remain on XLA, calling `clean_state_dict_for_safetensors` will result in only
|
| 368 |
+
# one XLA tensor remaining.
|
| 369 |
+
if PartialState().distributed_type == DistributedType.XLA:
|
| 370 |
+
obj = xm._maybe_convert_to_cpu(obj)
|
| 371 |
+
# Check if it's a model and remove duplicates
|
| 372 |
+
if safe_serialization:
|
| 373 |
+
save_func = partial(safe_save_file, metadata={"format": "pt"})
|
| 374 |
+
if isinstance(obj, OrderedDict):
|
| 375 |
+
obj = clean_state_dict_for_safetensors(obj)
|
| 376 |
+
else:
|
| 377 |
+
save_func = torch.save
|
| 378 |
+
|
| 379 |
+
if PartialState().is_main_process and not save_on_each_node:
|
| 380 |
+
save_func(obj, f)
|
| 381 |
+
elif PartialState().is_local_main_process and save_on_each_node:
|
| 382 |
+
save_func(obj, f)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# The following are considered "safe" globals to reconstruct various types of objects when using `weights_only=True`
|
| 386 |
+
# These should be added and then removed after loading in the file
|
| 387 |
+
np_core = np._core if is_numpy_available("2.0.0") else np.core
|
| 388 |
+
TORCH_SAFE_GLOBALS = [
|
| 389 |
+
# numpy arrays are just numbers, not objects, so we can reconstruct them safely
|
| 390 |
+
np_core.multiarray._reconstruct,
|
| 391 |
+
np.ndarray,
|
| 392 |
+
# The following are needed for the RNG states
|
| 393 |
+
encode,
|
| 394 |
+
np.dtype,
|
| 395 |
+
]
|
| 396 |
+
|
| 397 |
+
if is_numpy_available("1.25.0"):
|
| 398 |
+
TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def load(f, map_location=None, **kwargs):
|
| 402 |
+
"""
|
| 403 |
+
Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is
|
| 404 |
+
2.4.0 or higher. Otherwise will ignore the kwarg.
|
| 405 |
+
|
| 406 |
+
Will also add (and then remove) an exception for numpy arrays
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
f:
|
| 410 |
+
The file (or file-like object) to use to load the data
|
| 411 |
+
map_location:
|
| 412 |
+
a function, `torch.device`, string or a dict specifying how to remap storage locations
|
| 413 |
+
**kwargs:
|
| 414 |
+
Additional keyword arguments to pass to `torch.load()`.
|
| 415 |
+
"""
|
| 416 |
+
try:
|
| 417 |
+
if is_weights_only_available():
|
| 418 |
+
old_safe_globals = torch.serialization.get_safe_globals()
|
| 419 |
+
if "weights_only" not in kwargs:
|
| 420 |
+
kwargs["weights_only"] = True
|
| 421 |
+
torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS)
|
| 422 |
+
else:
|
| 423 |
+
kwargs.pop("weights_only", None)
|
| 424 |
+
loaded_obj = torch.load(f, map_location=map_location, **kwargs)
|
| 425 |
+
finally:
|
| 426 |
+
if is_weights_only_available():
|
| 427 |
+
torch.serialization.clear_safe_globals()
|
| 428 |
+
if old_safe_globals:
|
| 429 |
+
torch.serialization.add_safe_globals(old_safe_globals)
|
| 430 |
+
return loaded_obj
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def get_pretty_name(obj):
|
| 434 |
+
"""
|
| 435 |
+
Gets a pretty name from `obj`.
|
| 436 |
+
"""
|
| 437 |
+
if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"):
|
| 438 |
+
obj = getattr(obj, "__class__", obj)
|
| 439 |
+
if hasattr(obj, "__qualname__"):
|
| 440 |
+
return obj.__qualname__
|
| 441 |
+
if hasattr(obj, "__name__"):
|
| 442 |
+
return obj.__name__
|
| 443 |
+
return str(obj)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def merge_dicts(source, destination):
|
| 447 |
+
"""
|
| 448 |
+
Recursively merges two dictionaries.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
source (`dict`): The dictionary to merge into `destination`.
|
| 452 |
+
destination (`dict`): The dictionary to merge `source` into.
|
| 453 |
+
"""
|
| 454 |
+
for key, value in source.items():
|
| 455 |
+
if isinstance(value, dict):
|
| 456 |
+
node = destination.setdefault(key, {})
|
| 457 |
+
merge_dicts(value, node)
|
| 458 |
+
else:
|
| 459 |
+
destination[key] = value
|
| 460 |
+
|
| 461 |
+
return destination
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def is_port_in_use(port: Optional[int] = None) -> bool:
|
| 465 |
+
"""
|
| 466 |
+
Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been
|
| 467 |
+
run and need to see if the port is already in use.
|
| 468 |
+
"""
|
| 469 |
+
if port is None:
|
| 470 |
+
port = 29500
|
| 471 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 472 |
+
return s.connect_ex(("localhost", port)) == 0
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def get_free_port() -> int:
|
| 476 |
+
"""
|
| 477 |
+
Gets a free port on `localhost`. Useful for automatic port selection when port 0 is specified in distributed
|
| 478 |
+
training scenarios.
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
int: An available port number
|
| 482 |
+
"""
|
| 483 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 484 |
+
s.bind(("", 0)) # bind to port 0 for OS to assign a free port
|
| 485 |
+
return s.getsockname()[1]
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def convert_bytes(size):
|
| 489 |
+
"Converts `size` from bytes to the largest possible unit"
|
| 490 |
+
for x in ["bytes", "KB", "MB", "GB", "TB"]:
|
| 491 |
+
if size < 1024.0:
|
| 492 |
+
return f"{round(size, 2)} {x}"
|
| 493 |
+
size /= 1024.0
|
| 494 |
+
|
| 495 |
+
return f"{round(size, 2)} PB"
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def check_os_kernel():
|
| 499 |
+
"""Warns if the kernel version is below the recommended minimum on Linux."""
|
| 500 |
+
# see issue #1929
|
| 501 |
+
info = platform.uname()
|
| 502 |
+
system = info.system
|
| 503 |
+
if system != "Linux":
|
| 504 |
+
return
|
| 505 |
+
|
| 506 |
+
_, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release)
|
| 507 |
+
min_version = "5.5.0"
|
| 508 |
+
if Version(version) < Version(min_version):
|
| 509 |
+
msg = (
|
| 510 |
+
f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can "
|
| 511 |
+
"cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
|
| 512 |
+
)
|
| 513 |
+
logger.warning(msg, main_process_only=True)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def recursive_getattr(obj, attr: str):
|
| 517 |
+
"""
|
| 518 |
+
Recursive `getattr`.
|
| 519 |
+
|
| 520 |
+
Args:
|
| 521 |
+
obj:
|
| 522 |
+
A class instance holding the attribute.
|
| 523 |
+
attr (`str`):
|
| 524 |
+
The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
def _getattr(obj, attr):
|
| 528 |
+
return getattr(obj, attr)
|
| 529 |
+
|
| 530 |
+
return reduce(_getattr, [obj] + attr.split("."))
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: bool = False) -> list[torch.nn.Module]:
|
| 534 |
+
"""Traverse the model in bottom-up order and return the children modules in that order.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
model (`torch.nn.Module`): the model to get the children of
|
| 538 |
+
|
| 539 |
+
Returns:
|
| 540 |
+
`list[torch.nn.Module]`: a list of children modules of `model` in bottom-up order. The last element is the
|
| 541 |
+
`model` itself.
|
| 542 |
+
"""
|
| 543 |
+
top = model if not return_fqns else ("", model)
|
| 544 |
+
stack = [top]
|
| 545 |
+
ordered_modules = []
|
| 546 |
+
while stack:
|
| 547 |
+
current_module = stack.pop()
|
| 548 |
+
if return_fqns:
|
| 549 |
+
current_module_name, current_module = current_module
|
| 550 |
+
for name, attr in current_module.named_children():
|
| 551 |
+
if isinstance(attr, torch.nn.Module):
|
| 552 |
+
if return_fqns:
|
| 553 |
+
child_name = current_module_name + "." + name if current_module_name else name
|
| 554 |
+
stack.append((child_name, attr))
|
| 555 |
+
else:
|
| 556 |
+
stack.append(attr)
|
| 557 |
+
if return_fqns:
|
| 558 |
+
ordered_modules.append((current_module_name, current_module))
|
| 559 |
+
else:
|
| 560 |
+
ordered_modules.append(current_module)
|
| 561 |
+
return ordered_modules[::-1]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/random.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
from typing import Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ..state import AcceleratorState
|
| 22 |
+
from .constants import CUDA_DISTRIBUTED_TYPES
|
| 23 |
+
from .dataclasses import DistributedType, RNGType
|
| 24 |
+
from .imports import (
|
| 25 |
+
is_hpu_available,
|
| 26 |
+
is_mlu_available,
|
| 27 |
+
is_musa_available,
|
| 28 |
+
is_npu_available,
|
| 29 |
+
is_sdaa_available,
|
| 30 |
+
is_torch_xla_available,
|
| 31 |
+
is_xpu_available,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_torch_xla_available():
|
| 36 |
+
import torch_xla.core.xla_model as xm
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def set_seed(seed: int, device_specific: bool = False, deterministic: bool = False):
|
| 40 |
+
"""
|
| 41 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
seed (`int`):
|
| 45 |
+
The seed to set.
|
| 46 |
+
device_specific (`bool`, *optional*, defaults to `False`):
|
| 47 |
+
Whether to differ the seed on each device slightly with `self.process_index`.
|
| 48 |
+
deterministic (`bool`, *optional*, defaults to `False`):
|
| 49 |
+
Whether to use deterministic algorithms where available. Can slow down training.
|
| 50 |
+
"""
|
| 51 |
+
if device_specific:
|
| 52 |
+
seed += AcceleratorState().process_index
|
| 53 |
+
random.seed(seed)
|
| 54 |
+
np.random.seed(seed)
|
| 55 |
+
torch.manual_seed(seed)
|
| 56 |
+
if is_xpu_available():
|
| 57 |
+
torch.xpu.manual_seed_all(seed)
|
| 58 |
+
elif is_npu_available():
|
| 59 |
+
torch.npu.manual_seed_all(seed)
|
| 60 |
+
elif is_mlu_available():
|
| 61 |
+
torch.mlu.manual_seed_all(seed)
|
| 62 |
+
elif is_sdaa_available():
|
| 63 |
+
torch.sdaa.manual_seed_all(seed)
|
| 64 |
+
elif is_musa_available():
|
| 65 |
+
torch.musa.manual_seed_all(seed)
|
| 66 |
+
elif is_hpu_available():
|
| 67 |
+
torch.hpu.manual_seed_all(seed)
|
| 68 |
+
else:
|
| 69 |
+
torch.cuda.manual_seed_all(seed)
|
| 70 |
+
# ^^ safe to call this function even if cuda is not available
|
| 71 |
+
if is_torch_xla_available():
|
| 72 |
+
xm.set_rng_state(seed)
|
| 73 |
+
|
| 74 |
+
if deterministic:
|
| 75 |
+
torch.use_deterministic_algorithms(True)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):
|
| 79 |
+
# Get the proper rng state
|
| 80 |
+
if rng_type == RNGType.TORCH:
|
| 81 |
+
rng_state = torch.get_rng_state()
|
| 82 |
+
elif rng_type == RNGType.CUDA:
|
| 83 |
+
rng_state = torch.cuda.get_rng_state()
|
| 84 |
+
elif rng_type == RNGType.XLA:
|
| 85 |
+
assert is_torch_xla_available(), "Can't synchronize XLA seeds as torch_xla is unavailable."
|
| 86 |
+
rng_state = torch.tensor(xm.get_rng_state())
|
| 87 |
+
elif rng_type == RNGType.NPU:
|
| 88 |
+
assert is_npu_available(), "Can't synchronize NPU seeds on an environment without NPUs."
|
| 89 |
+
rng_state = torch.npu.get_rng_state()
|
| 90 |
+
elif rng_type == RNGType.MLU:
|
| 91 |
+
assert is_mlu_available(), "Can't synchronize MLU seeds on an environment without MLUs."
|
| 92 |
+
rng_state = torch.mlu.get_rng_state()
|
| 93 |
+
elif rng_type == RNGType.SDAA:
|
| 94 |
+
assert is_sdaa_available(), "Can't synchronize SDAA seeds on an environment without SDAAs."
|
| 95 |
+
rng_state = torch.sdaa.get_rng_state()
|
| 96 |
+
elif rng_type == RNGType.MUSA:
|
| 97 |
+
assert is_musa_available(), "Can't synchronize MUSA seeds on an environment without MUSAs."
|
| 98 |
+
rng_state = torch.musa.get_rng_state()
|
| 99 |
+
elif rng_type == RNGType.XPU:
|
| 100 |
+
assert is_xpu_available(), "Can't synchronize XPU seeds on an environment without XPUs."
|
| 101 |
+
rng_state = torch.xpu.get_rng_state()
|
| 102 |
+
elif rng_type == RNGType.HPU:
|
| 103 |
+
assert is_hpu_available(), "Can't synchronize HPU seeds on an environment without HPUs."
|
| 104 |
+
rng_state = torch.hpu.get_rng_state()
|
| 105 |
+
elif rng_type == RNGType.GENERATOR:
|
| 106 |
+
assert generator is not None, "Need a generator to synchronize its seed."
|
| 107 |
+
rng_state = generator.get_state()
|
| 108 |
+
|
| 109 |
+
# Broadcast the rng state from device 0 to other devices
|
| 110 |
+
state = AcceleratorState()
|
| 111 |
+
if state.distributed_type == DistributedType.XLA:
|
| 112 |
+
rng_state = rng_state.to(xm.xla_device())
|
| 113 |
+
xm.collective_broadcast([rng_state])
|
| 114 |
+
xm.mark_step()
|
| 115 |
+
rng_state = rng_state.cpu()
|
| 116 |
+
elif (
|
| 117 |
+
state.distributed_type in CUDA_DISTRIBUTED_TYPES
|
| 118 |
+
or state.distributed_type == DistributedType.MULTI_MLU
|
| 119 |
+
or state.distributed_type == DistributedType.MULTI_SDAA
|
| 120 |
+
or state.distributed_type == DistributedType.MULTI_MUSA
|
| 121 |
+
or state.distributed_type == DistributedType.MULTI_NPU
|
| 122 |
+
or state.distributed_type == DistributedType.MULTI_XPU
|
| 123 |
+
or state.distributed_type == DistributedType.MULTI_HPU
|
| 124 |
+
):
|
| 125 |
+
rng_state = rng_state.to(state.device)
|
| 126 |
+
torch.distributed.broadcast(rng_state, 0)
|
| 127 |
+
rng_state = rng_state.cpu()
|
| 128 |
+
elif state.distributed_type == DistributedType.MULTI_CPU:
|
| 129 |
+
torch.distributed.broadcast(rng_state, 0)
|
| 130 |
+
|
| 131 |
+
# Set the broadcast rng state
|
| 132 |
+
if rng_type == RNGType.TORCH:
|
| 133 |
+
torch.set_rng_state(rng_state)
|
| 134 |
+
elif rng_type == RNGType.CUDA:
|
| 135 |
+
torch.cuda.set_rng_state(rng_state)
|
| 136 |
+
elif rng_type == RNGType.NPU:
|
| 137 |
+
torch.npu.set_rng_state(rng_state)
|
| 138 |
+
elif rng_type == RNGType.MLU:
|
| 139 |
+
torch.mlu.set_rng_state(rng_state)
|
| 140 |
+
elif rng_type == RNGType.SDAA:
|
| 141 |
+
torch.sdaa.set_rng_state(rng_state)
|
| 142 |
+
elif rng_type == RNGType.MUSA:
|
| 143 |
+
torch.musa.set_rng_state(rng_state)
|
| 144 |
+
elif rng_type == RNGType.XPU:
|
| 145 |
+
torch.xpu.set_rng_state(rng_state)
|
| 146 |
+
elif rng_state == RNGType.HPU:
|
| 147 |
+
torch.hpu.set_rng_state(rng_state)
|
| 148 |
+
elif rng_type == RNGType.XLA:
|
| 149 |
+
xm.set_rng_state(rng_state.item())
|
| 150 |
+
elif rng_type == RNGType.GENERATOR:
|
| 151 |
+
generator.set_state(rng_state)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def synchronize_rng_states(rng_types: list[Union[str, RNGType]], generator: Optional[torch.Generator] = None):
|
| 155 |
+
for rng_type in rng_types:
|
| 156 |
+
synchronize_rng_state(RNGType(rng_type), generator=generator)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/rich.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .imports import is_rich_available
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if is_rich_available():
|
| 19 |
+
from rich.traceback import install
|
| 20 |
+
|
| 21 |
+
install(show_locals=False)
|
| 22 |
+
|
| 23 |
+
else:
|
| 24 |
+
raise ModuleNotFoundError("To use the rich extension, install rich with `pip install rich`")
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/torch_xla.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import importlib.metadata
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def install_xla(upgrade: bool = False):
|
| 21 |
+
"""
|
| 22 |
+
Helper function to install appropriate xla wheels based on the `torch` version in Google Colaboratory.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
upgrade (`bool`, *optional*, defaults to `False`):
|
| 26 |
+
Whether to upgrade `torch` and install the latest `torch_xla` wheels.
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
>>> from accelerate.utils import install_xla
|
| 32 |
+
|
| 33 |
+
>>> install_xla(upgrade=True)
|
| 34 |
+
```
|
| 35 |
+
"""
|
| 36 |
+
in_colab = False
|
| 37 |
+
if "IPython" in sys.modules:
|
| 38 |
+
in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
|
| 39 |
+
|
| 40 |
+
if in_colab:
|
| 41 |
+
if upgrade:
|
| 42 |
+
torch_install_cmd = ["pip", "install", "-U", "torch"]
|
| 43 |
+
subprocess.run(torch_install_cmd, check=True)
|
| 44 |
+
# get the current version of torch
|
| 45 |
+
torch_version = importlib.metadata.version("torch")
|
| 46 |
+
torch_version_trunc = torch_version[: torch_version.rindex(".")]
|
| 47 |
+
xla_wheel = f"https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-{torch_version_trunc}-cp37-cp37m-linux_x86_64.whl"
|
| 48 |
+
xla_install_cmd = ["pip", "install", xla_wheel]
|
| 49 |
+
subprocess.run(xla_install_cmd, check=True)
|
| 50 |
+
else:
|
| 51 |
+
raise RuntimeError("`install_xla` utility works only on google colab.")
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/tqdm.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from .imports import is_tqdm_available
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if is_tqdm_available():
|
| 20 |
+
from tqdm.auto import tqdm as _tqdm
|
| 21 |
+
|
| 22 |
+
from ..state import PartialState
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def tqdm(*args, main_process_only: bool = True, **kwargs):
|
| 26 |
+
"""
|
| 27 |
+
Wrapper around `tqdm.tqdm` that optionally displays only on the main process.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
main_process_only (`bool`, *optional*):
|
| 31 |
+
Whether to display the progress bar only on the main process
|
| 32 |
+
"""
|
| 33 |
+
if not is_tqdm_available():
|
| 34 |
+
raise ImportError("Accelerate's `tqdm` module requires `tqdm` to be installed. Please run `pip install tqdm`.")
|
| 35 |
+
if len(args) > 0 and isinstance(args[0], bool):
|
| 36 |
+
raise ValueError(
|
| 37 |
+
"Passing `True` or `False` as the first argument to Accelerate's `tqdm` wrapper is unsupported. "
|
| 38 |
+
"Please use the `main_process_only` keyword argument instead."
|
| 39 |
+
)
|
| 40 |
+
disable = kwargs.pop("disable", False)
|
| 41 |
+
if main_process_only and not disable:
|
| 42 |
+
disable = PartialState().local_process_index != 0
|
| 43 |
+
return _tqdm(*args, **kwargs, disable=disable)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from types import MethodType
|
| 16 |
+
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from .imports import is_hpu_available, is_transformer_engine_available
|
| 20 |
+
from .operations import GatheredParameters
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Do not import `transformer_engine` at package level to avoid potential issues
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
|
| 27 |
+
"""
|
| 28 |
+
Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
|
| 29 |
+
"""
|
| 30 |
+
if not is_transformer_engine_available():
|
| 31 |
+
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
|
| 32 |
+
|
| 33 |
+
if is_hpu_available():
|
| 34 |
+
import intel_transformer_engine as te
|
| 35 |
+
|
| 36 |
+
if not hasattr(te, "LayerNorm"):
|
| 37 |
+
# HPU does not have a LayerNorm implementation in TE
|
| 38 |
+
te.LayerNorm = nn.LayerNorm
|
| 39 |
+
else:
|
| 40 |
+
import transformer_engine.pytorch as te
|
| 41 |
+
|
| 42 |
+
for name, module in model.named_children():
|
| 43 |
+
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
|
| 44 |
+
has_bias = module.bias is not None
|
| 45 |
+
params_to_gather = [module.weight]
|
| 46 |
+
if has_bias:
|
| 47 |
+
params_to_gather.append(module.bias)
|
| 48 |
+
|
| 49 |
+
with GatheredParameters(params_to_gather, modifier_rank=0):
|
| 50 |
+
if any(p % 16 != 0 for p in module.weight.shape):
|
| 51 |
+
return
|
| 52 |
+
te_module = te.Linear(
|
| 53 |
+
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
|
| 54 |
+
)
|
| 55 |
+
te_module.weight.copy_(module.weight)
|
| 56 |
+
if has_bias:
|
| 57 |
+
te_module.bias.copy_(module.bias)
|
| 58 |
+
|
| 59 |
+
setattr(model, name, te_module)
|
| 60 |
+
# Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
|
| 61 |
+
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
|
| 62 |
+
with GatheredParameters([module.weight, module.bias], modifier_rank=0):
|
| 63 |
+
has_bias = module.bias is not None
|
| 64 |
+
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
|
| 65 |
+
te_module.weight.copy_(module.weight)
|
| 66 |
+
if has_bias:
|
| 67 |
+
te_module.bias.copy_(module.bias)
|
| 68 |
+
|
| 69 |
+
setattr(model, name, te_module)
|
| 70 |
+
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
|
| 71 |
+
has_bias = module.bias is not None
|
| 72 |
+
new_module = nn.Linear(
|
| 73 |
+
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
|
| 74 |
+
)
|
| 75 |
+
new_module.weight.copy_(module.weight)
|
| 76 |
+
if has_bias:
|
| 77 |
+
new_module.bias.copy_(module.bias)
|
| 78 |
+
|
| 79 |
+
setattr(model, name, new_module)
|
| 80 |
+
elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln:
|
| 81 |
+
new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
|
| 82 |
+
new_module.weight.copy_(module.weight)
|
| 83 |
+
new_module.bias.copy_(module.bias)
|
| 84 |
+
|
| 85 |
+
setattr(model, name, new_module)
|
| 86 |
+
else:
|
| 87 |
+
convert_model(
|
| 88 |
+
module,
|
| 89 |
+
to_transformer_engine=to_transformer_engine,
|
| 90 |
+
_convert_linear=_convert_linear,
|
| 91 |
+
_convert_ln=_convert_ln,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def has_transformer_engine_layers(model):
|
| 96 |
+
"""
|
| 97 |
+
Returns whether a given model has some `transformer_engine` layer or not.
|
| 98 |
+
"""
|
| 99 |
+
if not is_transformer_engine_available():
|
| 100 |
+
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
|
| 101 |
+
|
| 102 |
+
if is_hpu_available():
|
| 103 |
+
import intel_transformer_engine as te
|
| 104 |
+
|
| 105 |
+
module_cls_to_check = te.Linear
|
| 106 |
+
else:
|
| 107 |
+
import transformer_engine.pytorch as te
|
| 108 |
+
|
| 109 |
+
module_cls_to_check = (te.LayerNorm, te.Linear, te.TransformerLayer)
|
| 110 |
+
|
| 111 |
+
for m in model.modules():
|
| 112 |
+
if isinstance(m, module_cls_to_check):
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
|
| 119 |
+
"""
|
| 120 |
+
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
|
| 121 |
+
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
|
| 122 |
+
"""
|
| 123 |
+
if not is_transformer_engine_available():
|
| 124 |
+
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
|
| 125 |
+
|
| 126 |
+
if is_hpu_available():
|
| 127 |
+
from intel_transformer_engine import fp8_autocast
|
| 128 |
+
else:
|
| 129 |
+
from transformer_engine.pytorch import fp8_autocast
|
| 130 |
+
|
| 131 |
+
def forward(self, *args, **kwargs):
|
| 132 |
+
enabled = use_during_eval or self.training
|
| 133 |
+
with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe):
|
| 134 |
+
return model_forward(*args, **kwargs)
|
| 135 |
+
|
| 136 |
+
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
| 137 |
+
forward.__wrapped__ = model_forward
|
| 138 |
+
|
| 139 |
+
return forward
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def apply_fp8_autowrap(model, fp8_recipe_handler):
|
| 143 |
+
"""
|
| 144 |
+
Applies FP8 context manager to the model's forward method
|
| 145 |
+
"""
|
| 146 |
+
if not is_transformer_engine_available():
|
| 147 |
+
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
|
| 148 |
+
|
| 149 |
+
if is_hpu_available():
|
| 150 |
+
import intel_transformer_engine.recipe as te_recipe
|
| 151 |
+
|
| 152 |
+
is_fp8_block_scaling_available = False
|
| 153 |
+
message = "MXFP8 block scaling is not available on HPU."
|
| 154 |
+
|
| 155 |
+
else:
|
| 156 |
+
import transformer_engine.common.recipe as te_recipe
|
| 157 |
+
import transformer_engine.pytorch as te
|
| 158 |
+
|
| 159 |
+
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
|
| 160 |
+
|
| 161 |
+
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
|
| 162 |
+
if "fp8_format" in kwargs:
|
| 163 |
+
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
|
| 164 |
+
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
|
| 165 |
+
use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
|
| 166 |
+
|
| 167 |
+
if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
|
| 168 |
+
raise ValueError(f"MXFP8 block scaling is not available: {message}")
|
| 169 |
+
|
| 170 |
+
if use_mxfp8_block_scaling:
|
| 171 |
+
if "amax_compute_algo" in kwargs:
|
| 172 |
+
raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
|
| 173 |
+
if "amax_history_len" in kwargs:
|
| 174 |
+
raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
|
| 175 |
+
fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
|
| 176 |
+
else:
|
| 177 |
+
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
|
| 178 |
+
|
| 179 |
+
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
|
| 180 |
+
|
| 181 |
+
if hasattr(model.forward, "__func__"):
|
| 182 |
+
model.forward = MethodType(new_forward, model)
|
| 183 |
+
else:
|
| 184 |
+
model.forward = new_forward
|
| 185 |
+
|
| 186 |
+
return model
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/versions.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import importlib.metadata
|
| 16 |
+
from typing import Union
|
| 17 |
+
|
| 18 |
+
from packaging.version import Version, parse
|
| 19 |
+
|
| 20 |
+
from .constants import STR_OPERATION_TO_FUNC
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
torch_version = parse(importlib.metadata.version("torch"))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
|
| 27 |
+
"""
|
| 28 |
+
Compares a library version to some requirement using a given operation.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
library_or_version (`str` or `packaging.version.Version`):
|
| 32 |
+
A library name or a version to check.
|
| 33 |
+
operation (`str`):
|
| 34 |
+
A string representation of an operator, such as `">"` or `"<="`.
|
| 35 |
+
requirement_version (`str`):
|
| 36 |
+
The version to compare the library version against
|
| 37 |
+
"""
|
| 38 |
+
if operation not in STR_OPERATION_TO_FUNC.keys():
|
| 39 |
+
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
|
| 40 |
+
operation = STR_OPERATION_TO_FUNC[operation]
|
| 41 |
+
if isinstance(library_or_version, str):
|
| 42 |
+
library_or_version = parse(importlib.metadata.version(library_or_version))
|
| 43 |
+
return operation(library_or_version, parse(requirement_version))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def is_torch_version(operation: str, version: str):
|
| 47 |
+
"""
|
| 48 |
+
Compares the current PyTorch version to a given reference with an operation.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
operation (`str`):
|
| 52 |
+
A string representation of an operator, such as `">"` or `"<="`
|
| 53 |
+
version (`str`):
|
| 54 |
+
A string version of PyTorch
|
| 55 |
+
"""
|
| 56 |
+
return compare_versions(torch_version, operation, version)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (283 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/main.cpython-312.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/cuda_pathfinder-1.4.0.dist-info/licenses/LICENSE
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.38 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__version__.cpython-312.pyc
ADDED
|
Binary file (619 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align.cpython-312.pyc
ADDED
|
Binary file (1.37 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align_getter.cpython-312.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_base.cpython-312.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_column.cpython-312.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_common.cpython-312.pyc
ADDED
|
Binary file (3.47 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_container.cpython-312.pyc
ADDED
|
Binary file (9.41 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_converter.cpython-312.pyc
ADDED
|
Binary file (5.21 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_dataproperty.cpython-312.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_extractor.cpython-312.pyc
ADDED
|
Binary file (34.7 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_formatter.cpython-312.pyc
ADDED
|
Binary file (4.91 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_function.cpython-312.pyc
ADDED
|
Binary file (5.44 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_interface.cpython-312.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_line_break.cpython-312.pyc
ADDED
|
Binary file (546 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_preprocessor.cpython-312.pyc
ADDED
|
Binary file (8.11 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/typing.cpython-312.pyc
ADDED
|
Binary file (1.97 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._logger import logger, set_logger # type: ignore
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = (
|
| 5 |
+
"logger",
|
| 6 |
+
"set_logger",
|
| 7 |
+
)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (307 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_logger.cpython-312.pyc
ADDED
|
Binary file (950 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_null_logger.cpython-312.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|