Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/data_loader.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/inference.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/local_sgd.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/logging.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/parallelism_config.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/scheduler.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/tracking.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__init__.py +13 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py +54 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/env.py +131 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/estimate.py +316 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/launch.py +1291 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/merge.py +69 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/test.py +65 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/to_fsdp2.py +172 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/utils.py +123 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__init__.py +66 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/examples.py +148 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/testing.py +881 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/training.py +148 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/abstract_nodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/algorithms.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/approximations.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/ast.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cfunctions.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cnodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cutils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cxxnodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/fnodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/futils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/matrix_nodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/numpy_nodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/pynodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/pyutils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/rewriting.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/scipy_nodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__init__.py +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_abstract_nodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_algorithms.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_applications.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_approximations.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_ast.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_cfunctions.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_cnodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_cxxnodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_fnodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_matrix_nodes.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_numpy_nodes.cpython-312.pyc +0 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/data_loader.cpython-312.pyc
ADDED
|
Binary file (65.1 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (7.56 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/local_sgd.cpython-312.pyc
ADDED
|
Binary file (5.14 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/logging.cpython-312.pyc
ADDED
|
Binary file (5.76 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/parallelism_config.cpython-312.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/scheduler.cpython-312.pyc
ADDED
|
Binary file (4.69 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/tracking.cpython-312.pyc
ADDED
|
Binary file (64.1 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from accelerate.commands.config import get_config_parser
|
| 18 |
+
from accelerate.commands.env import env_command_parser
|
| 19 |
+
from accelerate.commands.estimate import estimate_command_parser
|
| 20 |
+
from accelerate.commands.launch import launch_command_parser
|
| 21 |
+
from accelerate.commands.merge import merge_command_parser
|
| 22 |
+
from accelerate.commands.test import test_command_parser
|
| 23 |
+
from accelerate.commands.to_fsdp2 import to_fsdp2_command_parser
|
| 24 |
+
from accelerate.commands.tpu import tpu_command_parser
|
| 25 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
parser = CustomArgumentParser("Accelerate CLI tool", usage="accelerate <command> [<args>]", allow_abbrev=False)
|
| 30 |
+
subparsers = parser.add_subparsers(help="accelerate command helpers")
|
| 31 |
+
|
| 32 |
+
# Register commands
|
| 33 |
+
get_config_parser(subparsers=subparsers)
|
| 34 |
+
estimate_command_parser(subparsers=subparsers)
|
| 35 |
+
env_command_parser(subparsers=subparsers)
|
| 36 |
+
launch_command_parser(subparsers=subparsers)
|
| 37 |
+
merge_command_parser(subparsers=subparsers)
|
| 38 |
+
tpu_command_parser(subparsers=subparsers)
|
| 39 |
+
test_command_parser(subparsers=subparsers)
|
| 40 |
+
to_fsdp2_command_parser(subparsers=subparsers)
|
| 41 |
+
|
| 42 |
+
# Let's go
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
if not hasattr(args, "func"):
|
| 46 |
+
parser.print_help()
|
| 47 |
+
exit(1)
|
| 48 |
+
|
| 49 |
+
# Run
|
| 50 |
+
args.func(args)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/env.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import platform
|
| 20 |
+
import subprocess
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import psutil
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from accelerate import __version__ as version
|
| 27 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
| 28 |
+
|
| 29 |
+
from ..utils import is_mlu_available, is_musa_available, is_npu_available, is_sdaa_available, is_xpu_available
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def env_command_parser(subparsers=None):
|
| 33 |
+
if subparsers is not None:
|
| 34 |
+
parser = subparsers.add_parser("env")
|
| 35 |
+
else:
|
| 36 |
+
parser = argparse.ArgumentParser("Accelerate env command")
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--config_file", default=None, help="The config file to use for the default values in the launching script."
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if subparsers is not None:
|
| 43 |
+
parser.set_defaults(func=env_command)
|
| 44 |
+
return parser
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def env_command(args):
|
| 48 |
+
pt_version = torch.__version__
|
| 49 |
+
pt_cuda_available = torch.cuda.is_available()
|
| 50 |
+
pt_xpu_available = is_xpu_available()
|
| 51 |
+
pt_mlu_available = is_mlu_available()
|
| 52 |
+
pt_sdaa_available = is_sdaa_available()
|
| 53 |
+
pt_musa_available = is_musa_available()
|
| 54 |
+
pt_npu_available = is_npu_available()
|
| 55 |
+
|
| 56 |
+
accelerator = "N/A"
|
| 57 |
+
if pt_cuda_available:
|
| 58 |
+
accelerator = "CUDA"
|
| 59 |
+
elif pt_xpu_available:
|
| 60 |
+
accelerator = "XPU"
|
| 61 |
+
elif pt_mlu_available:
|
| 62 |
+
accelerator = "MLU"
|
| 63 |
+
elif pt_sdaa_available:
|
| 64 |
+
accelerator = "SDAA"
|
| 65 |
+
elif pt_musa_available:
|
| 66 |
+
accelerator = "MUSA"
|
| 67 |
+
elif pt_npu_available:
|
| 68 |
+
accelerator = "NPU"
|
| 69 |
+
|
| 70 |
+
accelerate_config = "Not found"
|
| 71 |
+
# Get the default from the config file.
|
| 72 |
+
if args.config_file is not None or os.path.isfile(default_config_file):
|
| 73 |
+
accelerate_config = load_config_from_file(args.config_file).to_dict()
|
| 74 |
+
|
| 75 |
+
# if we can run which, get it
|
| 76 |
+
command = None
|
| 77 |
+
bash_location = "Not found"
|
| 78 |
+
if os.name == "nt":
|
| 79 |
+
command = ["where", "accelerate"]
|
| 80 |
+
elif os.name == "posix":
|
| 81 |
+
command = ["which", "accelerate"]
|
| 82 |
+
if command is not None:
|
| 83 |
+
bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()
|
| 84 |
+
info = {
|
| 85 |
+
"`Accelerate` version": version,
|
| 86 |
+
"Platform": platform.platform(),
|
| 87 |
+
"`accelerate` bash location": bash_location,
|
| 88 |
+
"Python version": platform.python_version(),
|
| 89 |
+
"Numpy version": np.__version__,
|
| 90 |
+
"PyTorch version": f"{pt_version}",
|
| 91 |
+
"PyTorch accelerator": accelerator,
|
| 92 |
+
"System RAM": f"{psutil.virtual_memory().total / 1024**3:.2f} GB",
|
| 93 |
+
}
|
| 94 |
+
if pt_cuda_available:
|
| 95 |
+
info["GPU type"] = torch.cuda.get_device_name()
|
| 96 |
+
elif pt_xpu_available:
|
| 97 |
+
info["XPU type"] = torch.xpu.get_device_name()
|
| 98 |
+
elif pt_mlu_available:
|
| 99 |
+
info["MLU type"] = torch.mlu.get_device_name()
|
| 100 |
+
elif pt_sdaa_available:
|
| 101 |
+
info["SDAA type"] = torch.sdaa.get_device_name()
|
| 102 |
+
elif pt_musa_available:
|
| 103 |
+
info["MUSA type"] = torch.musa.get_device_name()
|
| 104 |
+
elif pt_npu_available:
|
| 105 |
+
info["CANN version"] = torch.version.cann
|
| 106 |
+
|
| 107 |
+
print("\nCopy-and-paste the text below in your GitHub issue\n")
|
| 108 |
+
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]))
|
| 109 |
+
|
| 110 |
+
print("- `Accelerate` default config:" if args.config_file is None else "- `Accelerate` config passed:")
|
| 111 |
+
accelerate_config_str = (
|
| 112 |
+
"\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
|
| 113 |
+
if isinstance(accelerate_config, dict)
|
| 114 |
+
else f"\t{accelerate_config}"
|
| 115 |
+
)
|
| 116 |
+
print(accelerate_config_str)
|
| 117 |
+
|
| 118 |
+
info["`Accelerate` configs"] = accelerate_config
|
| 119 |
+
|
| 120 |
+
return info
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def main() -> int:
|
| 124 |
+
parser = env_command_parser()
|
| 125 |
+
args = parser.parse_args()
|
| 126 |
+
env_command(args)
|
| 127 |
+
return 0
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
raise SystemExit(main())
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/estimate.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from typing import Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from huggingface_hub import model_info
|
| 20 |
+
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
| 21 |
+
|
| 22 |
+
from accelerate import init_empty_weights
|
| 23 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 24 |
+
from accelerate.utils import (
|
| 25 |
+
calculate_maximum_sizes,
|
| 26 |
+
convert_bytes,
|
| 27 |
+
is_timm_available,
|
| 28 |
+
is_transformers_available,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if is_transformers_available():
|
| 33 |
+
import transformers
|
| 34 |
+
from transformers import AutoConfig, AutoModel
|
| 35 |
+
|
| 36 |
+
if is_timm_available():
|
| 37 |
+
import timm
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def verify_on_hub(repo: str, token: Optional[str] = None):
|
| 41 |
+
"Verifies that the model is on the hub and returns the model info."
|
| 42 |
+
try:
|
| 43 |
+
return model_info(repo, token=token)
|
| 44 |
+
except (OSError, GatedRepoError):
|
| 45 |
+
return "gated"
|
| 46 |
+
except RepositoryNotFoundError:
|
| 47 |
+
return "repo"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def check_has_model(error):
|
| 51 |
+
"""
|
| 52 |
+
Checks what library spawned `error` when a model is not found
|
| 53 |
+
"""
|
| 54 |
+
if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]:
|
| 55 |
+
return "timm"
|
| 56 |
+
elif (
|
| 57 |
+
is_transformers_available()
|
| 58 |
+
and isinstance(error, OSError)
|
| 59 |
+
and "does not appear to have a file named" in error.args[0]
|
| 60 |
+
):
|
| 61 |
+
return "transformers"
|
| 62 |
+
else:
|
| 63 |
+
return "unknown"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def create_empty_model(
|
| 67 |
+
model_name: str, library_name: str, trust_remote_code: bool = False, access_token: Optional[str] = None
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory
|
| 71 |
+
consumption.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model_name (`str`):
|
| 75 |
+
The model name on the Hub
|
| 76 |
+
library_name (`str`):
|
| 77 |
+
The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no
|
| 78 |
+
metadata on the Hub to determine the library.
|
| 79 |
+
trust_remote_code (`bool`, `optional`, defaults to `False`):
|
| 80 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 81 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 82 |
+
execute code present on the Hub on your local machine.
|
| 83 |
+
access_token (`str`, `optional`, defaults to `None`):
|
| 84 |
+
The access token to use to access private or gated models on the Hub. (for use on the Gradio app)
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
`torch.nn.Module`: The torch model that has been initialized on the `meta` device.
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
model_info = verify_on_hub(model_name, access_token)
|
| 91 |
+
# Simplified errors
|
| 92 |
+
if model_info == "gated":
|
| 93 |
+
raise GatedRepoError(
|
| 94 |
+
f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`."
|
| 95 |
+
)
|
| 96 |
+
elif model_info == "repo":
|
| 97 |
+
raise RepositoryNotFoundError(
|
| 98 |
+
f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo,"
|
| 99 |
+
" make sure you are authenticated via `huggingface-cli login` and have access."
|
| 100 |
+
)
|
| 101 |
+
if library_name is None:
|
| 102 |
+
library_name = getattr(model_info, "library_name", False)
|
| 103 |
+
if not library_name:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)"
|
| 106 |
+
)
|
| 107 |
+
if library_name == "transformers":
|
| 108 |
+
if not is_transformers_available():
|
| 109 |
+
raise ImportError(
|
| 110 |
+
f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`"
|
| 111 |
+
)
|
| 112 |
+
print(f"Loading pretrained config for `{model_name}` from `transformers`...")
|
| 113 |
+
if model_info.config is None:
|
| 114 |
+
raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.")
|
| 115 |
+
|
| 116 |
+
auto_map = model_info.config.get("auto_map", False)
|
| 117 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token)
|
| 118 |
+
with init_empty_weights():
|
| 119 |
+
# remote code could specify a specific `AutoModel` class in the `auto_map`
|
| 120 |
+
constructor = AutoModel
|
| 121 |
+
if isinstance(auto_map, dict):
|
| 122 |
+
value = None
|
| 123 |
+
for key in auto_map.keys():
|
| 124 |
+
if key.startswith("AutoModelFor"):
|
| 125 |
+
value = key
|
| 126 |
+
break
|
| 127 |
+
if value is not None:
|
| 128 |
+
constructor = getattr(transformers, value)
|
| 129 |
+
# we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config
|
| 130 |
+
model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code)
|
| 131 |
+
elif library_name == "timm":
|
| 132 |
+
if not is_timm_available():
|
| 133 |
+
raise ImportError(
|
| 134 |
+
f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`"
|
| 135 |
+
)
|
| 136 |
+
print(f"Loading pretrained config for `{model_name}` from `timm`...")
|
| 137 |
+
with init_empty_weights():
|
| 138 |
+
model = timm.create_model(model_name, pretrained=False)
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support."
|
| 142 |
+
)
|
| 143 |
+
return model
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def create_ascii_table(headers: list, rows: list, title: str):
|
| 147 |
+
"Creates a pretty table from a list of rows, minimal version of `tabulate`."
|
| 148 |
+
sep_char, in_between = "│", "─"
|
| 149 |
+
column_widths = []
|
| 150 |
+
for i in range(len(headers)):
|
| 151 |
+
column_values = [row[i] for row in rows] + [headers[i]]
|
| 152 |
+
max_column_width = max(len(value) for value in column_values)
|
| 153 |
+
column_widths.append(max_column_width)
|
| 154 |
+
|
| 155 |
+
formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))]
|
| 156 |
+
|
| 157 |
+
pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}"
|
| 158 |
+
diff = 0
|
| 159 |
+
|
| 160 |
+
def make_row(left_char, middle_char, right_char):
|
| 161 |
+
return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}"
|
| 162 |
+
|
| 163 |
+
separator = make_row("├", "┼", "┤")
|
| 164 |
+
if len(title) > sum(column_widths):
|
| 165 |
+
diff = abs(len(title) - len(separator))
|
| 166 |
+
column_widths[-1] += diff
|
| 167 |
+
|
| 168 |
+
# Update with diff
|
| 169 |
+
separator = make_row("├", "┼", "┤")
|
| 170 |
+
initial_rows = [
|
| 171 |
+
make_row("┌", in_between, "┐"),
|
| 172 |
+
f"{sep_char}{title.center(len(separator) - 2)}{sep_char}",
|
| 173 |
+
make_row("├", "┬", "┤"),
|
| 174 |
+
]
|
| 175 |
+
table = "\n".join(initial_rows) + "\n"
|
| 176 |
+
column_widths[-1] += diff
|
| 177 |
+
centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)]
|
| 178 |
+
table += f"{pattern % tuple(centered_line)}\n{separator}\n"
|
| 179 |
+
for i, line in enumerate(rows):
|
| 180 |
+
centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)]
|
| 181 |
+
table += f"{pattern % tuple(centered_line)}\n"
|
| 182 |
+
table += f"└{'┴'.join([in_between * n for n in column_widths])}┘"
|
| 183 |
+
|
| 184 |
+
return table
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def estimate_command_parser(subparsers=None):
|
| 188 |
+
if subparsers is not None:
|
| 189 |
+
parser = subparsers.add_parser("estimate-memory")
|
| 190 |
+
else:
|
| 191 |
+
parser = CustomArgumentParser(description="Model size estimator for fitting a model onto CUDA memory.")
|
| 192 |
+
|
| 193 |
+
parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.")
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--library_name",
|
| 196 |
+
type=str,
|
| 197 |
+
help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.",
|
| 198 |
+
choices=["timm", "transformers"],
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--dtypes",
|
| 202 |
+
type=str,
|
| 203 |
+
nargs="+",
|
| 204 |
+
default=["float32", "float16", "int8", "int4"],
|
| 205 |
+
help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`",
|
| 206 |
+
choices=["float32", "float16", "int8", "int4"],
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--trust_remote_code",
|
| 210 |
+
action="store_true",
|
| 211 |
+
help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag
|
| 212 |
+
should only be used for repositories you trust and in which you have read the code, as it will execute
|
| 213 |
+
code present on the Hub on your local machine.""",
|
| 214 |
+
default=False,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if subparsers is not None:
|
| 218 |
+
parser.set_defaults(func=estimate_command)
|
| 219 |
+
return parser
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: Optional[str] = None) -> dict:
|
| 223 |
+
"""
|
| 224 |
+
Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of
|
| 225 |
+
1.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
bytes (`int`):
|
| 229 |
+
The size of the model being trained.
|
| 230 |
+
mixed_precision (`str`):
|
| 231 |
+
The mixed precision that would be ran.
|
| 232 |
+
msamp_config (`str`):
|
| 233 |
+
The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`.
|
| 234 |
+
"""
|
| 235 |
+
memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1}
|
| 236 |
+
fp32_size = bytes
|
| 237 |
+
fp16_size = bytes // 2
|
| 238 |
+
|
| 239 |
+
if mixed_precision == "float32":
|
| 240 |
+
memory_sizes["model"] = fp32_size
|
| 241 |
+
memory_sizes["gradients"] = fp32_size
|
| 242 |
+
memory_sizes["optimizer"] = fp32_size * 2
|
| 243 |
+
memory_sizes["step"] = fp32_size * 4
|
| 244 |
+
elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None):
|
| 245 |
+
# With native `TransformersEngine`, there is no memory savings with FP8
|
| 246 |
+
# With mixed precision training, the model has weights stored
|
| 247 |
+
# in FP16 and FP32
|
| 248 |
+
memory_sizes["model"] = fp32_size
|
| 249 |
+
# 1.5 from weight gradient + computation (GEMM)
|
| 250 |
+
memory_sizes["gradients"] = fp32_size + fp16_size
|
| 251 |
+
# 2x from optimizer states
|
| 252 |
+
memory_sizes["optimizer"] = fp32_size * 2 # Optimizer states
|
| 253 |
+
memory_sizes["step"] = memory_sizes["optimizer"]
|
| 254 |
+
return memory_sizes
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def gather_data(args):
|
| 258 |
+
"Creates an empty model and gathers the data for the sizes"
|
| 259 |
+
try:
|
| 260 |
+
model = create_empty_model(
|
| 261 |
+
args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code
|
| 262 |
+
)
|
| 263 |
+
except (RuntimeError, OSError) as e:
|
| 264 |
+
library = check_has_model(e)
|
| 265 |
+
if library != "unknown":
|
| 266 |
+
raise RuntimeError(
|
| 267 |
+
f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo."
|
| 268 |
+
)
|
| 269 |
+
raise e
|
| 270 |
+
|
| 271 |
+
total_size, largest_layer = calculate_maximum_sizes(model)
|
| 272 |
+
|
| 273 |
+
data = []
|
| 274 |
+
|
| 275 |
+
for dtype in args.dtypes:
|
| 276 |
+
dtype_total_size = total_size
|
| 277 |
+
dtype_largest_layer = largest_layer[0]
|
| 278 |
+
dtype_training_size = estimate_training_usage(dtype_total_size, dtype)
|
| 279 |
+
if dtype == "float16":
|
| 280 |
+
dtype_total_size /= 2
|
| 281 |
+
dtype_largest_layer /= 2
|
| 282 |
+
elif dtype == "int8":
|
| 283 |
+
dtype_total_size /= 4
|
| 284 |
+
dtype_largest_layer /= 4
|
| 285 |
+
elif dtype == "int4":
|
| 286 |
+
dtype_total_size /= 8
|
| 287 |
+
dtype_largest_layer /= 8
|
| 288 |
+
data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size])
|
| 289 |
+
return data
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def estimate_command(args):
|
| 293 |
+
data = gather_data(args)
|
| 294 |
+
for row in data:
|
| 295 |
+
for i, item in enumerate(row):
|
| 296 |
+
if isinstance(item, (int, float)):
|
| 297 |
+
row[i] = convert_bytes(item)
|
| 298 |
+
elif isinstance(item, dict):
|
| 299 |
+
training_usage = max(item.values())
|
| 300 |
+
row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A"
|
| 301 |
+
|
| 302 |
+
headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"]
|
| 303 |
+
|
| 304 |
+
title = f"Memory Usage for loading `{args.model_name}`"
|
| 305 |
+
table = create_ascii_table(headers, data, title)
|
| 306 |
+
print(table)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def main():
|
| 310 |
+
parser = estimate_command_parser()
|
| 311 |
+
args = parser.parse_args()
|
| 312 |
+
estimate_command(args)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/launch.py
ADDED
|
@@ -0,0 +1,1291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import importlib
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
import subprocess
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import psutil
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
| 29 |
+
from accelerate.commands.config.config_args import SageMakerConfig
|
| 30 |
+
from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
|
| 31 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 32 |
+
from accelerate.state import get_int_from_env
|
| 33 |
+
from accelerate.utils import (
|
| 34 |
+
ComputeEnvironment,
|
| 35 |
+
DistributedType,
|
| 36 |
+
PrepareForLaunch,
|
| 37 |
+
_filter_args,
|
| 38 |
+
check_cuda_p2p_ib_support,
|
| 39 |
+
convert_dict_to_env_variables,
|
| 40 |
+
is_bf16_available,
|
| 41 |
+
is_deepspeed_available,
|
| 42 |
+
is_hpu_available,
|
| 43 |
+
is_mlu_available,
|
| 44 |
+
is_musa_available,
|
| 45 |
+
is_npu_available,
|
| 46 |
+
is_rich_available,
|
| 47 |
+
is_sagemaker_available,
|
| 48 |
+
is_sdaa_available,
|
| 49 |
+
is_torch_xla_available,
|
| 50 |
+
is_xpu_available,
|
| 51 |
+
patch_environment,
|
| 52 |
+
prepare_deepspeed_cmd_env,
|
| 53 |
+
prepare_multi_gpu_env,
|
| 54 |
+
prepare_sagemager_args_inputs,
|
| 55 |
+
prepare_simple_launcher_cmd_env,
|
| 56 |
+
prepare_tpu,
|
| 57 |
+
str_to_bool,
|
| 58 |
+
)
|
| 59 |
+
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if is_rich_available():
|
| 63 |
+
from rich import get_console
|
| 64 |
+
from rich.logging import RichHandler
|
| 65 |
+
|
| 66 |
+
FORMAT = "%(message)s"
|
| 67 |
+
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
options_to_group = {
|
| 74 |
+
"multi_gpu": "Distributed GPUs",
|
| 75 |
+
"tpu": "TPU",
|
| 76 |
+
"use_deepspeed": "DeepSpeed Arguments",
|
| 77 |
+
"use_fsdp": "FSDP Arguments",
|
| 78 |
+
"use_megatron_lm": "Megatron-LM Arguments",
|
| 79 |
+
"fp8_backend": "FP8 Arguments",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def clean_option(option):
|
| 84 |
+
"Finds all cases of - after the first two characters and changes them to _"
|
| 85 |
+
if "fp8_backend" in option:
|
| 86 |
+
option = "--fp8_backend"
|
| 87 |
+
if option.startswith("--"):
|
| 88 |
+
return option[2:].replace("-", "_")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CustomHelpFormatter(argparse.HelpFormatter):
|
| 92 |
+
"""
|
| 93 |
+
This is a custom help formatter that will hide all arguments that are not used in the command line when the help is
|
| 94 |
+
called. This is useful for the case where the user is using a specific platform and only wants to see the arguments
|
| 95 |
+
for that platform.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, *args, **kwargs):
|
| 99 |
+
super().__init__(*args, **kwargs)
|
| 100 |
+
self.titles = [
|
| 101 |
+
"Hardware Selection Arguments",
|
| 102 |
+
"Resource Selection Arguments",
|
| 103 |
+
"Training Paradigm Arguments",
|
| 104 |
+
"positional arguments",
|
| 105 |
+
"optional arguments",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
def add_argument(self, action: argparse.Action):
|
| 109 |
+
if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]:
|
| 110 |
+
args = sys.argv[2:]
|
| 111 |
+
else:
|
| 112 |
+
args = sys.argv[1:]
|
| 113 |
+
|
| 114 |
+
if len(args) > 1:
|
| 115 |
+
args = list(map(clean_option, args))
|
| 116 |
+
used_platforms = [arg for arg in args if arg in options_to_group.keys()]
|
| 117 |
+
used_titles = [options_to_group[o] for o in used_platforms]
|
| 118 |
+
if action.container.title not in self.titles + used_titles:
|
| 119 |
+
action.help = argparse.SUPPRESS
|
| 120 |
+
elif action.container.title == "Hardware Selection Arguments":
|
| 121 |
+
if set(action.option_strings).isdisjoint(set(args)):
|
| 122 |
+
action.help = argparse.SUPPRESS
|
| 123 |
+
else:
|
| 124 |
+
action.help = action.help + " (currently selected)"
|
| 125 |
+
elif action.container.title == "Training Paradigm Arguments":
|
| 126 |
+
if set(action.option_strings).isdisjoint(set(args)):
|
| 127 |
+
action.help = argparse.SUPPRESS
|
| 128 |
+
else:
|
| 129 |
+
action.help = action.help + " (currently selected)"
|
| 130 |
+
|
| 131 |
+
action.option_strings = [s for s in action.option_strings if "-" not in s[2:]]
|
| 132 |
+
super().add_argument(action)
|
| 133 |
+
|
| 134 |
+
def end_section(self):
|
| 135 |
+
if len(self._current_section.items) < 2:
|
| 136 |
+
self._current_section.items = []
|
| 137 |
+
self._current_section.heading = ""
|
| 138 |
+
super().end_section()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def launch_command_parser(subparsers=None):
|
| 142 |
+
description = "Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)"
|
| 143 |
+
if subparsers is not None:
|
| 144 |
+
parser = subparsers.add_parser(
|
| 145 |
+
"launch", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
parser = CustomArgumentParser(
|
| 149 |
+
"Accelerate launch command",
|
| 150 |
+
description=description,
|
| 151 |
+
add_help=False,
|
| 152 |
+
allow_abbrev=False,
|
| 153 |
+
formatter_class=CustomHelpFormatter,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.")
|
| 157 |
+
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--config_file",
|
| 160 |
+
default=None,
|
| 161 |
+
help="The config file to use for the default values in the launching script.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--quiet",
|
| 165 |
+
"-q",
|
| 166 |
+
action="store_true",
|
| 167 |
+
help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)",
|
| 168 |
+
)
|
| 169 |
+
# Hardware selection arguments
|
| 170 |
+
hardware_args = parser.add_argument_group(
|
| 171 |
+
"Hardware Selection Arguments", "Arguments for selecting the hardware to be used."
|
| 172 |
+
)
|
| 173 |
+
hardware_args.add_argument(
|
| 174 |
+
"--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU."
|
| 175 |
+
)
|
| 176 |
+
hardware_args.add_argument(
|
| 177 |
+
"--multi_gpu",
|
| 178 |
+
default=False,
|
| 179 |
+
action="store_true",
|
| 180 |
+
help="Whether or not this should launch a distributed GPU training.",
|
| 181 |
+
)
|
| 182 |
+
hardware_args.add_argument(
|
| 183 |
+
"--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
|
| 184 |
+
)
|
| 185 |
+
# Resource selection arguments
|
| 186 |
+
resource_args = parser.add_argument_group(
|
| 187 |
+
"Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
|
| 188 |
+
)
|
| 189 |
+
resource_args.add_argument(
|
| 190 |
+
"--mixed_precision",
|
| 191 |
+
type=str,
|
| 192 |
+
choices=["no", "fp16", "bf16", "fp8"],
|
| 193 |
+
help="Whether or not to use mixed precision training. "
|
| 194 |
+
"Choose between FP16 and BF16 (bfloat16) training. "
|
| 195 |
+
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
|
| 196 |
+
)
|
| 197 |
+
resource_args.add_argument(
|
| 198 |
+
"--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel."
|
| 199 |
+
)
|
| 200 |
+
resource_args.add_argument(
|
| 201 |
+
"--num_machines", type=int, default=None, help="The total number of machines used in this training."
|
| 202 |
+
)
|
| 203 |
+
resource_args.add_argument(
|
| 204 |
+
"--num_cpu_threads_per_process",
|
| 205 |
+
type=int,
|
| 206 |
+
default=None,
|
| 207 |
+
help="The number of CPU threads per process. Can be tuned for optimal performance.",
|
| 208 |
+
)
|
| 209 |
+
resource_args.add_argument(
|
| 210 |
+
"--enable_cpu_affinity",
|
| 211 |
+
default=False,
|
| 212 |
+
action="store_true",
|
| 213 |
+
help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.",
|
| 214 |
+
)
|
| 215 |
+
# Dynamo arguments
|
| 216 |
+
resource_args.add_argument(
|
| 217 |
+
"--dynamo_backend",
|
| 218 |
+
type=str,
|
| 219 |
+
choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS],
|
| 220 |
+
help="Choose a backend to optimize your training with dynamo, see more at "
|
| 221 |
+
"https://github.com/pytorch/torchdynamo.",
|
| 222 |
+
)
|
| 223 |
+
resource_args.add_argument(
|
| 224 |
+
"--dynamo_mode",
|
| 225 |
+
type=str,
|
| 226 |
+
default="default",
|
| 227 |
+
choices=TORCH_DYNAMO_MODES,
|
| 228 |
+
help="Choose a mode to optimize your training with dynamo.",
|
| 229 |
+
)
|
| 230 |
+
resource_args.add_argument(
|
| 231 |
+
"--dynamo_use_fullgraph",
|
| 232 |
+
default=False,
|
| 233 |
+
action="store_true",
|
| 234 |
+
help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
|
| 235 |
+
)
|
| 236 |
+
resource_args.add_argument(
|
| 237 |
+
"--dynamo_use_dynamic",
|
| 238 |
+
default=False,
|
| 239 |
+
action="store_true",
|
| 240 |
+
help="Whether to enable dynamic shape tracing.",
|
| 241 |
+
)
|
| 242 |
+
resource_args.add_argument(
|
| 243 |
+
"--dynamo_use_regional_compilation",
|
| 244 |
+
default=False,
|
| 245 |
+
action="store_true",
|
| 246 |
+
help="Whether to enable regional compilation.",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Training Paradigm arguments
|
| 250 |
+
paradigm_args = parser.add_argument_group(
|
| 251 |
+
"Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used."
|
| 252 |
+
)
|
| 253 |
+
paradigm_args.add_argument(
|
| 254 |
+
"--use_deepspeed",
|
| 255 |
+
default=False,
|
| 256 |
+
action="store_true",
|
| 257 |
+
help="Whether to use deepspeed.",
|
| 258 |
+
)
|
| 259 |
+
paradigm_args.add_argument(
|
| 260 |
+
"--use_fsdp",
|
| 261 |
+
default=False,
|
| 262 |
+
action="store_true",
|
| 263 |
+
help="Whether to use fsdp.",
|
| 264 |
+
)
|
| 265 |
+
paradigm_args.add_argument(
|
| 266 |
+
"--use_parallelism_config",
|
| 267 |
+
default=False,
|
| 268 |
+
action="store_true",
|
| 269 |
+
help="Whether to use the parallelism config to configure the N-d distributed training.",
|
| 270 |
+
)
|
| 271 |
+
paradigm_args.add_argument(
|
| 272 |
+
"--use_megatron_lm",
|
| 273 |
+
default=False,
|
| 274 |
+
action="store_true",
|
| 275 |
+
help="Whether to use Megatron-LM.",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
paradigm_args.add_argument(
|
| 279 |
+
"--use_xpu",
|
| 280 |
+
default=None,
|
| 281 |
+
action="store_true",
|
| 282 |
+
help="Whether to use IPEX plugin to speed up training on XPU specifically. This argument is deprecated and ignored, will be removed in Accelerate v1.20.",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# distributed GPU training arguments
|
| 286 |
+
distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.")
|
| 287 |
+
distributed_args.add_argument(
|
| 288 |
+
"--gpu_ids",
|
| 289 |
+
default=None,
|
| 290 |
+
help="What GPUs (by id) should be used for training on this machine as a comma-separated list",
|
| 291 |
+
)
|
| 292 |
+
distributed_args.add_argument(
|
| 293 |
+
"--same_network",
|
| 294 |
+
default=False,
|
| 295 |
+
action="store_true",
|
| 296 |
+
help="Whether all machines used for multinode training exist on the same local network.",
|
| 297 |
+
)
|
| 298 |
+
distributed_args.add_argument(
|
| 299 |
+
"--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched."
|
| 300 |
+
)
|
| 301 |
+
distributed_args.add_argument(
|
| 302 |
+
"--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0."
|
| 303 |
+
)
|
| 304 |
+
distributed_args.add_argument(
|
| 305 |
+
"--main_process_port",
|
| 306 |
+
type=int,
|
| 307 |
+
default=None,
|
| 308 |
+
help="The port to use to communicate with the machine of rank 0.",
|
| 309 |
+
)
|
| 310 |
+
distributed_args.add_argument(
|
| 311 |
+
"-t",
|
| 312 |
+
"--tee",
|
| 313 |
+
default="0",
|
| 314 |
+
type=str,
|
| 315 |
+
help="Tee std streams into a log file and also to console.",
|
| 316 |
+
)
|
| 317 |
+
distributed_args.add_argument(
|
| 318 |
+
"--log_dir",
|
| 319 |
+
type=str,
|
| 320 |
+
default=None,
|
| 321 |
+
help=(
|
| 322 |
+
"Base directory to use for log files when using torchrun/torch.distributed.run as launcher. "
|
| 323 |
+
"Use with --tee to redirect std streams info log files."
|
| 324 |
+
),
|
| 325 |
+
)
|
| 326 |
+
distributed_args.add_argument(
|
| 327 |
+
"--role",
|
| 328 |
+
type=str,
|
| 329 |
+
default="default",
|
| 330 |
+
help="User-defined role for the workers.",
|
| 331 |
+
)
|
| 332 |
+
# Rendezvous related arguments
|
| 333 |
+
distributed_args.add_argument(
|
| 334 |
+
"--rdzv_backend",
|
| 335 |
+
type=str,
|
| 336 |
+
default="static",
|
| 337 |
+
help="The rendezvous method to use, such as 'static' (the default) or 'c10d'",
|
| 338 |
+
)
|
| 339 |
+
distributed_args.add_argument(
|
| 340 |
+
"--rdzv_conf",
|
| 341 |
+
type=str,
|
| 342 |
+
default="",
|
| 343 |
+
help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
|
| 344 |
+
)
|
| 345 |
+
distributed_args.add_argument(
|
| 346 |
+
"--max_restarts",
|
| 347 |
+
type=int,
|
| 348 |
+
default=0,
|
| 349 |
+
help="Maximum number of worker group restarts before failing.",
|
| 350 |
+
)
|
| 351 |
+
distributed_args.add_argument(
|
| 352 |
+
"--monitor_interval",
|
| 353 |
+
type=float,
|
| 354 |
+
default=0.1,
|
| 355 |
+
help="Interval, in seconds, to monitor the state of workers.",
|
| 356 |
+
)
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"-m",
|
| 359 |
+
"--module",
|
| 360 |
+
action="store_true",
|
| 361 |
+
help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.",
|
| 362 |
+
)
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--no_python",
|
| 365 |
+
action="store_true",
|
| 366 |
+
help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.",
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# TPU arguments
|
| 370 |
+
tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.")
|
| 371 |
+
tpu_args.add_argument(
|
| 372 |
+
"--tpu_cluster",
|
| 373 |
+
action="store_true",
|
| 374 |
+
dest="tpu_use_cluster",
|
| 375 |
+
help="Whether to use a GCP TPU pod for training.",
|
| 376 |
+
)
|
| 377 |
+
tpu_args.add_argument(
|
| 378 |
+
"--no_tpu_cluster",
|
| 379 |
+
action="store_false",
|
| 380 |
+
dest="tpu_use_cluster",
|
| 381 |
+
help="Should not be passed explicitly, this is for internal use only.",
|
| 382 |
+
)
|
| 383 |
+
tpu_args.add_argument(
|
| 384 |
+
"--tpu_use_sudo",
|
| 385 |
+
action="store_true",
|
| 386 |
+
help="Whether to use `sudo` when running the TPU training script in each pod.",
|
| 387 |
+
)
|
| 388 |
+
tpu_args.add_argument(
|
| 389 |
+
"--vm",
|
| 390 |
+
type=str,
|
| 391 |
+
action="append",
|
| 392 |
+
help=(
|
| 393 |
+
"List of single Compute VM instance names. "
|
| 394 |
+
"If not provided we assume usage of instance groups. For TPU pods."
|
| 395 |
+
),
|
| 396 |
+
)
|
| 397 |
+
tpu_args.add_argument(
|
| 398 |
+
"--env",
|
| 399 |
+
type=str,
|
| 400 |
+
action="append",
|
| 401 |
+
help="List of environment variables to set on the Compute VM instances. For TPU pods.",
|
| 402 |
+
)
|
| 403 |
+
tpu_args.add_argument(
|
| 404 |
+
"--main_training_function",
|
| 405 |
+
type=str,
|
| 406 |
+
default=None,
|
| 407 |
+
help="The name of the main function to be executed in your script (only for TPU training).",
|
| 408 |
+
)
|
| 409 |
+
tpu_args.add_argument(
|
| 410 |
+
"--downcast_bf16",
|
| 411 |
+
action="store_true",
|
| 412 |
+
help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.",
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# DeepSpeed arguments
|
| 416 |
+
deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.")
|
| 417 |
+
deepspeed_args.add_argument(
|
| 418 |
+
"--deepspeed_config_file",
|
| 419 |
+
default=None,
|
| 420 |
+
type=str,
|
| 421 |
+
help="DeepSpeed config file.",
|
| 422 |
+
)
|
| 423 |
+
deepspeed_args.add_argument(
|
| 424 |
+
"--zero_stage",
|
| 425 |
+
default=None,
|
| 426 |
+
type=int,
|
| 427 |
+
help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). "
|
| 428 |
+
"If unspecified, will default to `2`.",
|
| 429 |
+
)
|
| 430 |
+
deepspeed_args.add_argument(
|
| 431 |
+
"--offload_optimizer_device",
|
| 432 |
+
default=None,
|
| 433 |
+
type=str,
|
| 434 |
+
help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
|
| 435 |
+
"If unspecified, will default to 'none'.",
|
| 436 |
+
)
|
| 437 |
+
deepspeed_args.add_argument(
|
| 438 |
+
"--offload_param_device",
|
| 439 |
+
default=None,
|
| 440 |
+
type=str,
|
| 441 |
+
help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). "
|
| 442 |
+
"If unspecified, will default to 'none'.",
|
| 443 |
+
)
|
| 444 |
+
deepspeed_args.add_argument(
|
| 445 |
+
"--offload_optimizer_nvme_path",
|
| 446 |
+
default=None,
|
| 447 |
+
type=str,
|
| 448 |
+
help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
|
| 449 |
+
"If unspecified, will default to 'none'.",
|
| 450 |
+
)
|
| 451 |
+
deepspeed_args.add_argument(
|
| 452 |
+
"--offload_param_nvme_path",
|
| 453 |
+
default=None,
|
| 454 |
+
type=str,
|
| 455 |
+
help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). "
|
| 456 |
+
"If unspecified, will default to 'none'.",
|
| 457 |
+
)
|
| 458 |
+
deepspeed_args.add_argument(
|
| 459 |
+
"--gradient_accumulation_steps",
|
| 460 |
+
default=None,
|
| 461 |
+
type=int,
|
| 462 |
+
help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). "
|
| 463 |
+
"If unspecified, will default to `1`.",
|
| 464 |
+
)
|
| 465 |
+
deepspeed_args.add_argument(
|
| 466 |
+
"--gradient_clipping",
|
| 467 |
+
default=None,
|
| 468 |
+
type=float,
|
| 469 |
+
help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). "
|
| 470 |
+
"If unspecified, will default to `1.0`.",
|
| 471 |
+
)
|
| 472 |
+
deepspeed_args.add_argument(
|
| 473 |
+
"--zero3_init_flag",
|
| 474 |
+
default=None,
|
| 475 |
+
type=str,
|
| 476 |
+
help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. "
|
| 477 |
+
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.",
|
| 478 |
+
)
|
| 479 |
+
deepspeed_args.add_argument(
|
| 480 |
+
"--zero3_save_16bit_model",
|
| 481 |
+
default=None,
|
| 482 |
+
type=str,
|
| 483 |
+
help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
|
| 484 |
+
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.",
|
| 485 |
+
)
|
| 486 |
+
deepspeed_args.add_argument(
|
| 487 |
+
"--deepspeed_hostfile",
|
| 488 |
+
default=None,
|
| 489 |
+
type=str,
|
| 490 |
+
help="DeepSpeed hostfile for configuring multi-node compute resources.",
|
| 491 |
+
)
|
| 492 |
+
deepspeed_args.add_argument(
|
| 493 |
+
"--deepspeed_exclusion_filter",
|
| 494 |
+
default=None,
|
| 495 |
+
type=str,
|
| 496 |
+
help="DeepSpeed exclusion filter string when using multi-node setup.",
|
| 497 |
+
)
|
| 498 |
+
deepspeed_args.add_argument(
|
| 499 |
+
"--deepspeed_inclusion_filter",
|
| 500 |
+
default=None,
|
| 501 |
+
type=str,
|
| 502 |
+
help="DeepSpeed inclusion filter string when using multi-node setup.",
|
| 503 |
+
)
|
| 504 |
+
deepspeed_args.add_argument(
|
| 505 |
+
"--deepspeed_multinode_launcher",
|
| 506 |
+
default=None,
|
| 507 |
+
type=str,
|
| 508 |
+
help="DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.",
|
| 509 |
+
)
|
| 510 |
+
deepspeed_args.add_argument(
|
| 511 |
+
"--deepspeed_moe_layer_cls_names",
|
| 512 |
+
default=None,
|
| 513 |
+
type=str,
|
| 514 |
+
help="comma-separated list of transformer MoE layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..."
|
| 515 |
+
" (useful only when `use_deepspeed` flag is passed).",
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# fsdp arguments
|
| 519 |
+
fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.")
|
| 520 |
+
fsdp_args.add_argument(
|
| 521 |
+
"--fsdp_version",
|
| 522 |
+
type=str,
|
| 523 |
+
default="1",
|
| 524 |
+
choices=["1", "2"],
|
| 525 |
+
help="FSDP version to use. (useful only when `use_fsdp` flag is passed).",
|
| 526 |
+
)
|
| 527 |
+
fsdp_args.add_argument(
|
| 528 |
+
"--fsdp_offload_params",
|
| 529 |
+
default="false",
|
| 530 |
+
type=str,
|
| 531 |
+
help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).",
|
| 532 |
+
)
|
| 533 |
+
fsdp_args.add_argument(
|
| 534 |
+
"--fsdp_min_num_params",
|
| 535 |
+
type=int,
|
| 536 |
+
default=1e8,
|
| 537 |
+
help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).",
|
| 538 |
+
)
|
| 539 |
+
# We enable this for backwards compatibility, throw a warning if this is set in `FullyShardedDataParallelPlugin`
|
| 540 |
+
fsdp_args.add_argument(
|
| 541 |
+
"--fsdp_sharding_strategy",
|
| 542 |
+
type=str,
|
| 543 |
+
default="FULL_SHARD",
|
| 544 |
+
help="FSDP's sharding strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version=1`).",
|
| 545 |
+
)
|
| 546 |
+
fsdp_args.add_argument(
|
| 547 |
+
"--fsdp_reshard_after_forward",
|
| 548 |
+
type=str,
|
| 549 |
+
default="true",
|
| 550 |
+
help="FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).",
|
| 551 |
+
)
|
| 552 |
+
fsdp_args.add_argument(
|
| 553 |
+
"--fsdp_auto_wrap_policy",
|
| 554 |
+
type=str,
|
| 555 |
+
default=None,
|
| 556 |
+
help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).",
|
| 557 |
+
)
|
| 558 |
+
fsdp_args.add_argument(
|
| 559 |
+
"--fsdp_transformer_layer_cls_to_wrap",
|
| 560 |
+
default=None,
|
| 561 |
+
type=str,
|
| 562 |
+
help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
|
| 563 |
+
"(useful only when `use_fsdp` flag is passed).",
|
| 564 |
+
)
|
| 565 |
+
fsdp_args.add_argument(
|
| 566 |
+
"--fsdp_backward_prefetch",
|
| 567 |
+
default=None,
|
| 568 |
+
type=str,
|
| 569 |
+
help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
|
| 570 |
+
)
|
| 571 |
+
fsdp_args.add_argument(
|
| 572 |
+
"--fsdp_state_dict_type",
|
| 573 |
+
default=None,
|
| 574 |
+
type=str,
|
| 575 |
+
help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).",
|
| 576 |
+
)
|
| 577 |
+
fsdp_args.add_argument(
|
| 578 |
+
"--fsdp_forward_prefetch",
|
| 579 |
+
default="false",
|
| 580 |
+
type=str,
|
| 581 |
+
help="If True, then FSDP explicitly prefetches the next upcoming "
|
| 582 |
+
"all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).",
|
| 583 |
+
)
|
| 584 |
+
fsdp_args.add_argument(
|
| 585 |
+
"--fsdp_use_orig_params",
|
| 586 |
+
default="true",
|
| 587 |
+
type=str,
|
| 588 |
+
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters."
|
| 589 |
+
" (useful only when `use_fsdp` flag is passed).",
|
| 590 |
+
)
|
| 591 |
+
fsdp_args.add_argument(
|
| 592 |
+
"--fsdp_cpu_ram_efficient_loading",
|
| 593 |
+
default="true",
|
| 594 |
+
type=str,
|
| 595 |
+
help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
|
| 596 |
+
"Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. "
|
| 597 |
+
"(useful only when `use_fsdp` flag is passed).",
|
| 598 |
+
)
|
| 599 |
+
fsdp_args.add_argument(
|
| 600 |
+
"--fsdp_sync_module_states",
|
| 601 |
+
default="true",
|
| 602 |
+
type=str,
|
| 603 |
+
help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
|
| 604 |
+
" (useful only when `use_fsdp` flag is passed).",
|
| 605 |
+
)
|
| 606 |
+
fsdp_args.add_argument(
|
| 607 |
+
"--fsdp_activation_checkpointing",
|
| 608 |
+
default="false",
|
| 609 |
+
type=str,
|
| 610 |
+
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# megatron_lm args
|
| 614 |
+
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
|
| 615 |
+
megatron_lm_args.add_argument(
|
| 616 |
+
"--megatron_lm_tp_degree",
|
| 617 |
+
type=int,
|
| 618 |
+
default=1,
|
| 619 |
+
help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).",
|
| 620 |
+
)
|
| 621 |
+
megatron_lm_args.add_argument(
|
| 622 |
+
"--megatron_lm_pp_degree",
|
| 623 |
+
type=int,
|
| 624 |
+
default=1,
|
| 625 |
+
help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).",
|
| 626 |
+
)
|
| 627 |
+
megatron_lm_args.add_argument(
|
| 628 |
+
"--megatron_lm_num_micro_batches",
|
| 629 |
+
type=int,
|
| 630 |
+
default=None,
|
| 631 |
+
help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).",
|
| 632 |
+
)
|
| 633 |
+
megatron_lm_args.add_argument(
|
| 634 |
+
"--megatron_lm_sequence_parallelism",
|
| 635 |
+
default=None,
|
| 636 |
+
type=str,
|
| 637 |
+
help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. "
|
| 638 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 639 |
+
)
|
| 640 |
+
megatron_lm_args.add_argument(
|
| 641 |
+
"--megatron_lm_recompute_activations",
|
| 642 |
+
default=None,
|
| 643 |
+
type=str,
|
| 644 |
+
help="Decides Whether (true|false) to enable Selective Activation Recomputation. "
|
| 645 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 646 |
+
)
|
| 647 |
+
megatron_lm_args.add_argument(
|
| 648 |
+
"--megatron_lm_use_distributed_optimizer",
|
| 649 |
+
default=None,
|
| 650 |
+
type=str,
|
| 651 |
+
help="Decides Whether (true|false) to use distributed optimizer "
|
| 652 |
+
"which shards optimizer state and gradients across Data Pralellel (DP) ranks. "
|
| 653 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 654 |
+
)
|
| 655 |
+
megatron_lm_args.add_argument(
|
| 656 |
+
"--megatron_lm_gradient_clipping",
|
| 657 |
+
default=1.0,
|
| 658 |
+
type=float,
|
| 659 |
+
help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). "
|
| 660 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# FP8 arguments
|
| 664 |
+
fp8_args = parser.add_argument_group(
|
| 665 |
+
"FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)"
|
| 666 |
+
)
|
| 667 |
+
fp8_args.add_argument(
|
| 668 |
+
"--fp8_backend",
|
| 669 |
+
type=str,
|
| 670 |
+
choices=["te", "msamp"],
|
| 671 |
+
help="Choose a backend to train with FP8 (te: TransformerEngine, msamp: MS-AMP)",
|
| 672 |
+
)
|
| 673 |
+
fp8_args.add_argument(
|
| 674 |
+
"--fp8_use_autocast_during_eval",
|
| 675 |
+
default=False,
|
| 676 |
+
action="store_true",
|
| 677 |
+
help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.",
|
| 678 |
+
)
|
| 679 |
+
fp8_args.add_argument(
|
| 680 |
+
"--fp8_margin",
|
| 681 |
+
type=int,
|
| 682 |
+
default=0,
|
| 683 |
+
help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).",
|
| 684 |
+
)
|
| 685 |
+
fp8_args.add_argument(
|
| 686 |
+
"--fp8_interval",
|
| 687 |
+
type=int,
|
| 688 |
+
default=1,
|
| 689 |
+
help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).",
|
| 690 |
+
)
|
| 691 |
+
fp8_args.add_argument(
|
| 692 |
+
"--fp8_format",
|
| 693 |
+
type=str,
|
| 694 |
+
default="HYBRID",
|
| 695 |
+
choices=["HYBRID", "E4M3", "E5M2"],
|
| 696 |
+
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
|
| 697 |
+
)
|
| 698 |
+
fp8_args.add_argument(
|
| 699 |
+
"--fp8_amax_history_len",
|
| 700 |
+
type=int,
|
| 701 |
+
default=1024,
|
| 702 |
+
help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).",
|
| 703 |
+
)
|
| 704 |
+
fp8_args.add_argument(
|
| 705 |
+
"--fp8_amax_compute_algo",
|
| 706 |
+
type=str,
|
| 707 |
+
default="most_recent",
|
| 708 |
+
choices=["max", "most_recent"],
|
| 709 |
+
help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).",
|
| 710 |
+
)
|
| 711 |
+
fp8_args.add_argument(
|
| 712 |
+
"--fp8_override_linear_precision",
|
| 713 |
+
type=lambda x: tuple(map(str_to_bool, x.split(","))),
|
| 714 |
+
default=(False, False, False),
|
| 715 |
+
help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-separated string of booleans (useful only when `--fp8_backend=te` is passed).",
|
| 716 |
+
)
|
| 717 |
+
fp8_args.add_argument(
|
| 718 |
+
"--fp8_opt_level",
|
| 719 |
+
type=str,
|
| 720 |
+
default="O2",
|
| 721 |
+
choices=["O1", "O2"],
|
| 722 |
+
help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
# AWS arguments
|
| 726 |
+
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
|
| 727 |
+
aws_args.add_argument(
|
| 728 |
+
"--aws_access_key_id",
|
| 729 |
+
type=str,
|
| 730 |
+
default=None,
|
| 731 |
+
help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job",
|
| 732 |
+
)
|
| 733 |
+
aws_args.add_argument(
|
| 734 |
+
"--aws_secret_access_key",
|
| 735 |
+
type=str,
|
| 736 |
+
default=None,
|
| 737 |
+
help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.",
|
| 738 |
+
)
|
| 739 |
+
parser.add_argument(
|
| 740 |
+
"--debug",
|
| 741 |
+
action="store_true",
|
| 742 |
+
help="Whether to print out the torch.distributed stack trace when something fails.",
|
| 743 |
+
)
|
| 744 |
+
parser.add_argument(
|
| 745 |
+
"training_script",
|
| 746 |
+
type=str,
|
| 747 |
+
help=(
|
| 748 |
+
"The full path to the script to be launched in parallel, followed by all the arguments for the training "
|
| 749 |
+
"script."
|
| 750 |
+
),
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# MPI arguments
|
| 754 |
+
mpirun_args = parser.add_argument_group("MPI Arguments", "Arguments related to mpirun for Multi-CPU")
|
| 755 |
+
mpirun_args.add_argument(
|
| 756 |
+
"--mpirun_hostfile",
|
| 757 |
+
type=str,
|
| 758 |
+
default=None,
|
| 759 |
+
help="Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will "
|
| 760 |
+
"get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.",
|
| 761 |
+
)
|
| 762 |
+
mpirun_args.add_argument(
|
| 763 |
+
"--mpirun_ccl",
|
| 764 |
+
type=int,
|
| 765 |
+
default=1,
|
| 766 |
+
help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.",
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# ParallelismConfig arguments
|
| 770 |
+
parallelism_config_args = parser.add_argument_group(
|
| 771 |
+
"ParallelismConfig Arguments",
|
| 772 |
+
"Arguments related to the ParallelismConfig used for distributed training.",
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
parallelism_config_args.add_argument(
|
| 776 |
+
"--parallelism_config_dp_replicate_size",
|
| 777 |
+
type=int,
|
| 778 |
+
default=1,
|
| 779 |
+
help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
parallelism_config_args.add_argument(
|
| 783 |
+
"--parallelism_config_dp_shard_size",
|
| 784 |
+
type=int,
|
| 785 |
+
default=1,
|
| 786 |
+
help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
parallelism_config_args.add_argument(
|
| 790 |
+
"--parallelism_config_tp_size",
|
| 791 |
+
type=int,
|
| 792 |
+
default=1,
|
| 793 |
+
help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
parallelism_config_args.add_argument(
|
| 797 |
+
"--parallelism_config_cp_size",
|
| 798 |
+
type=int,
|
| 799 |
+
default=1,
|
| 800 |
+
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
parallelism_config_args.add_argument(
|
| 804 |
+
"--parallelism_config_cp_backend",
|
| 805 |
+
type=str,
|
| 806 |
+
choices=["torch"],
|
| 807 |
+
default="torch",
|
| 808 |
+
help="Context Parallelism backend: torch (FSDP2) or deepspeed (ALST/Ulysses)",
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
parallelism_config_args.add_argument(
|
| 812 |
+
"--parallelism_config_cp_comm_strategy",
|
| 813 |
+
type=str,
|
| 814 |
+
default="allgather",
|
| 815 |
+
help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
parallelism_config_args.add_argument(
|
| 819 |
+
"--parallelism_config_sp_size",
|
| 820 |
+
type=int,
|
| 821 |
+
default=1,
|
| 822 |
+
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
parallelism_config_args.add_argument(
|
| 826 |
+
"--parallelism_config_sp_backend",
|
| 827 |
+
type=str,
|
| 828 |
+
choices=["deepspeed"],
|
| 829 |
+
default="deepspeed",
|
| 830 |
+
help="Sequence Parallelism backend: deepspeed (ALST/Ulysses)",
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
parallelism_config_args.add_argument(
|
| 834 |
+
"--parallelism_config_sp_seq_length",
|
| 835 |
+
type=str,
|
| 836 |
+
default=None,
|
| 837 |
+
help="Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `parallelism_config_sp_seq_length_is_variable=True`",
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
parallelism_config_args.add_argument(
|
| 841 |
+
"--parallelism_config_sp_seq_length_is_variable",
|
| 842 |
+
type=bool,
|
| 843 |
+
default=True,
|
| 844 |
+
help="If `True` will work with a sequence length that may change between batches, in which case `parallelism_config_sp_seq_length` value can be set to anything divisible by sp size or remain unset. If `False` then `parallelism_config_sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`.",
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
parallelism_config_args.add_argument(
|
| 848 |
+
"--parallelism_config_sp_attn_implementation",
|
| 849 |
+
type=str,
|
| 850 |
+
default="sdpa",
|
| 851 |
+
help="Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`.",
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
# Other arguments of the training scripts
|
| 855 |
+
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
|
| 856 |
+
|
| 857 |
+
if subparsers is not None:
|
| 858 |
+
parser.set_defaults(func=launch_command)
|
| 859 |
+
return parser
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
def simple_launcher(args):
|
| 863 |
+
cmd, current_env = prepare_simple_launcher_cmd_env(args)
|
| 864 |
+
|
| 865 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 866 |
+
process.wait()
|
| 867 |
+
if process.returncode != 0:
|
| 868 |
+
if not args.quiet:
|
| 869 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
| 870 |
+
else:
|
| 871 |
+
sys.exit(1)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def multi_gpu_launcher(args):
|
| 875 |
+
import torch.distributed.run as distrib_run
|
| 876 |
+
|
| 877 |
+
current_env = prepare_multi_gpu_env(args)
|
| 878 |
+
if not check_cuda_p2p_ib_support():
|
| 879 |
+
message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
|
| 880 |
+
warn = False
|
| 881 |
+
if "NCCL_P2P_DISABLE" not in current_env:
|
| 882 |
+
current_env["NCCL_P2P_DISABLE"] = "1"
|
| 883 |
+
warn = True
|
| 884 |
+
if "NCCL_IB_DISABLE" not in current_env:
|
| 885 |
+
current_env["NCCL_IB_DISABLE"] = "1"
|
| 886 |
+
warn = True
|
| 887 |
+
if warn:
|
| 888 |
+
logger.warning(message)
|
| 889 |
+
|
| 890 |
+
debug = getattr(args, "debug", False)
|
| 891 |
+
args = _filter_args(
|
| 892 |
+
args,
|
| 893 |
+
distrib_run.get_args_parser(),
|
| 894 |
+
["--training_script", args.training_script, "--training_script_args", args.training_script_args],
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
with patch_environment(**current_env):
|
| 898 |
+
try:
|
| 899 |
+
distrib_run.run(args)
|
| 900 |
+
except Exception:
|
| 901 |
+
if is_rich_available() and debug:
|
| 902 |
+
console = get_console()
|
| 903 |
+
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
| 904 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 905 |
+
else:
|
| 906 |
+
raise
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def deepspeed_launcher(args):
|
| 910 |
+
import torch.distributed.run as distrib_run
|
| 911 |
+
|
| 912 |
+
if not is_deepspeed_available():
|
| 913 |
+
raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
|
| 914 |
+
else:
|
| 915 |
+
from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME
|
| 916 |
+
|
| 917 |
+
cmd, current_env = prepare_deepspeed_cmd_env(args)
|
| 918 |
+
if not check_cuda_p2p_ib_support():
|
| 919 |
+
message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
|
| 920 |
+
warn = False
|
| 921 |
+
if "NCCL_P2P_DISABLE" not in current_env:
|
| 922 |
+
current_env["NCCL_P2P_DISABLE"] = "1"
|
| 923 |
+
warn = True
|
| 924 |
+
if "NCCL_IB_DISABLE" not in current_env:
|
| 925 |
+
current_env["NCCL_IB_DISABLE"] = "1"
|
| 926 |
+
warn = True
|
| 927 |
+
if warn:
|
| 928 |
+
logger.warning(message)
|
| 929 |
+
|
| 930 |
+
if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
|
| 931 |
+
with open(DEEPSPEED_ENVIRONMENT_NAME, "a") as f:
|
| 932 |
+
valid_env_items = convert_dict_to_env_variables(current_env)
|
| 933 |
+
if len(valid_env_items) > 1:
|
| 934 |
+
f.writelines(valid_env_items)
|
| 935 |
+
|
| 936 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 937 |
+
process.wait()
|
| 938 |
+
if process.returncode != 0:
|
| 939 |
+
if not args.quiet:
|
| 940 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
| 941 |
+
else:
|
| 942 |
+
sys.exit(1)
|
| 943 |
+
else:
|
| 944 |
+
debug = getattr(args, "debug", False)
|
| 945 |
+
args = _filter_args(
|
| 946 |
+
args,
|
| 947 |
+
distrib_run.get_args_parser(),
|
| 948 |
+
["--training_script", args.training_script, "--training_script_args", args.training_script_args],
|
| 949 |
+
)
|
| 950 |
+
with patch_environment(**current_env):
|
| 951 |
+
try:
|
| 952 |
+
distrib_run.run(args)
|
| 953 |
+
except Exception:
|
| 954 |
+
if is_rich_available() and debug:
|
| 955 |
+
console = get_console()
|
| 956 |
+
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
| 957 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 958 |
+
else:
|
| 959 |
+
raise
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def tpu_launcher(args):
|
| 963 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
| 964 |
+
|
| 965 |
+
if args.no_python:
|
| 966 |
+
raise ValueError("--no_python cannot be used with TPU launcher")
|
| 967 |
+
|
| 968 |
+
args, current_env = prepare_tpu(args, {})
|
| 969 |
+
|
| 970 |
+
if args.module:
|
| 971 |
+
mod_name = args.training_script
|
| 972 |
+
else:
|
| 973 |
+
# Import training_script as a module
|
| 974 |
+
script_path = Path(args.training_script)
|
| 975 |
+
sys.path.append(str(script_path.parent.resolve()))
|
| 976 |
+
mod_name = script_path.stem
|
| 977 |
+
|
| 978 |
+
mod = importlib.import_module(mod_name)
|
| 979 |
+
if not hasattr(mod, args.main_training_function):
|
| 980 |
+
raise ValueError(
|
| 981 |
+
f"Your training script should have a function named {args.main_training_function}, or you should pass a "
|
| 982 |
+
"different value to `--main_training_function`."
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
# Patch sys.argv
|
| 986 |
+
sys.argv = [mod.__file__] + args.training_script_args
|
| 987 |
+
|
| 988 |
+
main_function = getattr(mod, args.main_training_function)
|
| 989 |
+
with patch_environment(**current_env):
|
| 990 |
+
xmp.spawn(PrepareForLaunch(main_function), args=())
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
def tpu_pod_launcher(args):
|
| 994 |
+
from torch_xla.distributed import xla_dist
|
| 995 |
+
|
| 996 |
+
current_env = {}
|
| 997 |
+
args, current_env = prepare_tpu(args, current_env, True)
|
| 998 |
+
debug = getattr(args, "debug", False)
|
| 999 |
+
|
| 1000 |
+
training_script = args.training_script
|
| 1001 |
+
training_script_args = args.training_script_args
|
| 1002 |
+
new_args = _filter_args(
|
| 1003 |
+
args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"]
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
if args.tpu_use_sudo:
|
| 1007 |
+
new_cmd = ["sudo"]
|
| 1008 |
+
else:
|
| 1009 |
+
new_cmd = []
|
| 1010 |
+
|
| 1011 |
+
new_cmd += [
|
| 1012 |
+
"accelerate-launch",
|
| 1013 |
+
"--tpu",
|
| 1014 |
+
"--no_tpu_cluster",
|
| 1015 |
+
"--num_machines",
|
| 1016 |
+
"1",
|
| 1017 |
+
"--mixed_precision",
|
| 1018 |
+
"no",
|
| 1019 |
+
"--dynamo_backend",
|
| 1020 |
+
"no",
|
| 1021 |
+
"--num_processes",
|
| 1022 |
+
str(args.num_processes),
|
| 1023 |
+
"--main_training_function",
|
| 1024 |
+
str(args.main_training_function),
|
| 1025 |
+
training_script,
|
| 1026 |
+
] + training_script_args
|
| 1027 |
+
|
| 1028 |
+
new_args.positional = new_cmd
|
| 1029 |
+
bad_flags = ""
|
| 1030 |
+
for arg in vars(new_args):
|
| 1031 |
+
if arg.startswith("docker_"):
|
| 1032 |
+
value = getattr(new_args, arg)
|
| 1033 |
+
if value != "" and value is not None:
|
| 1034 |
+
bad_flags += f'{arg}="{value}"\n'
|
| 1035 |
+
if bad_flags != "":
|
| 1036 |
+
raise ValueError(
|
| 1037 |
+
f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}"
|
| 1038 |
+
)
|
| 1039 |
+
new_args.env = [f"{k}={v}" for k, v in current_env.items()]
|
| 1040 |
+
new_args.env.append("ACCELERATE_IN_TPU_POD=1")
|
| 1041 |
+
try:
|
| 1042 |
+
xla_dist.resolve_and_execute(new_args)
|
| 1043 |
+
except Exception:
|
| 1044 |
+
if is_rich_available() and debug:
|
| 1045 |
+
console = get_console()
|
| 1046 |
+
console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]")
|
| 1047 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 1048 |
+
else:
|
| 1049 |
+
raise
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
|
| 1053 |
+
if not is_sagemaker_available():
|
| 1054 |
+
raise ImportError(
|
| 1055 |
+
"Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`"
|
| 1056 |
+
)
|
| 1057 |
+
if args.module or args.no_python:
|
| 1058 |
+
raise ValueError(
|
| 1059 |
+
"SageMaker requires a python training script file and cannot be used with --module or --no_python"
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
from sagemaker.huggingface import HuggingFace
|
| 1063 |
+
|
| 1064 |
+
args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args)
|
| 1065 |
+
|
| 1066 |
+
huggingface_estimator = HuggingFace(**args)
|
| 1067 |
+
|
| 1068 |
+
huggingface_estimator.fit(inputs=sagemaker_inputs)
|
| 1069 |
+
print(f"You can find your model data at: {huggingface_estimator.model_data}")
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
def _validate_launch_command(args):
|
| 1073 |
+
# Sanity checks
|
| 1074 |
+
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
|
| 1075 |
+
raise ValueError(
|
| 1076 |
+
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
|
| 1077 |
+
)
|
| 1078 |
+
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
|
| 1079 |
+
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
|
| 1080 |
+
|
| 1081 |
+
if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
|
| 1082 |
+
raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
|
| 1083 |
+
|
| 1084 |
+
defaults = None
|
| 1085 |
+
warned = []
|
| 1086 |
+
mp_from_config_flag = False
|
| 1087 |
+
# Get the default from the config file.
|
| 1088 |
+
if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
|
| 1089 |
+
defaults = load_config_from_file(args.config_file)
|
| 1090 |
+
if (
|
| 1091 |
+
not args.multi_gpu
|
| 1092 |
+
and not args.tpu
|
| 1093 |
+
and not args.tpu_use_cluster
|
| 1094 |
+
and not args.use_deepspeed
|
| 1095 |
+
and not args.use_fsdp
|
| 1096 |
+
and not args.use_megatron_lm
|
| 1097 |
+
):
|
| 1098 |
+
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
|
| 1099 |
+
args.multi_gpu = (
|
| 1100 |
+
True
|
| 1101 |
+
if defaults.distributed_type
|
| 1102 |
+
in (
|
| 1103 |
+
DistributedType.MULTI_GPU,
|
| 1104 |
+
DistributedType.MULTI_NPU,
|
| 1105 |
+
DistributedType.MULTI_MLU,
|
| 1106 |
+
DistributedType.MULTI_SDAA,
|
| 1107 |
+
DistributedType.MULTI_MUSA,
|
| 1108 |
+
DistributedType.MULTI_XPU,
|
| 1109 |
+
DistributedType.MULTI_HPU,
|
| 1110 |
+
)
|
| 1111 |
+
else False
|
| 1112 |
+
)
|
| 1113 |
+
args.tpu = defaults.distributed_type == DistributedType.XLA
|
| 1114 |
+
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
|
| 1115 |
+
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
|
| 1116 |
+
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
|
| 1117 |
+
args.use_parallelism_config = defaults.parallelism_config != {}
|
| 1118 |
+
if args.gpu_ids is None:
|
| 1119 |
+
if defaults.gpu_ids is not None:
|
| 1120 |
+
args.gpu_ids = defaults.gpu_ids
|
| 1121 |
+
else:
|
| 1122 |
+
args.gpu_ids = "all"
|
| 1123 |
+
|
| 1124 |
+
if args.multi_gpu and args.num_machines is None:
|
| 1125 |
+
args.num_machines = defaults.num_machines
|
| 1126 |
+
|
| 1127 |
+
if len(args.gpu_ids.split(",")) < 2 and (args.gpu_ids != "all") and args.multi_gpu and args.num_machines <= 1:
|
| 1128 |
+
raise ValueError(
|
| 1129 |
+
"Less than two GPU ids were configured and tried to run on on multiple GPUs. "
|
| 1130 |
+
"Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`."
|
| 1131 |
+
)
|
| 1132 |
+
if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
|
| 1133 |
+
# Update args with the defaults
|
| 1134 |
+
for name, attr in defaults.__dict__.items():
|
| 1135 |
+
if isinstance(attr, dict):
|
| 1136 |
+
# Copy defaults.somedict.somearg to args.somearg and
|
| 1137 |
+
# defaults.fsdp_config.x to args.fsdp_x
|
| 1138 |
+
for key, value in attr.items():
|
| 1139 |
+
if name == "fsdp_config" and not key.startswith("fsdp"):
|
| 1140 |
+
key = "fsdp_" + key
|
| 1141 |
+
elif name == "fp8_config" and not key.startswith("fp8"):
|
| 1142 |
+
key = "fp8_" + key
|
| 1143 |
+
if hasattr(args, "nondefault") and key not in args.nondefault:
|
| 1144 |
+
setattr(args, key, value)
|
| 1145 |
+
elif (
|
| 1146 |
+
name not in ["compute_environment", "mixed_precision", "distributed_type"]
|
| 1147 |
+
and getattr(args, name, None) is None
|
| 1148 |
+
):
|
| 1149 |
+
# Those args are handled separately
|
| 1150 |
+
setattr(args, name, attr)
|
| 1151 |
+
if not args.debug:
|
| 1152 |
+
args.debug = defaults.debug
|
| 1153 |
+
|
| 1154 |
+
if not args.mixed_precision:
|
| 1155 |
+
if defaults.mixed_precision is None:
|
| 1156 |
+
args.mixed_precision = "no"
|
| 1157 |
+
else:
|
| 1158 |
+
args.mixed_precision = defaults.mixed_precision
|
| 1159 |
+
mp_from_config_flag = True
|
| 1160 |
+
else:
|
| 1161 |
+
native_amp = is_bf16_available(True)
|
| 1162 |
+
if (
|
| 1163 |
+
args.mixed_precision == "bf16"
|
| 1164 |
+
and not native_amp
|
| 1165 |
+
and not (args.tpu and is_torch_xla_available(check_is_tpu=True))
|
| 1166 |
+
):
|
| 1167 |
+
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
|
| 1168 |
+
|
| 1169 |
+
# Silently set the default here
|
| 1170 |
+
if args.dynamo_backend is None:
|
| 1171 |
+
args.dynamo_backend = "no"
|
| 1172 |
+
if args.num_processes == -1:
|
| 1173 |
+
raise ValueError("You need to manually pass in `--num_processes` using this config yaml.")
|
| 1174 |
+
else:
|
| 1175 |
+
if args.num_processes is None:
|
| 1176 |
+
if is_xpu_available():
|
| 1177 |
+
args.num_processes = torch.xpu.device_count()
|
| 1178 |
+
elif is_mlu_available():
|
| 1179 |
+
args.num_processes = torch.mlu.device_count()
|
| 1180 |
+
elif is_sdaa_available():
|
| 1181 |
+
args.num_processes = torch.sdaa.device_count()
|
| 1182 |
+
elif is_musa_available():
|
| 1183 |
+
args.num_processes = torch.musa.device_count()
|
| 1184 |
+
elif is_npu_available():
|
| 1185 |
+
args.num_processes = torch.npu.device_count()
|
| 1186 |
+
elif is_hpu_available():
|
| 1187 |
+
args.num_processes = torch.hpu.device_count()
|
| 1188 |
+
else:
|
| 1189 |
+
args.num_processes = torch.cuda.device_count()
|
| 1190 |
+
warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
|
| 1191 |
+
if args.debug is None:
|
| 1192 |
+
args.debug = False
|
| 1193 |
+
if (
|
| 1194 |
+
not args.multi_gpu
|
| 1195 |
+
and args.num_processes > 1
|
| 1196 |
+
and (
|
| 1197 |
+
(is_xpu_available() and torch.xpu.device_count() > 1)
|
| 1198 |
+
or (is_npu_available() and torch.npu.device_count() > 1)
|
| 1199 |
+
or (is_hpu_available() and torch.hpu.device_count() > 1)
|
| 1200 |
+
or (is_mlu_available() and torch.mlu.device_count() > 1)
|
| 1201 |
+
or (is_sdaa_available() and torch.sdaa.device_count() > 1)
|
| 1202 |
+
or (is_musa_available() and torch.musa.device_count() > 1)
|
| 1203 |
+
or (torch.cuda.is_available() and torch.cuda.device_count() > 1)
|
| 1204 |
+
)
|
| 1205 |
+
):
|
| 1206 |
+
warned.append(
|
| 1207 |
+
"\t\tMore than one GPU was found, enabling multi-GPU training.\n"
|
| 1208 |
+
"\t\tIf this was unintended please pass in `--num_processes=1`."
|
| 1209 |
+
)
|
| 1210 |
+
args.multi_gpu = True
|
| 1211 |
+
if args.num_machines is None:
|
| 1212 |
+
warned.append("\t`--num_machines` was set to a value of `1`")
|
| 1213 |
+
args.num_machines = 1
|
| 1214 |
+
if args.mixed_precision is None:
|
| 1215 |
+
warned.append("\t`--mixed_precision` was set to a value of `'no'`")
|
| 1216 |
+
args.mixed_precision = "no"
|
| 1217 |
+
if not hasattr(args, "use_cpu"):
|
| 1218 |
+
args.use_cpu = args.cpu
|
| 1219 |
+
if args.dynamo_backend is None:
|
| 1220 |
+
warned.append("\t`--dynamo_backend` was set to a value of `'no'`")
|
| 1221 |
+
args.dynamo_backend = "no"
|
| 1222 |
+
if args.debug:
|
| 1223 |
+
logger.debug("Running script in debug mode, expect distributed operations to be slightly slower.")
|
| 1224 |
+
|
| 1225 |
+
is_aws_env_disabled = defaults is None or (
|
| 1226 |
+
defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER
|
| 1227 |
+
)
|
| 1228 |
+
if is_aws_env_disabled and args.num_cpu_threads_per_process is None:
|
| 1229 |
+
args.num_cpu_threads_per_process = get_int_from_env(["OMP_NUM_THREADS"], 1)
|
| 1230 |
+
if args.use_cpu and args.num_processes >= 1 and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0:
|
| 1231 |
+
local_size = get_int_from_env(
|
| 1232 |
+
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"],
|
| 1233 |
+
max(int(args.num_processes / args.num_machines), 1),
|
| 1234 |
+
)
|
| 1235 |
+
threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
|
| 1236 |
+
if threads_per_process > 1:
|
| 1237 |
+
args.num_cpu_threads_per_process = threads_per_process
|
| 1238 |
+
warned.append(
|
| 1239 |
+
f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs"
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
if args.use_xpu is not None:
|
| 1243 |
+
logger.warning(
|
| 1244 |
+
"use_xpu is deprecated and ignored, will be removed in Accelerate v1.20. "
|
| 1245 |
+
"XPU is a PyTorch native citizen now, we don't need extra argument to enable it any more."
|
| 1246 |
+
)
|
| 1247 |
+
|
| 1248 |
+
if any(warned):
|
| 1249 |
+
message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n"
|
| 1250 |
+
message += "\n".join(warned)
|
| 1251 |
+
message += (
|
| 1252 |
+
"\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`."
|
| 1253 |
+
)
|
| 1254 |
+
logger.warning(message)
|
| 1255 |
+
return args, defaults, mp_from_config_flag
|
| 1256 |
+
|
| 1257 |
+
|
| 1258 |
+
def launch_command(args):
|
| 1259 |
+
args, defaults, mp_from_config_flag = _validate_launch_command(args)
|
| 1260 |
+
# Use the proper launcher
|
| 1261 |
+
if args.use_deepspeed and not args.cpu:
|
| 1262 |
+
args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else []
|
| 1263 |
+
if mp_from_config_flag:
|
| 1264 |
+
args.deepspeed_fields_from_accelerate_config.append("mixed_precision")
|
| 1265 |
+
args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config)
|
| 1266 |
+
deepspeed_launcher(args)
|
| 1267 |
+
elif args.use_fsdp and not args.cpu:
|
| 1268 |
+
multi_gpu_launcher(args)
|
| 1269 |
+
elif args.use_megatron_lm and not args.cpu:
|
| 1270 |
+
multi_gpu_launcher(args)
|
| 1271 |
+
elif args.multi_gpu and not args.cpu:
|
| 1272 |
+
multi_gpu_launcher(args)
|
| 1273 |
+
elif args.tpu and not args.cpu:
|
| 1274 |
+
if args.tpu_use_cluster:
|
| 1275 |
+
tpu_pod_launcher(args)
|
| 1276 |
+
else:
|
| 1277 |
+
tpu_launcher(args)
|
| 1278 |
+
elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
| 1279 |
+
sagemaker_launcher(defaults, args)
|
| 1280 |
+
else:
|
| 1281 |
+
simple_launcher(args)
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
def main():
|
| 1285 |
+
parser = launch_command_parser()
|
| 1286 |
+
args = parser.parse_args()
|
| 1287 |
+
launch_command(args)
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
if __name__ == "__main__":
|
| 1291 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/merge.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 17 |
+
from accelerate.utils import merge_fsdp_weights
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
|
| 21 |
+
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
|
| 22 |
+
|
| 23 |
+
This is a CPU-bound process and requires enough RAM to load the entire model state dict."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def merge_command(args):
|
| 27 |
+
merge_fsdp_weights(
|
| 28 |
+
args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def merge_command_parser(subparsers=None):
|
| 33 |
+
if subparsers is not None:
|
| 34 |
+
parser = subparsers.add_parser("merge-weights", description=description)
|
| 35 |
+
else:
|
| 36 |
+
parser = CustomArgumentParser(description=description)
|
| 37 |
+
|
| 38 |
+
parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"output_path",
|
| 41 |
+
type=str,
|
| 42 |
+
help="The path to save the merged weights. Defaults to the current directory. ",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--unsafe_serialization",
|
| 46 |
+
action="store_true",
|
| 47 |
+
default=False,
|
| 48 |
+
help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--remove_checkpoint_dir",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Whether to remove the checkpoint directory after merging.",
|
| 54 |
+
default=False,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if subparsers is not None:
|
| 58 |
+
parser.set_defaults(func=merge_command)
|
| 59 |
+
return parser
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
parser = merge_command_parser()
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
merge_command(args)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/test.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_command_parser(subparsers=None):
|
| 23 |
+
if subparsers is not None:
|
| 24 |
+
parser = subparsers.add_parser("test")
|
| 25 |
+
else:
|
| 26 |
+
parser = argparse.ArgumentParser("Accelerate test command")
|
| 27 |
+
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--config_file",
|
| 30 |
+
default=None,
|
| 31 |
+
help=(
|
| 32 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
| 33 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 34 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 35 |
+
"with 'huggingface'."
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if subparsers is not None:
|
| 40 |
+
parser.set_defaults(func=test_command)
|
| 41 |
+
return parser
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_command(args):
|
| 45 |
+
script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py")
|
| 46 |
+
|
| 47 |
+
if args.config_file is None:
|
| 48 |
+
test_args = [script_name]
|
| 49 |
+
else:
|
| 50 |
+
test_args = f"--config_file={args.config_file} {script_name}".split()
|
| 51 |
+
|
| 52 |
+
cmd = ["accelerate-launch"] + test_args
|
| 53 |
+
result = execute_subprocess_async(cmd)
|
| 54 |
+
if result.returncode == 0:
|
| 55 |
+
print("Test is a success! You are ready for your distributed training!")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
parser = test_command_parser()
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
test_command(args)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/to_fsdp2.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import enum
|
| 18 |
+
import logging
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConversionStatus(enum.Enum):
|
| 27 |
+
NOT_YET_IMPLEMENTED = 0
|
| 28 |
+
REMOVED = -1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
ARGUMENT_KEY_MAPPING = {
|
| 32 |
+
# New keys in FSDP2
|
| 33 |
+
"fsdp_version": "fsdp_version",
|
| 34 |
+
"fsdp_reshard_after_forward": "fsdp_reshard_after_forward",
|
| 35 |
+
# https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
|
| 36 |
+
# https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
|
| 37 |
+
"fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy",
|
| 38 |
+
"fsdp_backward_prefetch": ConversionStatus.REMOVED,
|
| 39 |
+
"fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED,
|
| 40 |
+
"fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading",
|
| 41 |
+
"fsdp_offload_params": "fsdp_offload_params",
|
| 42 |
+
"fsdp_sharding_strategy": "fsdp_reshard_after_forward",
|
| 43 |
+
"fsdp_state_dict_type": "fsdp_state_dict_type",
|
| 44 |
+
"fsdp_sync_module_states": ConversionStatus.REMOVED,
|
| 45 |
+
"fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap",
|
| 46 |
+
"fsdp_min_num_params": "fsdp_min_num_params",
|
| 47 |
+
"fsdp_use_orig_params": ConversionStatus.REMOVED,
|
| 48 |
+
"fsdp_activation_checkpointing": "fsdp_activation_checkpointing",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
ARGUMENT_VALUE_MAPPING = {
|
| 52 |
+
"fsdp_sharding_strategy": {
|
| 53 |
+
"FULL_SHARD": True,
|
| 54 |
+
"SHARD_GRAD_OP": False,
|
| 55 |
+
"HYBRID_SHARD": True,
|
| 56 |
+
"HYBRID_SHARD_ZERO2": False,
|
| 57 |
+
"NO_SHARD": False,
|
| 58 |
+
},
|
| 59 |
+
"fsdp_reshard_after_forward": { # Needed to convert newly created configs using FSDP1 to FSDP2
|
| 60 |
+
"FULL_SHARD": True,
|
| 61 |
+
"SHARD_GRAD_OP": False,
|
| 62 |
+
"HYBRID_SHARD": True,
|
| 63 |
+
"HYBRID_SHARD_ZERO2": False,
|
| 64 |
+
"NO_SHARD": False,
|
| 65 |
+
},
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _validate_to_fsdp2_args(args):
|
| 72 |
+
if not Path(args.config_file).exists():
|
| 73 |
+
raise FileNotFoundError(f"Config file {args.config_file} not found")
|
| 74 |
+
|
| 75 |
+
if not args.overwrite and args.output_file is None:
|
| 76 |
+
raise ValueError("If --overwrite is not set, --output_file must be provided")
|
| 77 |
+
|
| 78 |
+
if not args.overwrite and Path(args.output_file).exists():
|
| 79 |
+
raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def convert_config_to_fsdp2(config: dict) -> dict:
|
| 83 |
+
fsdp_config = config.get("fsdp_config", {})
|
| 84 |
+
|
| 85 |
+
if not fsdp_config:
|
| 86 |
+
logger.info("No FSDP config found in the config file, skipping conversion...")
|
| 87 |
+
return config
|
| 88 |
+
|
| 89 |
+
new_fsdp_config = {}
|
| 90 |
+
|
| 91 |
+
if fsdp_config.get("fsdp_version", 1) == 2:
|
| 92 |
+
logger.warning("Config already specifies FSDP2, skipping conversion...")
|
| 93 |
+
logger.warning(
|
| 94 |
+
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
|
| 95 |
+
)
|
| 96 |
+
return config
|
| 97 |
+
|
| 98 |
+
for key, value in fsdp_config.items():
|
| 99 |
+
conversion_status = ARGUMENT_KEY_MAPPING.get(key, None)
|
| 100 |
+
if isinstance(conversion_status, ConversionStatus) or conversion_status is None:
|
| 101 |
+
conversion_status = key
|
| 102 |
+
new_fsdp_config[conversion_status] = value
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
if conversion_status == ConversionStatus.REMOVED:
|
| 106 |
+
logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...")
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED:
|
| 110 |
+
logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...")
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
if conversion_status is None:
|
| 114 |
+
logger.warning(f"Argument {key} is not being converted, skipping this key...")
|
| 115 |
+
new_fsdp_config[key] = value
|
| 116 |
+
else:
|
| 117 |
+
if key in ARGUMENT_VALUE_MAPPING:
|
| 118 |
+
value = ARGUMENT_VALUE_MAPPING[key].get(value, value)
|
| 119 |
+
new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value
|
| 120 |
+
|
| 121 |
+
new_fsdp_config["fsdp_version"] = 2
|
| 122 |
+
config["fsdp_config"] = new_fsdp_config
|
| 123 |
+
return config
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def to_fsdp2_command_parser(subparsers=None):
|
| 127 |
+
description = "Convert an Accelerate config from FSDP1 to FSDP2"
|
| 128 |
+
|
| 129 |
+
if subparsers is not None:
|
| 130 |
+
parser = subparsers.add_parser("to-fsdp2", description=description)
|
| 131 |
+
else:
|
| 132 |
+
parser = CustomArgumentParser(description=description)
|
| 133 |
+
|
| 134 |
+
parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--overwrite",
|
| 137 |
+
action="store_true",
|
| 138 |
+
help="Overwrite the config file if it exists",
|
| 139 |
+
default=False,
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--output_file",
|
| 143 |
+
type=str,
|
| 144 |
+
help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)",
|
| 145 |
+
default=None,
|
| 146 |
+
)
|
| 147 |
+
if subparsers is not None:
|
| 148 |
+
parser.set_defaults(func=to_fsdp2_command)
|
| 149 |
+
|
| 150 |
+
return parser
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_config(config_file: str) -> dict:
|
| 154 |
+
with open(config_file) as f:
|
| 155 |
+
config = yaml.safe_load(f)
|
| 156 |
+
if not config:
|
| 157 |
+
raise ValueError("Config file is empty")
|
| 158 |
+
|
| 159 |
+
return config
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def to_fsdp2_command(args):
|
| 163 |
+
_validate_to_fsdp2_args(args)
|
| 164 |
+
config = load_config(args.config_file)
|
| 165 |
+
|
| 166 |
+
if args.overwrite and args.output_file is None:
|
| 167 |
+
args.output_file = args.config_file
|
| 168 |
+
|
| 169 |
+
new_config = convert_config_to_fsdp2(config)
|
| 170 |
+
|
| 171 |
+
with open(args.output_file, "w") as f:
|
| 172 |
+
yaml.dump(new_config, f)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _StoreAction(argparse.Action):
|
| 19 |
+
"""
|
| 20 |
+
Custom action that allows for `-` or `_` to be passed in for an argument.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, *args, **kwargs):
|
| 24 |
+
super().__init__(*args, **kwargs)
|
| 25 |
+
new_option_strings = []
|
| 26 |
+
for option_string in self.option_strings:
|
| 27 |
+
new_option_strings.append(option_string)
|
| 28 |
+
if "_" in option_string[2:]:
|
| 29 |
+
# Add `-` version to the option string
|
| 30 |
+
new_option_strings.append(option_string.replace("_", "-"))
|
| 31 |
+
self.option_strings = new_option_strings
|
| 32 |
+
|
| 33 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
| 34 |
+
setattr(namespace, self.dest, values)
|
| 35 |
+
if not hasattr(namespace, "nondefault"):
|
| 36 |
+
namespace.nondefault = set()
|
| 37 |
+
namespace.nondefault.add(self.dest)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class _StoreConstAction(_StoreAction):
|
| 41 |
+
"""
|
| 42 |
+
Same as `argparse._StoreConstAction` but uses the custom `_StoreAction`.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, option_strings, dest, const, default=None, required=False, help=None):
|
| 46 |
+
super().__init__(
|
| 47 |
+
option_strings=option_strings,
|
| 48 |
+
dest=dest,
|
| 49 |
+
nargs=0,
|
| 50 |
+
const=const,
|
| 51 |
+
default=default,
|
| 52 |
+
required=required,
|
| 53 |
+
help=help,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
| 57 |
+
super().__call__(parser, namespace, self.const, option_string)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class _StoreTrueAction(_StoreConstAction):
|
| 61 |
+
"""
|
| 62 |
+
Same as `argparse._StoreTrueAction` but uses the custom `_StoreConstAction`.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
option_strings,
|
| 68 |
+
dest,
|
| 69 |
+
default=None,
|
| 70 |
+
required=False,
|
| 71 |
+
help=None,
|
| 72 |
+
):
|
| 73 |
+
super().__init__(
|
| 74 |
+
option_strings=option_strings, dest=dest, const=True, default=default, required=required, help=help
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class CustomArgumentGroup(argparse._ArgumentGroup):
|
| 79 |
+
"""
|
| 80 |
+
Custom argument group that allows for the use of `-` or `_` in arguments passed and overrides the help for each
|
| 81 |
+
when applicable.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def _add_action(self, action):
|
| 85 |
+
args = vars(action)
|
| 86 |
+
if isinstance(action, argparse._StoreTrueAction):
|
| 87 |
+
action = _StoreTrueAction(
|
| 88 |
+
args["option_strings"], args["dest"], args["default"], args["required"], args["help"]
|
| 89 |
+
)
|
| 90 |
+
elif isinstance(action, argparse._StoreConstAction):
|
| 91 |
+
action = _StoreConstAction(
|
| 92 |
+
args["option_strings"],
|
| 93 |
+
args["dest"],
|
| 94 |
+
args["const"],
|
| 95 |
+
args["default"],
|
| 96 |
+
args["required"],
|
| 97 |
+
args["help"],
|
| 98 |
+
)
|
| 99 |
+
elif isinstance(action, argparse._StoreAction):
|
| 100 |
+
action = _StoreAction(**args)
|
| 101 |
+
action = super()._add_action(action)
|
| 102 |
+
return action
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class CustomArgumentParser(argparse.ArgumentParser):
|
| 106 |
+
"""
|
| 107 |
+
Custom argument parser that allows for the use of `-` or `_` in arguments passed and overrides the help for each
|
| 108 |
+
when applicable.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def add_argument(self, *args, **kwargs):
|
| 112 |
+
if "action" in kwargs:
|
| 113 |
+
# Translate action -> class
|
| 114 |
+
if kwargs["action"] == "store_true":
|
| 115 |
+
kwargs["action"] = _StoreTrueAction
|
| 116 |
+
else:
|
| 117 |
+
kwargs["action"] = _StoreAction
|
| 118 |
+
super().add_argument(*args, **kwargs)
|
| 119 |
+
|
| 120 |
+
def add_argument_group(self, *args, **kwargs):
|
| 121 |
+
group = CustomArgumentGroup(self, *args, **kwargs)
|
| 122 |
+
self._action_groups.append(group)
|
| 123 |
+
return group
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from .testing import (
|
| 15 |
+
DEFAULT_LAUNCH_COMMAND,
|
| 16 |
+
are_the_same_tensors,
|
| 17 |
+
assert_exception,
|
| 18 |
+
capture_call_output,
|
| 19 |
+
device_count,
|
| 20 |
+
execute_subprocess_async,
|
| 21 |
+
get_launch_command,
|
| 22 |
+
get_torch_dist_unique_port,
|
| 23 |
+
memory_allocated_func,
|
| 24 |
+
path_in_accelerate_package,
|
| 25 |
+
pytest_xdist_worker_id,
|
| 26 |
+
require_bnb,
|
| 27 |
+
require_cpu,
|
| 28 |
+
require_cuda,
|
| 29 |
+
require_cuda_or_hpu,
|
| 30 |
+
require_cuda_or_xpu,
|
| 31 |
+
require_fp8,
|
| 32 |
+
require_fp16,
|
| 33 |
+
require_huggingface_suite,
|
| 34 |
+
require_mlu,
|
| 35 |
+
require_mps,
|
| 36 |
+
require_multi_device,
|
| 37 |
+
require_multi_gpu,
|
| 38 |
+
require_multi_gpu_or_xpu,
|
| 39 |
+
require_multi_xpu,
|
| 40 |
+
require_musa,
|
| 41 |
+
require_non_cpu,
|
| 42 |
+
require_non_hpu,
|
| 43 |
+
require_non_torch_xla,
|
| 44 |
+
require_non_xpu,
|
| 45 |
+
require_npu,
|
| 46 |
+
require_pippy,
|
| 47 |
+
require_sdaa,
|
| 48 |
+
require_single_device,
|
| 49 |
+
require_single_gpu,
|
| 50 |
+
require_single_xpu,
|
| 51 |
+
require_torch_min_version,
|
| 52 |
+
require_torchao,
|
| 53 |
+
require_torchvision,
|
| 54 |
+
require_tpu,
|
| 55 |
+
require_transformer_engine,
|
| 56 |
+
require_transformer_engine_mxfp8,
|
| 57 |
+
require_xpu,
|
| 58 |
+
run_first,
|
| 59 |
+
skip,
|
| 60 |
+
slow,
|
| 61 |
+
torch_device,
|
| 62 |
+
)
|
| 63 |
+
from .training import RegressionDataset, RegressionModel
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
from .scripts import test_script, test_sync, test_ops # isort: skip
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/examples.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
A collection of utilities for comparing `examples/complete_*_example.py` scripts with the capabilities inside of each
|
| 18 |
+
`examples/by_feature` example. `compare_against_test` is the main function that should be used when testing, while the
|
| 19 |
+
others are used to either get the code that matters, or to preprocess them (such as stripping comments)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_function_contents_by_name(lines: list[str], name: str):
|
| 27 |
+
"""
|
| 28 |
+
Extracts a function from `lines` of segmented source code with the name `name`.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
lines (`List[str]`):
|
| 32 |
+
Source code of a script separated by line.
|
| 33 |
+
name (`str`):
|
| 34 |
+
The name of the function to extract. Should be either `training_function` or `main`
|
| 35 |
+
"""
|
| 36 |
+
if name != "training_function" and name != "main":
|
| 37 |
+
raise ValueError(f"Incorrect function name passed: {name}, choose either 'main' or 'training_function'")
|
| 38 |
+
good_lines, found_start = [], False
|
| 39 |
+
for line in lines:
|
| 40 |
+
if not found_start and f"def {name}" in line:
|
| 41 |
+
found_start = True
|
| 42 |
+
good_lines.append(line)
|
| 43 |
+
continue
|
| 44 |
+
if found_start:
|
| 45 |
+
if name == "training_function" and "def main" in line:
|
| 46 |
+
return good_lines
|
| 47 |
+
if name == "main" and "if __name__" in line:
|
| 48 |
+
return good_lines
|
| 49 |
+
good_lines.append(line)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def clean_lines(lines: list[str]):
|
| 53 |
+
"""
|
| 54 |
+
Filters `lines` and removes any entries that start with a comment ('#') or is just a newline ('\n')
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
lines (`List[str]`):
|
| 58 |
+
Source code of a script separated by line.
|
| 59 |
+
"""
|
| 60 |
+
return [line for line in lines if not line.lstrip().startswith("#") and line != "\n"]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def compare_against_test(
|
| 64 |
+
base_filename: str, feature_filename: str, parser_only: bool, secondary_filename: Optional[str] = None
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Tests whether the additional code inside of `feature_filename` was implemented in `base_filename`. This should be
|
| 68 |
+
used when testing to see if `complete_*_.py` examples have all of the implementations from each of the
|
| 69 |
+
`examples/by_feature/*` scripts.
|
| 70 |
+
|
| 71 |
+
It utilizes `nlp_example.py` to extract out all of the repeated training code, so that only the new additional code
|
| 72 |
+
is examined and checked. If something *other* than `nlp_example.py` should be used, such as `cv_example.py` for the
|
| 73 |
+
`complete_cv_example.py` script, it should be passed in for the `secondary_filename` parameter.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
base_filename (`str` or `os.PathLike`):
|
| 77 |
+
The filepath of a single "complete" example script to test, such as `examples/complete_cv_example.py`
|
| 78 |
+
feature_filename (`str` or `os.PathLike`):
|
| 79 |
+
The filepath of a single feature example script. The contents of this script are checked to see if they
|
| 80 |
+
exist in `base_filename`
|
| 81 |
+
parser_only (`bool`):
|
| 82 |
+
Whether to compare only the `main()` sections in both files, or to compare the contents of
|
| 83 |
+
`training_loop()`
|
| 84 |
+
secondary_filename (`str`, *optional*):
|
| 85 |
+
A potential secondary filepath that should be included in the check. This function extracts the base
|
| 86 |
+
functionalities off of "examples/nlp_example.py", so if `base_filename` is a script other than
|
| 87 |
+
`complete_nlp_example.py`, the template script should be included here. Such as `examples/cv_example.py`
|
| 88 |
+
"""
|
| 89 |
+
with open(base_filename) as f:
|
| 90 |
+
base_file_contents = f.readlines()
|
| 91 |
+
with open(os.path.abspath(os.path.join("examples", "nlp_example.py"))) as f:
|
| 92 |
+
full_file_contents = f.readlines()
|
| 93 |
+
with open(feature_filename) as f:
|
| 94 |
+
feature_file_contents = f.readlines()
|
| 95 |
+
if secondary_filename is not None:
|
| 96 |
+
with open(secondary_filename) as f:
|
| 97 |
+
secondary_file_contents = f.readlines()
|
| 98 |
+
|
| 99 |
+
# This is our base, we remove all the code from here in our `full_filename` and `feature_filename` to find the new content
|
| 100 |
+
if parser_only:
|
| 101 |
+
base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "main"))
|
| 102 |
+
full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "main"))
|
| 103 |
+
feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "main"))
|
| 104 |
+
if secondary_filename is not None:
|
| 105 |
+
secondary_file_func = clean_lines(get_function_contents_by_name(secondary_file_contents, "main"))
|
| 106 |
+
else:
|
| 107 |
+
base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "training_function"))
|
| 108 |
+
full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "training_function"))
|
| 109 |
+
feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "training_function"))
|
| 110 |
+
if secondary_filename is not None:
|
| 111 |
+
secondary_file_func = clean_lines(
|
| 112 |
+
get_function_contents_by_name(secondary_file_contents, "training_function")
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
_dl_line = "train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n"
|
| 116 |
+
|
| 117 |
+
# Specific code in our script that differs from the full version, aka what is new
|
| 118 |
+
new_feature_code = []
|
| 119 |
+
passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
|
| 120 |
+
it = iter(feature_file_func)
|
| 121 |
+
for i in range(len(feature_file_func) - 1):
|
| 122 |
+
if i not in passed_idxs:
|
| 123 |
+
line = next(it)
|
| 124 |
+
if (line not in full_file_func) and (line.lstrip() != _dl_line):
|
| 125 |
+
if "TESTING_MOCKED_DATALOADERS" not in line:
|
| 126 |
+
new_feature_code.append(line)
|
| 127 |
+
passed_idxs.append(i)
|
| 128 |
+
else:
|
| 129 |
+
# Skip over the `config['num_epochs'] = 2` statement
|
| 130 |
+
_ = next(it)
|
| 131 |
+
|
| 132 |
+
# Extract out just the new parts from the full_file_training_func
|
| 133 |
+
new_full_example_parts = []
|
| 134 |
+
passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
|
| 135 |
+
for i, line in enumerate(base_file_func):
|
| 136 |
+
if i not in passed_idxs:
|
| 137 |
+
if (line not in full_file_func) and (line.lstrip() != _dl_line):
|
| 138 |
+
if "TESTING_MOCKED_DATALOADERS" not in line:
|
| 139 |
+
new_full_example_parts.append(line)
|
| 140 |
+
passed_idxs.append(i)
|
| 141 |
+
|
| 142 |
+
# Finally, get the overall diff
|
| 143 |
+
diff_from_example = [line for line in new_feature_code if line not in new_full_example_parts]
|
| 144 |
+
if secondary_filename is not None:
|
| 145 |
+
diff_from_two = [line for line in full_file_contents if line not in secondary_file_func]
|
| 146 |
+
diff_from_example = [line for line in diff_from_example if line not in diff_from_two]
|
| 147 |
+
|
| 148 |
+
return diff_from_example
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/testing.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import inspect
|
| 17 |
+
import io
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import shutil
|
| 21 |
+
import subprocess
|
| 22 |
+
import sys
|
| 23 |
+
import tempfile
|
| 24 |
+
import unittest
|
| 25 |
+
from contextlib import contextmanager
|
| 26 |
+
from functools import partial
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Optional, Union
|
| 29 |
+
from unittest import mock
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
|
| 33 |
+
import accelerate
|
| 34 |
+
|
| 35 |
+
from ..state import AcceleratorState
|
| 36 |
+
from ..utils import (
|
| 37 |
+
check_cuda_fp8_capability,
|
| 38 |
+
compare_versions,
|
| 39 |
+
gather,
|
| 40 |
+
is_aim_available,
|
| 41 |
+
is_bnb_available,
|
| 42 |
+
is_clearml_available,
|
| 43 |
+
is_comet_ml_available,
|
| 44 |
+
is_cuda_available,
|
| 45 |
+
is_datasets_available,
|
| 46 |
+
is_deepspeed_available,
|
| 47 |
+
is_dvclive_available,
|
| 48 |
+
is_fp8_available,
|
| 49 |
+
is_fp16_available,
|
| 50 |
+
is_habana_gaudi1,
|
| 51 |
+
is_hpu_available,
|
| 52 |
+
is_import_timer_available,
|
| 53 |
+
is_matplotlib_available,
|
| 54 |
+
is_mlflow_available,
|
| 55 |
+
is_mlu_available,
|
| 56 |
+
is_mps_available,
|
| 57 |
+
is_musa_available,
|
| 58 |
+
is_npu_available,
|
| 59 |
+
is_pandas_available,
|
| 60 |
+
is_pippy_available,
|
| 61 |
+
is_pytest_available,
|
| 62 |
+
is_schedulefree_available,
|
| 63 |
+
is_sdaa_available,
|
| 64 |
+
is_swanlab_available,
|
| 65 |
+
is_tensorboard_available,
|
| 66 |
+
is_timm_available,
|
| 67 |
+
is_torch_version,
|
| 68 |
+
is_torch_xla_available,
|
| 69 |
+
is_torchao_available,
|
| 70 |
+
is_torchdata_stateful_dataloader_available,
|
| 71 |
+
is_torchvision_available,
|
| 72 |
+
is_trackio_available,
|
| 73 |
+
is_transformer_engine_available,
|
| 74 |
+
is_transformer_engine_mxfp8_available,
|
| 75 |
+
is_transformers_available,
|
| 76 |
+
is_triton_available,
|
| 77 |
+
is_wandb_available,
|
| 78 |
+
is_xpu_available,
|
| 79 |
+
str_to_bool,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_backend():
|
| 84 |
+
if is_torch_xla_available():
|
| 85 |
+
return "xla", torch.cuda.device_count(), torch.cuda.memory_allocated
|
| 86 |
+
elif is_cuda_available():
|
| 87 |
+
return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated
|
| 88 |
+
elif is_mps_available(min_version="2.0"):
|
| 89 |
+
return "mps", 1, torch.mps.current_allocated_memory
|
| 90 |
+
elif is_mps_available():
|
| 91 |
+
return "mps", 1, lambda: 0
|
| 92 |
+
elif is_mlu_available():
|
| 93 |
+
return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated
|
| 94 |
+
elif is_sdaa_available():
|
| 95 |
+
return "sdaa", torch.sdaa.device_count(), torch.sdaa.memory_allocated
|
| 96 |
+
elif is_musa_available():
|
| 97 |
+
return "musa", torch.musa.device_count(), torch.musa.memory_allocated
|
| 98 |
+
elif is_npu_available():
|
| 99 |
+
return "npu", torch.npu.device_count(), torch.npu.memory_allocated
|
| 100 |
+
elif is_xpu_available():
|
| 101 |
+
return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated
|
| 102 |
+
elif is_hpu_available():
|
| 103 |
+
return "hpu", torch.hpu.device_count(), torch.hpu.memory_allocated
|
| 104 |
+
else:
|
| 105 |
+
return "cpu", 1, lambda: 0
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
torch_device, device_count, memory_allocated_func = get_backend()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_launch_command(**kwargs) -> list:
|
| 112 |
+
"""
|
| 113 |
+
Wraps around `kwargs` to help simplify launching from `subprocess`.
|
| 114 |
+
|
| 115 |
+
Example:
|
| 116 |
+
```python
|
| 117 |
+
# returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2']
|
| 118 |
+
get_launch_command(num_processes=2, device_count=2)
|
| 119 |
+
```
|
| 120 |
+
"""
|
| 121 |
+
command = ["accelerate", "launch"]
|
| 122 |
+
for k, v in kwargs.items():
|
| 123 |
+
if isinstance(v, bool) and v:
|
| 124 |
+
command.append(f"--{k}")
|
| 125 |
+
elif v is not None:
|
| 126 |
+
command.append(f"--{k}={v}")
|
| 127 |
+
return command
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
DEFAULT_LAUNCH_COMMAND = get_launch_command(num_processes=device_count, monitor_interval=0.1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def parse_flag_from_env(key, default=False):
|
| 134 |
+
try:
|
| 135 |
+
value = os.environ[key]
|
| 136 |
+
except KeyError:
|
| 137 |
+
# KEY isn't set, default to `default`.
|
| 138 |
+
_value = default
|
| 139 |
+
else:
|
| 140 |
+
# KEY is set, convert it to True or False.
|
| 141 |
+
try:
|
| 142 |
+
_value = str_to_bool(value)
|
| 143 |
+
except ValueError:
|
| 144 |
+
# More values are supported, but let's keep the message simple.
|
| 145 |
+
raise ValueError(f"If set, {key} must be yes or no.")
|
| 146 |
+
return _value
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def skip(test_case):
|
| 153 |
+
"Decorator that skips a test unconditionally"
|
| 154 |
+
return unittest.skip("Test was skipped")(test_case)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def slow(test_case):
|
| 158 |
+
"""
|
| 159 |
+
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
|
| 160 |
+
truthy value to run them.
|
| 161 |
+
"""
|
| 162 |
+
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def require_cpu(test_case):
|
| 166 |
+
"""
|
| 167 |
+
Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available.
|
| 168 |
+
"""
|
| 169 |
+
return unittest.skipUnless(torch_device == "cpu", "test requires only a CPU")(test_case)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def require_non_cpu(test_case):
|
| 173 |
+
"""
|
| 174 |
+
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
|
| 175 |
+
hardware accelerator available.
|
| 176 |
+
"""
|
| 177 |
+
return unittest.skipUnless(torch_device != "cpu", "test requires a GPU")(test_case)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def require_cuda(test_case):
|
| 181 |
+
"""
|
| 182 |
+
Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available or when
|
| 183 |
+
TorchXLA is available.
|
| 184 |
+
"""
|
| 185 |
+
return unittest.skipUnless(is_cuda_available() and not is_torch_xla_available(), "test requires a GPU")(test_case)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def require_cuda_or_hpu(test_case):
|
| 189 |
+
"""
|
| 190 |
+
Decorator marking a test that requires CUDA or HPU. These tests are skipped when there are no GPU available or when
|
| 191 |
+
TorchXLA is available.
|
| 192 |
+
"""
|
| 193 |
+
return unittest.skipUnless(
|
| 194 |
+
(is_cuda_available() and not is_torch_xla_available()) or is_hpu_available(), "test requires a GPU or HPU"
|
| 195 |
+
)(test_case)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def require_xpu(test_case):
|
| 199 |
+
"""
|
| 200 |
+
Decorator marking a test that requires XPU. These tests are skipped when there are no XPU available.
|
| 201 |
+
"""
|
| 202 |
+
return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def require_cuda_or_xpu(test_case):
|
| 206 |
+
"""
|
| 207 |
+
Decorator marking a test that requires CUDA or XPU. These tests are skipped when there are no GPU available or when
|
| 208 |
+
TorchXLA is available.
|
| 209 |
+
"""
|
| 210 |
+
cuda_condition = is_cuda_available() and not is_torch_xla_available()
|
| 211 |
+
xpu_condition = is_xpu_available()
|
| 212 |
+
return unittest.skipUnless(cuda_condition or xpu_condition, "test requires a CUDA GPU or XPU")(test_case)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def require_non_xpu(test_case):
|
| 216 |
+
"""
|
| 217 |
+
Decorator marking a test that should be skipped for XPU.
|
| 218 |
+
"""
|
| 219 |
+
return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def require_non_hpu(test_case):
|
| 223 |
+
"""
|
| 224 |
+
Decorator marking a test that should be skipped for HPU.
|
| 225 |
+
"""
|
| 226 |
+
return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def require_fp16(test_case):
|
| 230 |
+
"""
|
| 231 |
+
Decorator marking a test that requires FP16. These tests are skipped when FP16 is not supported.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
return unittest.skipUnless(is_fp16_available(), "test requires FP16 support")(test_case)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def require_fp8(test_case):
|
| 238 |
+
"""
|
| 239 |
+
Decorator marking a test that requires FP8. These tests are skipped when FP8 is not supported.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
# is_fp8_available only checks for libraries
|
| 243 |
+
# ideally it should check for device capability as well
|
| 244 |
+
fp8_is_available = is_fp8_available()
|
| 245 |
+
|
| 246 |
+
if torch.cuda.is_available() and not check_cuda_fp8_capability():
|
| 247 |
+
fp8_is_available = False
|
| 248 |
+
|
| 249 |
+
if is_hpu_available() and is_habana_gaudi1():
|
| 250 |
+
fp8_is_available = False
|
| 251 |
+
|
| 252 |
+
return unittest.skipUnless(fp8_is_available, "test requires FP8 support")(test_case)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def require_fsdp2(test_case):
|
| 256 |
+
return unittest.skipUnless(is_torch_version(">=", "2.5.0"), "test requires FSDP2 (torch >= 2.5.0)")(test_case)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def require_mlu(test_case):
|
| 260 |
+
"""
|
| 261 |
+
Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.
|
| 262 |
+
"""
|
| 263 |
+
return unittest.skipUnless(is_mlu_available(), "test require a MLU")(test_case)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def require_sdaa(test_case):
|
| 267 |
+
"""
|
| 268 |
+
Decorator marking a test that requires SDAA. These tests are skipped when there are no SDAA available.
|
| 269 |
+
"""
|
| 270 |
+
return unittest.skipUnless(is_sdaa_available(), "test require a SDAA")(test_case)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def require_musa(test_case):
|
| 274 |
+
"""
|
| 275 |
+
Decorator marking a test that requires MUSA. These tests are skipped when there are no MUSA available.
|
| 276 |
+
"""
|
| 277 |
+
return unittest.skipUnless(is_musa_available(), "test require a MUSA")(test_case)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def require_npu(test_case):
|
| 281 |
+
"""
|
| 282 |
+
Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available.
|
| 283 |
+
"""
|
| 284 |
+
return unittest.skipUnless(is_npu_available(), "test require a NPU")(test_case)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def require_mps(test_case):
|
| 288 |
+
"""
|
| 289 |
+
Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps`
|
| 290 |
+
backend.
|
| 291 |
+
"""
|
| 292 |
+
return unittest.skipUnless(is_mps_available(), "test requires a `mps` backend support in `torch`")(test_case)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def require_huggingface_suite(test_case):
|
| 296 |
+
"""
|
| 297 |
+
Decorator marking a test that requires transformers and datasets. These tests are skipped when they are not.
|
| 298 |
+
"""
|
| 299 |
+
return unittest.skipUnless(
|
| 300 |
+
is_transformers_available() and is_datasets_available(),
|
| 301 |
+
"test requires the Hugging Face suite",
|
| 302 |
+
)(test_case)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def require_transformers(test_case):
|
| 306 |
+
"""
|
| 307 |
+
Decorator marking a test that requires transformers. These tests are skipped when they are not.
|
| 308 |
+
"""
|
| 309 |
+
return unittest.skipUnless(is_transformers_available(), "test requires the transformers library")(test_case)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def require_timm(test_case):
|
| 313 |
+
"""
|
| 314 |
+
Decorator marking a test that requires timm. These tests are skipped when they are not.
|
| 315 |
+
"""
|
| 316 |
+
return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def require_torchvision(test_case):
|
| 320 |
+
"""
|
| 321 |
+
Decorator marking a test that requires torchvision. These tests are skipped when they are not.
|
| 322 |
+
"""
|
| 323 |
+
return unittest.skipUnless(is_torchvision_available(), "test requires the torchvision library")(test_case)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def require_triton(test_case):
|
| 327 |
+
"""
|
| 328 |
+
Decorator marking a test that requires triton. These tests are skipped when they are not.
|
| 329 |
+
"""
|
| 330 |
+
return unittest.skipUnless(is_triton_available(), "test requires the triton library")(test_case)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def require_schedulefree(test_case):
|
| 334 |
+
"""
|
| 335 |
+
Decorator marking a test that requires schedulefree. These tests are skipped when they are not.
|
| 336 |
+
"""
|
| 337 |
+
return unittest.skipUnless(is_schedulefree_available(), "test requires the schedulefree library")(test_case)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def require_bnb(test_case):
|
| 341 |
+
"""
|
| 342 |
+
Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not.
|
| 343 |
+
"""
|
| 344 |
+
return unittest.skipUnless(is_bnb_available(), "test requires the bitsandbytes library")(test_case)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def require_tpu(test_case):
|
| 348 |
+
"""
|
| 349 |
+
Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available.
|
| 350 |
+
"""
|
| 351 |
+
return unittest.skipUnless(is_torch_xla_available(check_is_tpu=True), "test requires TPU")(test_case)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def require_non_torch_xla(test_case):
|
| 355 |
+
"""
|
| 356 |
+
Decorator marking a test as requiring an environment without TorchXLA. These tests are skipped when TorchXLA is
|
| 357 |
+
available.
|
| 358 |
+
"""
|
| 359 |
+
return unittest.skipUnless(not is_torch_xla_available(), "test requires an env without TorchXLA")(test_case)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def require_single_device(test_case):
|
| 363 |
+
"""
|
| 364 |
+
Decorator marking a test that requires a single device. These tests are skipped when there is no hardware
|
| 365 |
+
accelerator available or number of devices is more than one.
|
| 366 |
+
"""
|
| 367 |
+
return unittest.skipUnless(
|
| 368 |
+
torch_device != "cpu" and device_count == 1, "test requires a single device accelerator"
|
| 369 |
+
)(test_case)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def require_single_gpu(test_case):
|
| 373 |
+
"""
|
| 374 |
+
Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU
|
| 375 |
+
available or number of GPUs is more than one.
|
| 376 |
+
"""
|
| 377 |
+
return unittest.skipUnless(torch.cuda.device_count() == 1, "test requires a GPU")(test_case)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def require_single_xpu(test_case):
|
| 381 |
+
"""
|
| 382 |
+
Decorator marking a test that requires CUDA on a single XPU. These tests are skipped when there are no XPU
|
| 383 |
+
available or number of xPUs is more than one.
|
| 384 |
+
"""
|
| 385 |
+
return unittest.skipUnless(torch.xpu.device_count() == 1, "test requires a XPU")(test_case)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def require_multi_device(test_case):
|
| 389 |
+
"""
|
| 390 |
+
Decorator marking a test that requires a multi-device setup. These tests are skipped on a machine without multiple
|
| 391 |
+
devices.
|
| 392 |
+
"""
|
| 393 |
+
return unittest.skipUnless(device_count > 1, "test requires multiple hardware accelerators")(test_case)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def require_multi_gpu(test_case):
|
| 397 |
+
"""
|
| 398 |
+
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
|
| 399 |
+
GPUs.
|
| 400 |
+
"""
|
| 401 |
+
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def require_multi_xpu(test_case):
|
| 405 |
+
"""
|
| 406 |
+
Decorator marking a test that requires a multi-XPU setup. These tests are skipped on a machine without multiple
|
| 407 |
+
XPUs.
|
| 408 |
+
"""
|
| 409 |
+
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def require_multi_gpu_or_xpu(test_case):
|
| 413 |
+
"""
|
| 414 |
+
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
|
| 415 |
+
GPUs or XPUs.
|
| 416 |
+
"""
|
| 417 |
+
return unittest.skipUnless(
|
| 418 |
+
(is_cuda_available() or is_xpu_available()) and device_count > 1, "test requires multiple GPUs or XPUs"
|
| 419 |
+
)(test_case)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def require_deepspeed(test_case):
|
| 423 |
+
"""
|
| 424 |
+
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
|
| 425 |
+
"""
|
| 426 |
+
return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def require_tp(test_case):
|
| 430 |
+
"""
|
| 431 |
+
Decorator marking a test that requires TP installed. These tests are skipped when TP isn't installed
|
| 432 |
+
"""
|
| 433 |
+
return unittest.skipUnless(
|
| 434 |
+
is_torch_version(">=", "2.3.0") and compare_versions("transformers", ">=", "4.52.0"),
|
| 435 |
+
"test requires torch version >= 2.3.0 and transformers version >= 4.52.0",
|
| 436 |
+
)(test_case)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def require_torch_min_version(test_case=None, version=None):
|
| 440 |
+
"""
|
| 441 |
+
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an
|
| 442 |
+
installed torch version is less than the required one.
|
| 443 |
+
"""
|
| 444 |
+
if test_case is None:
|
| 445 |
+
return partial(require_torch_min_version, version=version)
|
| 446 |
+
return unittest.skipUnless(is_torch_version(">=", version), f"test requires torch version >= {version}")(test_case)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def require_tensorboard(test_case):
|
| 450 |
+
"""
|
| 451 |
+
Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't
|
| 452 |
+
installed
|
| 453 |
+
"""
|
| 454 |
+
return unittest.skipUnless(is_tensorboard_available(), "test requires Tensorboard")(test_case)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def require_wandb(test_case):
|
| 458 |
+
"""
|
| 459 |
+
Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed
|
| 460 |
+
"""
|
| 461 |
+
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def require_trackio(test_case):
|
| 465 |
+
"""
|
| 466 |
+
Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed
|
| 467 |
+
"""
|
| 468 |
+
return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def require_comet_ml(test_case):
|
| 472 |
+
"""
|
| 473 |
+
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
|
| 474 |
+
"""
|
| 475 |
+
return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def require_aim(test_case):
|
| 479 |
+
"""
|
| 480 |
+
Decorator marking a test that requires aim installed. These tests are skipped when aim isn't installed
|
| 481 |
+
"""
|
| 482 |
+
return unittest.skipUnless(is_aim_available(), "test requires aim")(test_case)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def require_clearml(test_case):
|
| 486 |
+
"""
|
| 487 |
+
Decorator marking a test that requires clearml installed. These tests are skipped when clearml isn't installed
|
| 488 |
+
"""
|
| 489 |
+
return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def require_dvclive(test_case):
|
| 493 |
+
"""
|
| 494 |
+
Decorator marking a test that requires dvclive installed. These tests are skipped when dvclive isn't installed
|
| 495 |
+
"""
|
| 496 |
+
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def require_swanlab(test_case):
|
| 500 |
+
"""
|
| 501 |
+
Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
|
| 502 |
+
"""
|
| 503 |
+
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def require_pandas(test_case):
|
| 507 |
+
"""
|
| 508 |
+
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
|
| 509 |
+
"""
|
| 510 |
+
return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def require_mlflow(test_case):
|
| 514 |
+
"""
|
| 515 |
+
Decorator marking a test that requires mlflow installed. These tests are skipped when mlflow isn't installed
|
| 516 |
+
"""
|
| 517 |
+
return unittest.skipUnless(is_mlflow_available(), "test requires mlflow")(test_case)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def require_pippy(test_case):
|
| 521 |
+
"""
|
| 522 |
+
Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed It is
|
| 523 |
+
also checked if the test is running on a Gaudi1 device which doesn't support pippy.
|
| 524 |
+
"""
|
| 525 |
+
return unittest.skipUnless(is_pippy_available() and not is_habana_gaudi1(), "test requires pippy")(test_case)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def require_import_timer(test_case):
|
| 529 |
+
"""
|
| 530 |
+
Decorator marking a test that requires tuna interpreter installed. These tests are skipped when tuna isn't
|
| 531 |
+
installed
|
| 532 |
+
"""
|
| 533 |
+
return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def require_transformer_engine(test_case):
|
| 537 |
+
"""
|
| 538 |
+
Decorator marking a test that requires transformers engine installed. These tests are skipped when transformers
|
| 539 |
+
engine isn't installed
|
| 540 |
+
"""
|
| 541 |
+
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def require_transformer_engine_mxfp8(test_case):
|
| 545 |
+
"""
|
| 546 |
+
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
|
| 547 |
+
when transformers engine MXFP8 block scaling isn't available
|
| 548 |
+
"""
|
| 549 |
+
return unittest.skipUnless(
|
| 550 |
+
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
|
| 551 |
+
)(test_case)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def require_torchao(test_case):
|
| 555 |
+
"""
|
| 556 |
+
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
|
| 557 |
+
"""
|
| 558 |
+
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def require_matplotlib(test_case):
|
| 562 |
+
"""
|
| 563 |
+
Decorator marking a test that requires matplotlib installed. These tests are skipped when matplotlib isn't
|
| 564 |
+
installed
|
| 565 |
+
"""
|
| 566 |
+
return unittest.skipUnless(is_matplotlib_available(), "test requires matplotlib")(test_case)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
_atleast_one_tracker_available = (
|
| 570 |
+
any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])
|
| 571 |
+
and not is_comet_ml_available()
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def require_trackers(test_case):
|
| 576 |
+
"""
|
| 577 |
+
Decorator marking that a test requires at least one tracking library installed. These tests are skipped when none
|
| 578 |
+
are installed
|
| 579 |
+
"""
|
| 580 |
+
return unittest.skipUnless(
|
| 581 |
+
_atleast_one_tracker_available,
|
| 582 |
+
"test requires at least one tracker to be available and for `comet_ml` to not be installed",
|
| 583 |
+
)(test_case)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def require_torchdata_stateful_dataloader(test_case):
|
| 587 |
+
"""
|
| 588 |
+
Decorator marking a test that requires torchdata.stateful_dataloader.
|
| 589 |
+
|
| 590 |
+
These tests are skipped when torchdata with stateful_dataloader module isn't installed.
|
| 591 |
+
|
| 592 |
+
"""
|
| 593 |
+
return unittest.skipUnless(
|
| 594 |
+
is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader"
|
| 595 |
+
)(test_case)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def run_first(test_case):
|
| 599 |
+
"""
|
| 600 |
+
Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are
|
| 601 |
+
guaranteed to run first.
|
| 602 |
+
|
| 603 |
+
This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
|
| 604 |
+
single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
|
| 605 |
+
allocation conflicts.
|
| 606 |
+
|
| 607 |
+
If pytest is not installed, test will be returned as is.
|
| 608 |
+
"""
|
| 609 |
+
|
| 610 |
+
if is_pytest_available():
|
| 611 |
+
import pytest
|
| 612 |
+
|
| 613 |
+
return pytest.mark.order(1)(test_case)
|
| 614 |
+
return test_case
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class TempDirTestCase(unittest.TestCase):
|
| 618 |
+
"""
|
| 619 |
+
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
|
| 620 |
+
data at the start of a test, and then destroys it at the end of the TestCase.
|
| 621 |
+
|
| 622 |
+
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
|
| 623 |
+
|
| 624 |
+
The temporary directory location will be stored in `self.tmpdir`
|
| 625 |
+
"""
|
| 626 |
+
|
| 627 |
+
clear_on_setup = True
|
| 628 |
+
|
| 629 |
+
@classmethod
|
| 630 |
+
def setUpClass(cls):
|
| 631 |
+
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
|
| 632 |
+
cls.tmpdir = Path(tempfile.mkdtemp())
|
| 633 |
+
|
| 634 |
+
@classmethod
|
| 635 |
+
def tearDownClass(cls):
|
| 636 |
+
"Remove `cls.tmpdir` after test suite has finished"
|
| 637 |
+
if os.path.exists(cls.tmpdir):
|
| 638 |
+
shutil.rmtree(cls.tmpdir)
|
| 639 |
+
|
| 640 |
+
def setUp(self):
|
| 641 |
+
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
|
| 642 |
+
if self.clear_on_setup:
|
| 643 |
+
for path in self.tmpdir.glob("**/*"):
|
| 644 |
+
if path.is_file():
|
| 645 |
+
path.unlink()
|
| 646 |
+
elif path.is_dir():
|
| 647 |
+
shutil.rmtree(path)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
class AccelerateTestCase(unittest.TestCase):
|
| 651 |
+
"""
|
| 652 |
+
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
|
| 653 |
+
the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between
|
| 654 |
+
tests.
|
| 655 |
+
"""
|
| 656 |
+
|
| 657 |
+
def tearDown(self):
|
| 658 |
+
super().tearDown()
|
| 659 |
+
# Reset the state of the AcceleratorState singleton.
|
| 660 |
+
AcceleratorState._reset_state(True)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
class MockingTestCase(unittest.TestCase):
|
| 664 |
+
"""
|
| 665 |
+
A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the
|
| 666 |
+
behavior of a class-wide mock when defining one normally will not do.
|
| 667 |
+
|
| 668 |
+
Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as
|
| 669 |
+
setting an environment variable with that information.
|
| 670 |
+
|
| 671 |
+
The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to
|
| 672 |
+
`super().setUp()` such as:
|
| 673 |
+
```python
|
| 674 |
+
def setUp(self):
|
| 675 |
+
super().setUp()
|
| 676 |
+
mocks = mock.patch.dict(os.environ, {"SOME_ENV_VAR", "SOME_VALUE"})
|
| 677 |
+
self.add_mocks(mocks)
|
| 678 |
+
```
|
| 679 |
+
"""
|
| 680 |
+
|
| 681 |
+
def add_mocks(self, mocks: Union[mock.Mock, list[mock.Mock]]):
|
| 682 |
+
"""
|
| 683 |
+
Add custom mocks for tests that should be repeated on each test. Should be called during
|
| 684 |
+
`MockingTestCase.setUp`, after `super().setUp()`.
|
| 685 |
+
|
| 686 |
+
Args:
|
| 687 |
+
mocks (`mock.Mock` or list of `mock.Mock`):
|
| 688 |
+
Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run
|
| 689 |
+
"""
|
| 690 |
+
self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks]
|
| 691 |
+
for m in self.mocks:
|
| 692 |
+
m.start()
|
| 693 |
+
self.addCleanup(m.stop)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def are_the_same_tensors(tensor):
|
| 697 |
+
state = AcceleratorState()
|
| 698 |
+
tensor = tensor[None].clone().to(state.device)
|
| 699 |
+
tensors = gather(tensor).cpu()
|
| 700 |
+
tensor = tensor[0].cpu()
|
| 701 |
+
for i in range(tensors.shape[0]):
|
| 702 |
+
if not torch.equal(tensors[i], tensor):
|
| 703 |
+
return False
|
| 704 |
+
return True
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class _RunOutput:
|
| 708 |
+
def __init__(self, returncode, stdout, stderr):
|
| 709 |
+
self.returncode = returncode
|
| 710 |
+
self.stdout = stdout
|
| 711 |
+
self.stderr = stderr
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
async def _read_stream(stream, callback):
|
| 715 |
+
while True:
|
| 716 |
+
line = await stream.readline()
|
| 717 |
+
if line:
|
| 718 |
+
callback(line)
|
| 719 |
+
else:
|
| 720 |
+
break
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
| 724 |
+
if echo:
|
| 725 |
+
print("\nRunning: ", " ".join(cmd))
|
| 726 |
+
|
| 727 |
+
p = await asyncio.create_subprocess_exec(
|
| 728 |
+
cmd[0],
|
| 729 |
+
*cmd[1:],
|
| 730 |
+
stdin=stdin,
|
| 731 |
+
stdout=asyncio.subprocess.PIPE,
|
| 732 |
+
stderr=asyncio.subprocess.PIPE,
|
| 733 |
+
env=env,
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
| 737 |
+
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
| 738 |
+
#
|
| 739 |
+
# If it starts hanging, will need to switch to the following code. The problem is that no data
|
| 740 |
+
# will be seen until it's done and if it hangs for example there will be no debug info.
|
| 741 |
+
# out, err = await p.communicate()
|
| 742 |
+
# return _RunOutput(p.returncode, out, err)
|
| 743 |
+
|
| 744 |
+
out = []
|
| 745 |
+
err = []
|
| 746 |
+
|
| 747 |
+
def tee(line, sink, pipe, label=""):
|
| 748 |
+
line = line.decode("utf-8").rstrip()
|
| 749 |
+
sink.append(line)
|
| 750 |
+
if not quiet:
|
| 751 |
+
print(label, line, file=pipe)
|
| 752 |
+
|
| 753 |
+
# XXX: the timeout doesn't seem to make any difference here
|
| 754 |
+
await asyncio.wait(
|
| 755 |
+
[
|
| 756 |
+
asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
|
| 757 |
+
asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
|
| 758 |
+
],
|
| 759 |
+
timeout=timeout,
|
| 760 |
+
)
|
| 761 |
+
return _RunOutput(await p.wait(), out, err)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
|
| 765 |
+
# Cast every path in `cmd` to a string
|
| 766 |
+
for i, c in enumerate(cmd):
|
| 767 |
+
if isinstance(c, Path):
|
| 768 |
+
cmd[i] = str(c)
|
| 769 |
+
loop = asyncio.get_event_loop()
|
| 770 |
+
result = loop.run_until_complete(
|
| 771 |
+
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
cmd_str = " ".join(cmd)
|
| 775 |
+
if result.returncode > 0:
|
| 776 |
+
stderr = "\n".join(result.stderr)
|
| 777 |
+
raise RuntimeError(
|
| 778 |
+
f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
|
| 779 |
+
f"The combined stderr from workers follows:\n{stderr}"
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
return result
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def pytest_xdist_worker_id():
|
| 786 |
+
"""
|
| 787 |
+
Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
|
| 788 |
+
if `-n 1` or `pytest-xdist` isn't being used.
|
| 789 |
+
"""
|
| 790 |
+
worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
|
| 791 |
+
worker = re.sub(r"^gw", "", worker, 0, re.M)
|
| 792 |
+
return int(worker)
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def get_torch_dist_unique_port():
|
| 796 |
+
"""
|
| 797 |
+
Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
|
| 798 |
+
|
| 799 |
+
Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
|
| 800 |
+
port at once.
|
| 801 |
+
"""
|
| 802 |
+
port = 29500
|
| 803 |
+
uniq_delta = pytest_xdist_worker_id()
|
| 804 |
+
return port + uniq_delta
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
class SubprocessCallException(Exception):
|
| 808 |
+
pass
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
def run_command(command: list[str], return_stdout=False, env=None):
|
| 812 |
+
"""
|
| 813 |
+
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
| 814 |
+
if an error occurred while running `command`
|
| 815 |
+
"""
|
| 816 |
+
# Cast every path in `command` to a string
|
| 817 |
+
for i, c in enumerate(command):
|
| 818 |
+
if isinstance(c, Path):
|
| 819 |
+
command[i] = str(c)
|
| 820 |
+
if env is None:
|
| 821 |
+
env = os.environ.copy()
|
| 822 |
+
try:
|
| 823 |
+
output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
|
| 824 |
+
if return_stdout:
|
| 825 |
+
if hasattr(output, "decode"):
|
| 826 |
+
output = output.decode("utf-8")
|
| 827 |
+
return output
|
| 828 |
+
except subprocess.CalledProcessError as e:
|
| 829 |
+
raise SubprocessCallException(
|
| 830 |
+
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
| 831 |
+
) from e
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def path_in_accelerate_package(*components: str) -> Path:
|
| 835 |
+
"""
|
| 836 |
+
Get a path within the `accelerate` package's directory.
|
| 837 |
+
|
| 838 |
+
Args:
|
| 839 |
+
*components: Components of the path to join after the package directory.
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
`Path`: The path to the requested file or directory.
|
| 843 |
+
"""
|
| 844 |
+
|
| 845 |
+
accelerate_package_dir = Path(inspect.getfile(accelerate)).parent
|
| 846 |
+
return accelerate_package_dir.joinpath(*components)
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
@contextmanager
|
| 850 |
+
def assert_exception(exception_class: Exception, msg: Optional[str] = None) -> bool:
|
| 851 |
+
"""
|
| 852 |
+
Context manager to assert that the right `Exception` class was raised.
|
| 853 |
+
|
| 854 |
+
If `msg` is provided, will check that the message is contained in the raised exception.
|
| 855 |
+
"""
|
| 856 |
+
was_ran = False
|
| 857 |
+
try:
|
| 858 |
+
yield
|
| 859 |
+
was_ran = True
|
| 860 |
+
except Exception as e:
|
| 861 |
+
assert isinstance(e, exception_class), f"Expected exception of type {exception_class} but got {type(e)}"
|
| 862 |
+
if msg is not None:
|
| 863 |
+
assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'"
|
| 864 |
+
if was_ran:
|
| 865 |
+
raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.")
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def capture_call_output(func, *args, **kwargs):
|
| 869 |
+
"""
|
| 870 |
+
Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string
|
| 871 |
+
"""
|
| 872 |
+
captured_output = io.StringIO()
|
| 873 |
+
original_stdout = sys.stdout
|
| 874 |
+
try:
|
| 875 |
+
sys.stdout = captured_output
|
| 876 |
+
func(*args, **kwargs)
|
| 877 |
+
except Exception as e:
|
| 878 |
+
raise e
|
| 879 |
+
finally:
|
| 880 |
+
sys.stdout = original_stdout
|
| 881 |
+
return captured_output.getvalue()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/training.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
|
| 19 |
+
from accelerate.utils.dataclasses import DistributedType
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RegressionDataset:
|
| 23 |
+
def __init__(self, a=2, b=3, length=64, seed=None):
|
| 24 |
+
rng = np.random.default_rng(seed)
|
| 25 |
+
self.length = length
|
| 26 |
+
self.x = rng.normal(size=(length,)).astype(np.float32)
|
| 27 |
+
self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return self.length
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, i):
|
| 33 |
+
return {"x": self.x[i], "y": self.y[i]}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RegressionModel(torch.nn.Module):
|
| 37 |
+
def __init__(self, a=0, b=0, double_output=False):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.a = torch.nn.Parameter(torch.tensor(a).float())
|
| 40 |
+
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
| 41 |
+
self.first_batch = True
|
| 42 |
+
|
| 43 |
+
def forward(self, x=None):
|
| 44 |
+
if self.first_batch:
|
| 45 |
+
print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}")
|
| 46 |
+
self.first_batch = False
|
| 47 |
+
return x * self.a + self.b
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def mocked_dataloaders(accelerator, batch_size: int = 16):
|
| 51 |
+
from datasets import load_dataset
|
| 52 |
+
from transformers import AutoTokenizer
|
| 53 |
+
|
| 54 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
| 55 |
+
data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
|
| 56 |
+
datasets = load_dataset("csv", data_files=data_files)
|
| 57 |
+
label_list = datasets["train"].unique("label")
|
| 58 |
+
|
| 59 |
+
label_to_id = {v: i for i, v in enumerate(label_list)}
|
| 60 |
+
|
| 61 |
+
def tokenize_function(examples):
|
| 62 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 63 |
+
outputs = tokenizer(
|
| 64 |
+
examples["sentence1"], examples["sentence2"], truncation=True, max_length=None, padding="max_length"
|
| 65 |
+
)
|
| 66 |
+
if "label" in examples:
|
| 67 |
+
outputs["labels"] = [label_to_id[l] for l in examples["label"]]
|
| 68 |
+
return outputs
|
| 69 |
+
|
| 70 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 71 |
+
tokenized_datasets = datasets.map(
|
| 72 |
+
tokenize_function,
|
| 73 |
+
batched=True,
|
| 74 |
+
remove_columns=["sentence1", "sentence2", "label"],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def collate_fn(examples):
|
| 78 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 79 |
+
if accelerator.distributed_type == DistributedType.XLA:
|
| 80 |
+
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
| 81 |
+
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
| 82 |
+
|
| 83 |
+
# Instantiate dataloaders.
|
| 84 |
+
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=2)
|
| 85 |
+
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)
|
| 86 |
+
|
| 87 |
+
return train_dataloader, eval_dataloader
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def mocked_dataloaders_for_autoregressive_models(accelerator, batch_size: int = 16):
|
| 91 |
+
from datasets import load_dataset
|
| 92 |
+
from transformers import AutoTokenizer
|
| 93 |
+
|
| 94 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M")
|
| 95 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 96 |
+
|
| 97 |
+
data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
|
| 98 |
+
datasets = load_dataset("csv", data_files=data_files)
|
| 99 |
+
|
| 100 |
+
def tokenize_function(examples):
|
| 101 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 102 |
+
outputs = tokenizer(examples["sentence1"], truncation=True, max_length=None, return_attention_mask=False)
|
| 103 |
+
return outputs
|
| 104 |
+
|
| 105 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 106 |
+
# starting with the main process first:
|
| 107 |
+
with accelerator.main_process_first():
|
| 108 |
+
tokenized_datasets = datasets.map(
|
| 109 |
+
tokenize_function,
|
| 110 |
+
batched=True,
|
| 111 |
+
remove_columns=["sentence1", "sentence2", "label"],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def collate_fn(examples):
|
| 115 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 116 |
+
max_length = (
|
| 117 |
+
128
|
| 118 |
+
if accelerator.distributed_type == DistributedType.XLA
|
| 119 |
+
else max([len(e["input_ids"]) for e in examples])
|
| 120 |
+
)
|
| 121 |
+
# When using mixed precision we want round multiples of 8/16
|
| 122 |
+
if accelerator.mixed_precision == "fp8":
|
| 123 |
+
pad_to_multiple_of = 16
|
| 124 |
+
elif accelerator.mixed_precision != "no":
|
| 125 |
+
pad_to_multiple_of = 8
|
| 126 |
+
else:
|
| 127 |
+
pad_to_multiple_of = None
|
| 128 |
+
|
| 129 |
+
batch = tokenizer.pad(
|
| 130 |
+
examples,
|
| 131 |
+
padding="max_length",
|
| 132 |
+
max_length=max_length + 1,
|
| 133 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 134 |
+
return_tensors="pt",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
batch["labels"] = batch["input_ids"][:, 1:]
|
| 138 |
+
batch["input_ids"] = batch["input_ids"][:, :-1]
|
| 139 |
+
|
| 140 |
+
batch["labels"] = torch.where(batch["labels"] == tokenizer.pad_token_id, -100, batch["labels"])
|
| 141 |
+
|
| 142 |
+
return batch
|
| 143 |
+
|
| 144 |
+
# Instantiate dataloaders.
|
| 145 |
+
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=2)
|
| 146 |
+
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)
|
| 147 |
+
|
| 148 |
+
return train_dataloader, eval_dataloader
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.14 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/abstract_nodes.cpython-312.pyc
ADDED
|
Binary file (1.14 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/algorithms.cpython-312.pyc
ADDED
|
Binary file (8.64 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/approximations.cpython-312.pyc
ADDED
|
Binary file (9.03 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/ast.cpython-312.pyc
ADDED
|
Binary file (74 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cfunctions.cpython-312.pyc
ADDED
|
Binary file (19.1 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cnodes.cpython-312.pyc
ADDED
|
Binary file (5.62 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cutils.cpython-312.pyc
ADDED
|
Binary file (813 Bytes). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/cxxnodes.cpython-312.pyc
ADDED
|
Binary file (798 Bytes). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/fnodes.cpython-312.pyc
ADDED
|
Binary file (26.2 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/futils.cpython-312.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/matrix_nodes.cpython-312.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/numpy_nodes.cpython-312.pyc
ADDED
|
Binary file (7.93 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/pynodes.cpython-312.pyc
ADDED
|
Binary file (744 Bytes). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/pyutils.cpython-312.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/rewriting.cpython-312.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/__pycache__/scipy_nodes.cpython-312.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__init__.py
ADDED
|
File without changes
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (211 Bytes). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_abstract_nodes.cpython-312.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_algorithms.cpython-312.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_applications.cpython-312.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_approximations.cpython-312.pyc
ADDED
|
Binary file (3.54 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_ast.cpython-312.pyc
ADDED
|
Binary file (48.4 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_cfunctions.cpython-312.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_cnodes.cpython-312.pyc
ADDED
|
Binary file (6.72 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_cxxnodes.cpython-312.pyc
ADDED
|
Binary file (846 Bytes). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_fnodes.cpython-312.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_matrix_nodes.cpython-312.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/sympy/codegen/tests/__pycache__/test_numpy_nodes.cpython-312.pyc
ADDED
|
Binary file (4.72 kB). View file
|
|
|