Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/hooks.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/memory_utils.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/state.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/env.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/estimate.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/launch.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/merge.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/test.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/to_fsdp2.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/tpu.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/utils.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__init__.py +52 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/default.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/sagemaker.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/update.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/cluster.py +917 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config.py +89 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_args.py +256 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_utils.py +122 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/default.py +163 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/sagemaker.py +274 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/update.py +63 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__init__.py +14 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/helpers.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/input.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/cursor.py +65 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/helpers.py +59 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/input.py +84 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/keymap.py +133 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/selection_menu.py +145 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/examples.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/testing.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/training.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/__init__.py +13 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_ddp_comm_hook.py +85 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_distributed_data_loop.py +410 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_merge_weights.py +158 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_notebook.py +118 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/hooks.cpython-312.pyc
ADDED
|
Binary file (34.4 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/memory_utils.cpython-312.pyc
ADDED
|
Binary file (517 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/state.cpython-312.pyc
ADDED
|
Binary file (64.5 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-312.pyc
ADDED
|
Binary file (1.86 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/env.cpython-312.pyc
ADDED
|
Binary file (5.17 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/estimate.cpython-312.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/launch.cpython-312.pyc
ADDED
|
Binary file (51.8 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/merge.cpython-312.pyc
ADDED
|
Binary file (2.45 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/test.cpython-312.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/to_fsdp2.cpython-312.pyc
ADDED
|
Binary file (6.25 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/tpu.cpython-312.pyc
ADDED
|
Binary file (6.04 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (5.24 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from .config import config_command_parser
|
| 20 |
+
from .config_args import default_config_file, load_config_from_file # noqa: F401
|
| 21 |
+
from .default import default_command_parser
|
| 22 |
+
from .update import update_command_parser
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_config_parser(subparsers=None):
|
| 26 |
+
parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
| 27 |
+
# The main config parser
|
| 28 |
+
config_parser = config_command_parser(subparsers)
|
| 29 |
+
# The subparser to add commands to
|
| 30 |
+
subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand")
|
| 31 |
+
|
| 32 |
+
# Then add other parsers with the parent parser
|
| 33 |
+
default_command_parser(subcommands, parents=[parent_parser])
|
| 34 |
+
update_command_parser(subcommands, parents=[parent_parser])
|
| 35 |
+
|
| 36 |
+
return config_parser
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
config_parser = get_config_parser()
|
| 41 |
+
args = config_parser.parse_args()
|
| 42 |
+
|
| 43 |
+
if not hasattr(args, "func"):
|
| 44 |
+
config_parser.print_help()
|
| 45 |
+
exit(1)
|
| 46 |
+
|
| 47 |
+
# Run
|
| 48 |
+
args.func(args)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-312.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (3.29 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-312.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-312.pyc
ADDED
|
Binary file (3.97 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/default.cpython-312.pyc
ADDED
|
Binary file (5.99 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/sagemaker.cpython-312.pyc
ADDED
|
Binary file (9.52 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/update.cpython-312.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/cluster.py
ADDED
|
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from ...utils import (
|
| 20 |
+
ComputeEnvironment,
|
| 21 |
+
DistributedType,
|
| 22 |
+
is_deepspeed_available,
|
| 23 |
+
is_fp8_available,
|
| 24 |
+
is_hpu_available,
|
| 25 |
+
is_mlu_available,
|
| 26 |
+
is_mps_available,
|
| 27 |
+
is_msamp_available,
|
| 28 |
+
is_musa_available,
|
| 29 |
+
is_npu_available,
|
| 30 |
+
is_sdaa_available,
|
| 31 |
+
is_transformer_engine_available,
|
| 32 |
+
is_transformers_available,
|
| 33 |
+
is_xpu_available,
|
| 34 |
+
)
|
| 35 |
+
from ...utils.constants import (
|
| 36 |
+
DEEPSPEED_MULTINODE_LAUNCHERS,
|
| 37 |
+
FSDP2_STATE_DICT_TYPE,
|
| 38 |
+
FSDP_AUTO_WRAP_POLICY,
|
| 39 |
+
FSDP_BACKWARD_PREFETCH,
|
| 40 |
+
FSDP_SHARDING_STRATEGY,
|
| 41 |
+
FSDP_STATE_DICT_TYPE,
|
| 42 |
+
TORCH_DYNAMO_MODES,
|
| 43 |
+
)
|
| 44 |
+
from .config_args import ClusterConfig
|
| 45 |
+
from .config_utils import (
|
| 46 |
+
DYNAMO_BACKENDS,
|
| 47 |
+
_ask_field,
|
| 48 |
+
_ask_options,
|
| 49 |
+
_convert_distributed_mode,
|
| 50 |
+
_convert_dynamo_backend,
|
| 51 |
+
_convert_fp8_backend,
|
| 52 |
+
_convert_mixed_precision,
|
| 53 |
+
_convert_yes_no_to_bool,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_cluster_input():
|
| 58 |
+
distributed_type = _ask_options(
|
| 59 |
+
"Which type of machine are you using?",
|
| 60 |
+
[
|
| 61 |
+
"No distributed training",
|
| 62 |
+
"multi-CPU",
|
| 63 |
+
"multi-XPU",
|
| 64 |
+
"multi-HPU",
|
| 65 |
+
"multi-GPU",
|
| 66 |
+
"multi-NPU",
|
| 67 |
+
"multi-MLU",
|
| 68 |
+
"multi-SDAA",
|
| 69 |
+
"multi-MUSA",
|
| 70 |
+
"TPU",
|
| 71 |
+
],
|
| 72 |
+
_convert_distributed_mode,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
machine_rank = 0
|
| 76 |
+
num_machines = 1
|
| 77 |
+
num_processes = 1
|
| 78 |
+
gpu_ids = None
|
| 79 |
+
main_process_ip = None
|
| 80 |
+
main_process_port = None
|
| 81 |
+
rdzv_backend = "static"
|
| 82 |
+
same_network = True
|
| 83 |
+
debug = False
|
| 84 |
+
|
| 85 |
+
if distributed_type in [
|
| 86 |
+
DistributedType.MULTI_GPU,
|
| 87 |
+
DistributedType.MULTI_MLU,
|
| 88 |
+
DistributedType.MULTI_SDAA,
|
| 89 |
+
DistributedType.MULTI_MUSA,
|
| 90 |
+
DistributedType.MULTI_NPU,
|
| 91 |
+
DistributedType.MULTI_XPU,
|
| 92 |
+
DistributedType.MULTI_CPU,
|
| 93 |
+
DistributedType.MULTI_HPU,
|
| 94 |
+
]:
|
| 95 |
+
num_machines = _ask_field(
|
| 96 |
+
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
|
| 97 |
+
int,
|
| 98 |
+
default=1,
|
| 99 |
+
)
|
| 100 |
+
if num_machines > 1:
|
| 101 |
+
machine_rank = _ask_options(
|
| 102 |
+
"What is the rank of this machine?",
|
| 103 |
+
list(range(num_machines)),
|
| 104 |
+
int,
|
| 105 |
+
)
|
| 106 |
+
main_process_ip = _ask_field(
|
| 107 |
+
"What is the IP address of the machine that will host the main process? ",
|
| 108 |
+
)
|
| 109 |
+
main_process_port = _ask_field(
|
| 110 |
+
"What is the port you will use to communicate with the main process? ",
|
| 111 |
+
int,
|
| 112 |
+
)
|
| 113 |
+
same_network = _ask_field(
|
| 114 |
+
"Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: ",
|
| 115 |
+
_convert_yes_no_to_bool,
|
| 116 |
+
default=True,
|
| 117 |
+
error_message="Please enter yes or no.",
|
| 118 |
+
)
|
| 119 |
+
if not same_network:
|
| 120 |
+
rdzv_backend = _ask_field(
|
| 121 |
+
"What rendezvous backend will you use? ('static', 'c10d', ...): ", default="static"
|
| 122 |
+
)
|
| 123 |
+
debug = _ask_field(
|
| 124 |
+
"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
|
| 125 |
+
_convert_yes_no_to_bool,
|
| 126 |
+
default=False,
|
| 127 |
+
error_message="Please enter yes or no.",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if distributed_type == DistributedType.NO:
|
| 131 |
+
use_cpu = _ask_field(
|
| 132 |
+
"Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:",
|
| 133 |
+
_convert_yes_no_to_bool,
|
| 134 |
+
default=False,
|
| 135 |
+
error_message="Please enter yes or no.",
|
| 136 |
+
)
|
| 137 |
+
elif distributed_type == DistributedType.MULTI_CPU:
|
| 138 |
+
use_cpu = True
|
| 139 |
+
else:
|
| 140 |
+
use_cpu = False
|
| 141 |
+
|
| 142 |
+
ipex_config = {}
|
| 143 |
+
mpirun_config = {}
|
| 144 |
+
if use_cpu or is_xpu_available():
|
| 145 |
+
ipex_config["ipex"] = _ask_field(
|
| 146 |
+
"Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU/XPU? [yes/NO]:",
|
| 147 |
+
_convert_yes_no_to_bool,
|
| 148 |
+
default=False,
|
| 149 |
+
error_message="Please enter yes or no.",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if use_cpu:
|
| 153 |
+
if distributed_type == DistributedType.MULTI_CPU:
|
| 154 |
+
use_mpirun = _ask_field(
|
| 155 |
+
"Do you want accelerate to launch mpirun? [yes/NO]: ",
|
| 156 |
+
_convert_yes_no_to_bool,
|
| 157 |
+
default=False,
|
| 158 |
+
error_message="Please enter yes or no.",
|
| 159 |
+
)
|
| 160 |
+
if use_mpirun:
|
| 161 |
+
mpirun_hostfile = _ask_field(
|
| 162 |
+
"Please enter the path to the hostfile to use with mpirun [~/hostfile]: ",
|
| 163 |
+
str,
|
| 164 |
+
default="~/hostfile",
|
| 165 |
+
)
|
| 166 |
+
mpirun_config["mpirun_hostfile"] = os.path.expanduser(mpirun_hostfile.strip())
|
| 167 |
+
mpirun_config["mpirun_ccl"] = _ask_field("Enter the number of oneCCL worker threads [1]: ", default=1)
|
| 168 |
+
|
| 169 |
+
dynamo_config = {}
|
| 170 |
+
use_dynamo = _ask_field(
|
| 171 |
+
"Do you wish to optimize your script with torch dynamo?[yes/NO]:",
|
| 172 |
+
_convert_yes_no_to_bool,
|
| 173 |
+
default=False,
|
| 174 |
+
error_message="Please enter yes or no.",
|
| 175 |
+
)
|
| 176 |
+
if use_dynamo:
|
| 177 |
+
prefix = "dynamo_"
|
| 178 |
+
dynamo_config[prefix + "backend"] = _ask_options(
|
| 179 |
+
"Which dynamo backend would you like to use?",
|
| 180 |
+
[x.lower() for x in DYNAMO_BACKENDS],
|
| 181 |
+
_convert_dynamo_backend,
|
| 182 |
+
default=2,
|
| 183 |
+
)
|
| 184 |
+
use_custom_options = _ask_field(
|
| 185 |
+
"Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
|
| 186 |
+
_convert_yes_no_to_bool,
|
| 187 |
+
default=False,
|
| 188 |
+
error_message="Please enter yes or no.",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if use_custom_options:
|
| 192 |
+
dynamo_config[prefix + "mode"] = _ask_options(
|
| 193 |
+
"Which mode do you want to use?",
|
| 194 |
+
TORCH_DYNAMO_MODES,
|
| 195 |
+
lambda x: TORCH_DYNAMO_MODES[int(x)],
|
| 196 |
+
default=0,
|
| 197 |
+
)
|
| 198 |
+
dynamo_config[prefix + "use_fullgraph"] = _ask_field(
|
| 199 |
+
"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
|
| 200 |
+
_convert_yes_no_to_bool,
|
| 201 |
+
default=False,
|
| 202 |
+
error_message="Please enter yes or no.",
|
| 203 |
+
)
|
| 204 |
+
dynamo_config[prefix + "use_dynamic"] = _ask_field(
|
| 205 |
+
"Do you want to enable dynamic shape tracing? [yes/NO]: ",
|
| 206 |
+
_convert_yes_no_to_bool,
|
| 207 |
+
default=False,
|
| 208 |
+
error_message="Please enter yes or no.",
|
| 209 |
+
)
|
| 210 |
+
dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
|
| 211 |
+
"Do you want to enable regional compilation? [yes/NO]: ",
|
| 212 |
+
_convert_yes_no_to_bool,
|
| 213 |
+
default=False,
|
| 214 |
+
error_message="Please enter yes or no.",
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
use_mps = not use_cpu and is_mps_available()
|
| 218 |
+
deepspeed_config = {}
|
| 219 |
+
if (
|
| 220 |
+
distributed_type
|
| 221 |
+
in [
|
| 222 |
+
DistributedType.MULTI_GPU,
|
| 223 |
+
DistributedType.MULTI_XPU,
|
| 224 |
+
DistributedType.MULTI_HPU,
|
| 225 |
+
DistributedType.MULTI_NPU,
|
| 226 |
+
DistributedType.MULTI_MLU,
|
| 227 |
+
DistributedType.MULTI_SDAA,
|
| 228 |
+
DistributedType.MULTI_MUSA,
|
| 229 |
+
DistributedType.NO,
|
| 230 |
+
]
|
| 231 |
+
and not use_mps
|
| 232 |
+
):
|
| 233 |
+
use_deepspeed = _ask_field(
|
| 234 |
+
"Do you want to use DeepSpeed? [yes/NO]: ",
|
| 235 |
+
_convert_yes_no_to_bool,
|
| 236 |
+
default=False,
|
| 237 |
+
error_message="Please enter yes or no.",
|
| 238 |
+
)
|
| 239 |
+
if use_deepspeed:
|
| 240 |
+
distributed_type = DistributedType.DEEPSPEED
|
| 241 |
+
assert is_deepspeed_available(), (
|
| 242 |
+
"DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if distributed_type == DistributedType.DEEPSPEED:
|
| 246 |
+
use_deepspeed_config = _ask_field(
|
| 247 |
+
"Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ",
|
| 248 |
+
_convert_yes_no_to_bool,
|
| 249 |
+
default=False,
|
| 250 |
+
error_message="Please enter yes or no.",
|
| 251 |
+
)
|
| 252 |
+
if use_deepspeed_config:
|
| 253 |
+
deepspeed_config["deepspeed_config_file"] = _ask_field(
|
| 254 |
+
"Please enter the path to the json DeepSpeed config file: ",
|
| 255 |
+
str,
|
| 256 |
+
default="none",
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
deepspeed_config["zero_stage"] = _ask_options(
|
| 260 |
+
"What should be your DeepSpeed's ZeRO optimization stage?",
|
| 261 |
+
[0, 1, 2, 3],
|
| 262 |
+
int,
|
| 263 |
+
default=2,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
deepspeed_devices = ["none", "cpu", "nvme"]
|
| 267 |
+
if deepspeed_config["zero_stage"] >= 2:
|
| 268 |
+
deepspeed_config["offload_optimizer_device"] = _ask_options(
|
| 269 |
+
"Where to offload optimizer states?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
|
| 270 |
+
)
|
| 271 |
+
deepspeed_config["offload_param_device"] = _ask_options(
|
| 272 |
+
"Where to offload parameters?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
|
| 273 |
+
)
|
| 274 |
+
if deepspeed_config["offload_param_device"] == "nvme":
|
| 275 |
+
deepspeed_config["offload_param_nvme_path"] = _ask_field(
|
| 276 |
+
"Nvme Path to offload parameters?",
|
| 277 |
+
str,
|
| 278 |
+
default="/nvme",
|
| 279 |
+
)
|
| 280 |
+
if deepspeed_config["offload_optimizer_device"] == "nvme":
|
| 281 |
+
deepspeed_config["offload_optimizer_nvme_path"] = _ask_field(
|
| 282 |
+
"Nvme Path to offload optimizer states?",
|
| 283 |
+
str,
|
| 284 |
+
default="/nvme",
|
| 285 |
+
)
|
| 286 |
+
deepspeed_config["gradient_accumulation_steps"] = _ask_field(
|
| 287 |
+
"How many gradient accumulation steps you're passing in your script? [1]: ",
|
| 288 |
+
int,
|
| 289 |
+
default=1,
|
| 290 |
+
)
|
| 291 |
+
use_gradient_clipping = _ask_field(
|
| 292 |
+
"Do you want to use gradient clipping? [yes/NO]: ",
|
| 293 |
+
_convert_yes_no_to_bool,
|
| 294 |
+
default=False,
|
| 295 |
+
error_message="Please enter yes or no.",
|
| 296 |
+
)
|
| 297 |
+
if use_gradient_clipping:
|
| 298 |
+
deepspeed_config["gradient_clipping"] = _ask_field(
|
| 299 |
+
"What is the gradient clipping value? [1.0]: ",
|
| 300 |
+
float,
|
| 301 |
+
default=1.0,
|
| 302 |
+
)
|
| 303 |
+
if deepspeed_config["zero_stage"] == 3:
|
| 304 |
+
deepspeed_config["zero3_save_16bit_model"] = _ask_field(
|
| 305 |
+
"Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ",
|
| 306 |
+
_convert_yes_no_to_bool,
|
| 307 |
+
default=False,
|
| 308 |
+
error_message="Please enter yes or no.",
|
| 309 |
+
)
|
| 310 |
+
deepspeed_config["zero3_init_flag"] = _ask_field(
|
| 311 |
+
"Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ",
|
| 312 |
+
_convert_yes_no_to_bool,
|
| 313 |
+
default=False,
|
| 314 |
+
error_message="Please enter yes or no.",
|
| 315 |
+
)
|
| 316 |
+
if deepspeed_config["zero3_init_flag"]:
|
| 317 |
+
if not is_transformers_available():
|
| 318 |
+
raise Exception(
|
| 319 |
+
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
|
| 320 |
+
"Please run `pip3 install transformers`."
|
| 321 |
+
)
|
| 322 |
+
use_moe = _ask_field(
|
| 323 |
+
"Do you want to enable Mixture-of-Experts training (MoE)? [yes/NO]: ",
|
| 324 |
+
_convert_yes_no_to_bool,
|
| 325 |
+
default=False,
|
| 326 |
+
error_message="Please enter yes or no.",
|
| 327 |
+
)
|
| 328 |
+
if use_moe:
|
| 329 |
+
deepspeed_config["deepspeed_moe_layer_cls_names"] = _ask_field(
|
| 330 |
+
"Specify the comma-separated list of transformers MoE layer class names (case-sensitive), e.g : "
|
| 331 |
+
" `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ... : ",
|
| 332 |
+
str,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if num_machines > 1:
|
| 336 |
+
launcher_query = "Which Type of launcher do you want to use?"
|
| 337 |
+
deepspeed_config["deepspeed_multinode_launcher"] = _ask_options(
|
| 338 |
+
launcher_query,
|
| 339 |
+
DEEPSPEED_MULTINODE_LAUNCHERS,
|
| 340 |
+
lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)],
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
|
| 344 |
+
deepspeed_config["deepspeed_hostfile"] = _ask_field(
|
| 345 |
+
"DeepSpeed configures multi-node compute resources with hostfile. "
|
| 346 |
+
"Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; "
|
| 347 |
+
"for more information please refer official [documentation]"
|
| 348 |
+
"(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). "
|
| 349 |
+
"Please specify the location of hostfile: ",
|
| 350 |
+
str,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
is_exclusion_filter = _ask_field(
|
| 354 |
+
"Do you want to specify exclusion filter string? [yes/NO]: ",
|
| 355 |
+
_convert_yes_no_to_bool,
|
| 356 |
+
default=False,
|
| 357 |
+
error_message="Please enter yes or no.",
|
| 358 |
+
)
|
| 359 |
+
if is_exclusion_filter:
|
| 360 |
+
deepspeed_config["deepspeed_exclusion_filter"] = _ask_field(
|
| 361 |
+
"DeepSpeed exclusion filter string: ",
|
| 362 |
+
str,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
is_inclusion_filter = _ask_field(
|
| 366 |
+
"Do you want to specify inclusion filter string? [yes/NO]: ",
|
| 367 |
+
_convert_yes_no_to_bool,
|
| 368 |
+
default=False,
|
| 369 |
+
error_message="Please enter yes or no.",
|
| 370 |
+
)
|
| 371 |
+
if is_inclusion_filter:
|
| 372 |
+
deepspeed_config["deepspeed_inclusion_filter"] = _ask_field(
|
| 373 |
+
"DeepSpeed inclusion filter string: ",
|
| 374 |
+
str,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
fsdp_config = {}
|
| 378 |
+
|
| 379 |
+
if distributed_type in [
|
| 380 |
+
DistributedType.MULTI_GPU,
|
| 381 |
+
DistributedType.MULTI_NPU,
|
| 382 |
+
DistributedType.MULTI_MLU,
|
| 383 |
+
DistributedType.MULTI_SDAA,
|
| 384 |
+
DistributedType.MULTI_MUSA,
|
| 385 |
+
DistributedType.MULTI_XPU,
|
| 386 |
+
DistributedType.MULTI_HPU,
|
| 387 |
+
]:
|
| 388 |
+
use_fsdp = _ask_field(
|
| 389 |
+
"Do you want to use FullyShardedDataParallel? [yes/NO]: ",
|
| 390 |
+
_convert_yes_no_to_bool,
|
| 391 |
+
default=False,
|
| 392 |
+
error_message="Please enter yes or no.",
|
| 393 |
+
)
|
| 394 |
+
if use_fsdp:
|
| 395 |
+
distributed_type = DistributedType.FSDP
|
| 396 |
+
if distributed_type == DistributedType.FSDP:
|
| 397 |
+
fsdp_config["fsdp_version"] = _ask_options(
|
| 398 |
+
"What should be your FSDP version? [2]: ",
|
| 399 |
+
[1, 2],
|
| 400 |
+
lambda x: int(x) + 1,
|
| 401 |
+
default=1,
|
| 402 |
+
)
|
| 403 |
+
fsdp_version = fsdp_config["fsdp_version"] # extract to a variable to simplify usage later
|
| 404 |
+
|
| 405 |
+
if fsdp_version == 1:
|
| 406 |
+
sharding_strategy_query = "What should be your sharding strategy?"
|
| 407 |
+
fsdp_config["fsdp_reshard_after_forward"] = _ask_options(
|
| 408 |
+
sharding_strategy_query,
|
| 409 |
+
FSDP_SHARDING_STRATEGY,
|
| 410 |
+
lambda x: FSDP_SHARDING_STRATEGY[int(x)],
|
| 411 |
+
)
|
| 412 |
+
else:
|
| 413 |
+
fsdp_config["fsdp_reshard_after_forward"] = _ask_field(
|
| 414 |
+
"Do you want to enable resharding after forward? [YES/no]: ",
|
| 415 |
+
_convert_yes_no_to_bool,
|
| 416 |
+
default=True,
|
| 417 |
+
error_message="Please enter yes or no.",
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
fsdp_config["fsdp_offload_params"] = _ask_field(
|
| 421 |
+
"Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
|
| 422 |
+
_convert_yes_no_to_bool,
|
| 423 |
+
default=False,
|
| 424 |
+
error_message="Please enter yes or no.",
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
fsdp_wrap_query = "What should be your auto wrap policy?"
|
| 428 |
+
fsdp_config["fsdp_auto_wrap_policy"] = _ask_options(
|
| 429 |
+
fsdp_wrap_query,
|
| 430 |
+
FSDP_AUTO_WRAP_POLICY,
|
| 431 |
+
lambda x: FSDP_AUTO_WRAP_POLICY[int(x)],
|
| 432 |
+
)
|
| 433 |
+
if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]:
|
| 434 |
+
use_no_split_modules = _ask_field(
|
| 435 |
+
"Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: ",
|
| 436 |
+
_convert_yes_no_to_bool,
|
| 437 |
+
default=False,
|
| 438 |
+
error_message="Please enter yes or no.",
|
| 439 |
+
)
|
| 440 |
+
if not use_no_split_modules:
|
| 441 |
+
fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field(
|
| 442 |
+
"Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :"
|
| 443 |
+
"`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : ",
|
| 444 |
+
str,
|
| 445 |
+
)
|
| 446 |
+
elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]:
|
| 447 |
+
fsdp_config["fsdp_min_num_params"] = _ask_field(
|
| 448 |
+
"What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ",
|
| 449 |
+
int,
|
| 450 |
+
default=100000000,
|
| 451 |
+
)
|
| 452 |
+
# Removed in FSDP2, ask for user input for FSDP1
|
| 453 |
+
if fsdp_version == 1:
|
| 454 |
+
fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?"
|
| 455 |
+
fsdp_config["fsdp_backward_prefetch"] = _ask_options(
|
| 456 |
+
fsdp_backward_prefetch_query,
|
| 457 |
+
FSDP_BACKWARD_PREFETCH,
|
| 458 |
+
lambda x: FSDP_BACKWARD_PREFETCH[int(x)],
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
fsdp_state_dict_type_query = "What should be your FSDP's state dict type?"
|
| 462 |
+
fsdp_config["fsdp_state_dict_type"] = _ask_options(
|
| 463 |
+
fsdp_state_dict_type_query,
|
| 464 |
+
FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE,
|
| 465 |
+
lambda x: FSDP_STATE_DICT_TYPE[int(x)] if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE[int(x)],
|
| 466 |
+
default=0,
|
| 467 |
+
)
|
| 468 |
+
# Not implemented in FSDP2, ask for user input for FSDP1
|
| 469 |
+
if fsdp_version == 1:
|
| 470 |
+
fsdp_config["fsdp_forward_prefetch"] = _ask_field(
|
| 471 |
+
"Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ",
|
| 472 |
+
_convert_yes_no_to_bool,
|
| 473 |
+
default=False,
|
| 474 |
+
error_message="Please enter yes or no.",
|
| 475 |
+
)
|
| 476 |
+
# Obsolete in FSDP2, ask for user input for FSDP1
|
| 477 |
+
if fsdp_version == 1:
|
| 478 |
+
fsdp_config["fsdp_use_orig_params"] = _ask_field(
|
| 479 |
+
"Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ",
|
| 480 |
+
_convert_yes_no_to_bool,
|
| 481 |
+
default=True,
|
| 482 |
+
error_message="Please enter yes or no.",
|
| 483 |
+
)
|
| 484 |
+
fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field(
|
| 485 |
+
"Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ",
|
| 486 |
+
_convert_yes_no_to_bool,
|
| 487 |
+
default=True,
|
| 488 |
+
error_message="Please enter yes or no.",
|
| 489 |
+
)
|
| 490 |
+
# Obsolete in FSDP2, ask for user input for FSDP1
|
| 491 |
+
if fsdp_version == 1:
|
| 492 |
+
if fsdp_config["fsdp_cpu_ram_efficient_loading"]:
|
| 493 |
+
fsdp_config["fsdp_sync_module_states"] = True
|
| 494 |
+
else:
|
| 495 |
+
fsdp_config["fsdp_sync_module_states"] = _ask_field(
|
| 496 |
+
"Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
|
| 497 |
+
_convert_yes_no_to_bool,
|
| 498 |
+
default=True,
|
| 499 |
+
error_message="Please enter yes or no.",
|
| 500 |
+
)
|
| 501 |
+
fsdp_config["fsdp_activation_checkpointing"] = _ask_field(
|
| 502 |
+
"Do you want to enable FSDP activation checkpointing? [yes/NO]: ",
|
| 503 |
+
_convert_yes_no_to_bool,
|
| 504 |
+
default=False,
|
| 505 |
+
error_message="Please enter yes or no.",
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
parallelism_config = {}
|
| 509 |
+
|
| 510 |
+
if fsdp_config.get("fsdp_version", 1) == 2:
|
| 511 |
+
use_parallelism_config = _ask_field(
|
| 512 |
+
"Do you want to use the parallelism config? [yes/NO]: ",
|
| 513 |
+
_convert_yes_no_to_bool,
|
| 514 |
+
default=False,
|
| 515 |
+
error_message="Please enter yes or no.",
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
if use_parallelism_config:
|
| 519 |
+
prefix = "parallelism_config_"
|
| 520 |
+
parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
|
| 521 |
+
"What is the data parallelism replicate size? [1]: ",
|
| 522 |
+
int,
|
| 523 |
+
default=1,
|
| 524 |
+
error_message="Please enter an integer.",
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
parallelism_config[prefix + "dp_shard_size"] = _ask_field(
|
| 528 |
+
"What is the FSDP shard size? [1]: ",
|
| 529 |
+
int,
|
| 530 |
+
default=1,
|
| 531 |
+
error_message="Please enter an integer.",
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
parallelism_config[prefix + "tp_size"] = _ask_field(
|
| 535 |
+
"What is the tensor parallelism size? [1]: ",
|
| 536 |
+
int,
|
| 537 |
+
default=1,
|
| 538 |
+
error_message="Please enter an integer.",
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
parallelism_config[prefix + "cp_size"] = _ask_field(
|
| 542 |
+
"What is the context parallelism size? [1]: ",
|
| 543 |
+
int,
|
| 544 |
+
default=1,
|
| 545 |
+
error_message="Please enter an integer.",
|
| 546 |
+
)
|
| 547 |
+
if parallelism_config[prefix + "cp_size"] > 1:
|
| 548 |
+
parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
|
| 549 |
+
"What is the compute parallelism communication strategy?",
|
| 550 |
+
["allgather", "alltoall"],
|
| 551 |
+
lambda x: ["allgather", "alltoall"][int(x)],
|
| 552 |
+
default=0,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
megatron_lm_config = {}
|
| 556 |
+
if distributed_type in [DistributedType.MULTI_GPU]:
|
| 557 |
+
use_megatron_lm = _ask_field(
|
| 558 |
+
"Do you want to use Megatron-LM ? [yes/NO]: ",
|
| 559 |
+
_convert_yes_no_to_bool,
|
| 560 |
+
default=False,
|
| 561 |
+
error_message="Please enter yes or no.",
|
| 562 |
+
)
|
| 563 |
+
if use_megatron_lm:
|
| 564 |
+
distributed_type = DistributedType.MEGATRON_LM
|
| 565 |
+
if distributed_type == DistributedType.MEGATRON_LM:
|
| 566 |
+
prefix = "megatron_lm_"
|
| 567 |
+
megatron_lm_config[prefix + "tp_degree"] = _ask_field(
|
| 568 |
+
"What is the Tensor Parallelism degree/size? [1]:",
|
| 569 |
+
int,
|
| 570 |
+
default=1,
|
| 571 |
+
error_message="Please enter an integer.",
|
| 572 |
+
)
|
| 573 |
+
if megatron_lm_config[prefix + "tp_degree"] > 1:
|
| 574 |
+
megatron_lm_config[prefix + "sequence_parallelism"] = _ask_field(
|
| 575 |
+
"Do you want to enable Sequence Parallelism? [YES/no]: ",
|
| 576 |
+
_convert_yes_no_to_bool,
|
| 577 |
+
default=True,
|
| 578 |
+
error_message="Please enter yes or no.",
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
megatron_lm_config[prefix + "pp_degree"] = _ask_field(
|
| 582 |
+
"What is the Pipeline Parallelism degree/size? [1]:",
|
| 583 |
+
int,
|
| 584 |
+
default=1,
|
| 585 |
+
error_message="Please enter an integer.",
|
| 586 |
+
)
|
| 587 |
+
if megatron_lm_config[prefix + "pp_degree"] > 1:
|
| 588 |
+
megatron_lm_config[prefix + "num_micro_batches"] = _ask_field(
|
| 589 |
+
"What is the number of micro-batches? [1]:",
|
| 590 |
+
int,
|
| 591 |
+
default=1,
|
| 592 |
+
error_message="Please enter an integer.",
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
megatron_lm_config[prefix + "recompute_activations"] = _ask_field(
|
| 596 |
+
"Do you want to enable selective activation recomputation? [YES/no]: ",
|
| 597 |
+
_convert_yes_no_to_bool,
|
| 598 |
+
default=True,
|
| 599 |
+
error_message="Please enter yes or no.",
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
megatron_lm_config[prefix + "use_distributed_optimizer"] = _ask_field(
|
| 603 |
+
"Do you want to use distributed optimizer "
|
| 604 |
+
"which shards optimizer state and gradients across data parallel ranks? [YES/no]: ",
|
| 605 |
+
_convert_yes_no_to_bool,
|
| 606 |
+
default=True,
|
| 607 |
+
error_message="Please enter yes or no.",
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
megatron_lm_config[prefix + "gradient_clipping"] = _ask_field(
|
| 611 |
+
"What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: ",
|
| 612 |
+
float,
|
| 613 |
+
default=1.0,
|
| 614 |
+
)
|
| 615 |
+
# TPU specific defaults
|
| 616 |
+
tpu_commands = None
|
| 617 |
+
tpu_command_file = None
|
| 618 |
+
tpu_downcast_bf16 = "no"
|
| 619 |
+
tpu_env = []
|
| 620 |
+
tpu_name = None
|
| 621 |
+
tpu_vm = None
|
| 622 |
+
tpu_zone = None
|
| 623 |
+
tpu_use_sudo = False
|
| 624 |
+
tpu_use_cluster = False
|
| 625 |
+
|
| 626 |
+
if distributed_type in [
|
| 627 |
+
DistributedType.MULTI_CPU,
|
| 628 |
+
DistributedType.MULTI_XPU,
|
| 629 |
+
DistributedType.MULTI_HPU,
|
| 630 |
+
DistributedType.MULTI_GPU,
|
| 631 |
+
DistributedType.MULTI_MLU,
|
| 632 |
+
DistributedType.MULTI_SDAA,
|
| 633 |
+
DistributedType.MULTI_MUSA,
|
| 634 |
+
DistributedType.MULTI_NPU,
|
| 635 |
+
DistributedType.XLA,
|
| 636 |
+
]:
|
| 637 |
+
machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
|
| 638 |
+
if machine_type == "TPU":
|
| 639 |
+
machine_type += " cores"
|
| 640 |
+
elif machine_type == "CPU":
|
| 641 |
+
machine_type = "processes"
|
| 642 |
+
else:
|
| 643 |
+
machine_type += "(s)"
|
| 644 |
+
num_processes = _ask_field(
|
| 645 |
+
f"How many {machine_type} should be used for distributed training? [1]:",
|
| 646 |
+
int,
|
| 647 |
+
default=1,
|
| 648 |
+
error_message="Please enter an integer.",
|
| 649 |
+
)
|
| 650 |
+
elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
|
| 651 |
+
num_processes = _ask_field(
|
| 652 |
+
"How many GPU(s) should be used for distributed training? [1]:",
|
| 653 |
+
int,
|
| 654 |
+
default=1,
|
| 655 |
+
error_message="Please enter an integer.",
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
num_processes = 1
|
| 659 |
+
|
| 660 |
+
if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1):
|
| 661 |
+
raise ValueError(
|
| 662 |
+
f"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using."
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
if (
|
| 666 |
+
distributed_type
|
| 667 |
+
in [
|
| 668 |
+
DistributedType.MULTI_GPU,
|
| 669 |
+
DistributedType.MULTI_MLU,
|
| 670 |
+
DistributedType.MULTI_SDAA,
|
| 671 |
+
DistributedType.MULTI_MUSA,
|
| 672 |
+
DistributedType.MULTI_NPU,
|
| 673 |
+
DistributedType.MULTI_XPU,
|
| 674 |
+
DistributedType.MULTI_HPU,
|
| 675 |
+
DistributedType.NO,
|
| 676 |
+
]
|
| 677 |
+
and not use_cpu
|
| 678 |
+
and not use_mps
|
| 679 |
+
):
|
| 680 |
+
if is_npu_available():
|
| 681 |
+
machine_type = "NPU(s)"
|
| 682 |
+
elif is_mlu_available():
|
| 683 |
+
machine_type = "MLU(s)"
|
| 684 |
+
elif is_sdaa_available():
|
| 685 |
+
machine_type = "SDAA(s)"
|
| 686 |
+
elif is_musa_available():
|
| 687 |
+
machine_type = "MUSA(s)"
|
| 688 |
+
elif is_xpu_available():
|
| 689 |
+
machine_type = "XPU(s)"
|
| 690 |
+
elif is_hpu_available():
|
| 691 |
+
machine_type = "HPU(s)"
|
| 692 |
+
else:
|
| 693 |
+
machine_type = "GPU(s)"
|
| 694 |
+
gpu_ids = _ask_field(
|
| 695 |
+
f"What {machine_type} (by id) should be used for training on this machine as a comma-separated list? [all]:",
|
| 696 |
+
default="all",
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# CPU affinity is only supported on NVIDIA hardware for now
|
| 700 |
+
enable_cpu_affinity = False
|
| 701 |
+
if distributed_type in (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps:
|
| 702 |
+
enable_cpu_affinity = _ask_field(
|
| 703 |
+
"Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: ",
|
| 704 |
+
_convert_yes_no_to_bool,
|
| 705 |
+
default=False,
|
| 706 |
+
error_message="Please enter yes or no.",
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
fp8_config = None
|
| 710 |
+
if distributed_type == DistributedType.XLA:
|
| 711 |
+
mixed_precision = "no"
|
| 712 |
+
main_training_function = _ask_field(
|
| 713 |
+
"What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
|
| 714 |
+
default="main",
|
| 715 |
+
)
|
| 716 |
+
tpu_use_cluster = _ask_field(
|
| 717 |
+
"Are you using a TPU cluster? [yes/NO]: ",
|
| 718 |
+
_convert_yes_no_to_bool,
|
| 719 |
+
default=False,
|
| 720 |
+
error_message="Please enter yes or no.",
|
| 721 |
+
)
|
| 722 |
+
if tpu_use_cluster:
|
| 723 |
+
tpu_name = _ask_field(
|
| 724 |
+
"What is the name of your TPU cluster? ",
|
| 725 |
+
default=None,
|
| 726 |
+
error_message="Please enter the name of your TPU cluster.",
|
| 727 |
+
)
|
| 728 |
+
tpu_zone = _ask_field(
|
| 729 |
+
"What is the zone of your TPU cluster? ",
|
| 730 |
+
default=None,
|
| 731 |
+
error_message="Please enter the zone of your TPU cluster.",
|
| 732 |
+
)
|
| 733 |
+
tpu_use_sudo = _ask_field(
|
| 734 |
+
"To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: ",
|
| 735 |
+
default=False,
|
| 736 |
+
error_message="Please enter yes or no.",
|
| 737 |
+
)
|
| 738 |
+
run_commands = _ask_field(
|
| 739 |
+
"Do you have code you wish to run on startup in each pod? [yes/NO]: ",
|
| 740 |
+
_convert_yes_no_to_bool,
|
| 741 |
+
default=False,
|
| 742 |
+
error_message="Please enter yes or no.",
|
| 743 |
+
)
|
| 744 |
+
if run_commands:
|
| 745 |
+
use_command_file = _ask_field(
|
| 746 |
+
"Is this code located in a bash script? [yes/NO]: ",
|
| 747 |
+
_convert_yes_no_to_bool,
|
| 748 |
+
default=False,
|
| 749 |
+
error_message="Please enter yes or no.",
|
| 750 |
+
)
|
| 751 |
+
if use_command_file:
|
| 752 |
+
tpu_command_file = _ask_field(
|
| 753 |
+
"What is the path to your bash script? ",
|
| 754 |
+
default=None,
|
| 755 |
+
error_message="Please enter the path to your bash script.",
|
| 756 |
+
)
|
| 757 |
+
tpu_command_file = os.path.abspath(tpu_command_file)
|
| 758 |
+
else:
|
| 759 |
+
print("Please enter each command separately you wish to run on startup in each pod.")
|
| 760 |
+
tpu_commands = []
|
| 761 |
+
another_command = True
|
| 762 |
+
while another_command:
|
| 763 |
+
tpu_commands.append(
|
| 764 |
+
_ask_field(
|
| 765 |
+
"Please enter a single command to be ran ",
|
| 766 |
+
default=None,
|
| 767 |
+
error_message="Please enter the commands you wish to run on startup in each pod as a single string.",
|
| 768 |
+
)
|
| 769 |
+
)
|
| 770 |
+
another_command = _ask_field(
|
| 771 |
+
"Do you wish to add another command? [yes/NO]: ",
|
| 772 |
+
_convert_yes_no_to_bool,
|
| 773 |
+
default=False,
|
| 774 |
+
error_message="Please enter yes or no.",
|
| 775 |
+
)
|
| 776 |
+
tpu_vm = _ask_field(
|
| 777 |
+
"If not using an instance group, what are the names of the Compute VM instances to be used, separated by a comma: ",
|
| 778 |
+
default="",
|
| 779 |
+
).split(",")
|
| 780 |
+
tpu_env = _ask_field(
|
| 781 |
+
"What environment variables do you wish to set in each pod, separated by a comma: ",
|
| 782 |
+
default="",
|
| 783 |
+
).split(",")
|
| 784 |
+
|
| 785 |
+
else:
|
| 786 |
+
main_training_function = "main"
|
| 787 |
+
if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
|
| 788 |
+
mixed_precision = None
|
| 789 |
+
else:
|
| 790 |
+
mixed_precision = _ask_options(
|
| 791 |
+
"Do you wish to use mixed precision?",
|
| 792 |
+
["no", "fp16", "bf16", "fp8"],
|
| 793 |
+
_convert_mixed_precision,
|
| 794 |
+
)
|
| 795 |
+
if mixed_precision == "fp8":
|
| 796 |
+
if not is_fp8_available():
|
| 797 |
+
raise ValueError("FP8 (either Transformer Engine or MSAMP) is not installed on this machine.")
|
| 798 |
+
fp8_config = {}
|
| 799 |
+
fp8_config["backend"] = _ask_options(
|
| 800 |
+
"Which FP8 backend do you want to use?",
|
| 801 |
+
["te", "msamp"],
|
| 802 |
+
_convert_fp8_backend,
|
| 803 |
+
)
|
| 804 |
+
if fp8_config["backend"] == "TE":
|
| 805 |
+
if not is_transformer_engine_available():
|
| 806 |
+
raise ValueError("TransformersEngine was selected, but it is not installed on this machine.")
|
| 807 |
+
fp8_config["use_autocast_during_eval"] = _ask_field(
|
| 808 |
+
"Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ",
|
| 809 |
+
_convert_yes_no_to_bool,
|
| 810 |
+
default=False,
|
| 811 |
+
)
|
| 812 |
+
fp8_config["margin"] = _ask_field(
|
| 813 |
+
"What margin should be used for gradient scaling? [0]: ",
|
| 814 |
+
int,
|
| 815 |
+
default=0,
|
| 816 |
+
)
|
| 817 |
+
fp8_config["interval"] = _ask_field(
|
| 818 |
+
"What interval should be used for for how often the scaling factor is recomputed? [1]: ",
|
| 819 |
+
int,
|
| 820 |
+
default=1,
|
| 821 |
+
)
|
| 822 |
+
fp8_config["fp8_format"] = _ask_options(
|
| 823 |
+
"Which weight format should be used?",
|
| 824 |
+
["HYBRID", "E4M3", "E5M2"],
|
| 825 |
+
lambda i: ["HYBRID", "E4M3", "E5M2"][i],
|
| 826 |
+
default=0,
|
| 827 |
+
)
|
| 828 |
+
fp8_config["amax_history_length"] = _ask_field(
|
| 829 |
+
"What length of history should be used for the amax scaling factor computation? [1024]: ",
|
| 830 |
+
int,
|
| 831 |
+
default=1024,
|
| 832 |
+
)
|
| 833 |
+
fp8_config["amax_compute_algorithm"] = _ask_options(
|
| 834 |
+
"Which algorithm should be used for the amax scaling factor computation?",
|
| 835 |
+
["max", "most_recent"],
|
| 836 |
+
lambda x: "max" if x == 0 else "most_recent",
|
| 837 |
+
default=0,
|
| 838 |
+
)
|
| 839 |
+
fp8_config["override_linear_precision"] = _ask_field(
|
| 840 |
+
"Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ",
|
| 841 |
+
_convert_yes_no_to_bool,
|
| 842 |
+
default=False,
|
| 843 |
+
)
|
| 844 |
+
if fp8_config["override_linear_precision"]:
|
| 845 |
+
fprop = _ask_field(
|
| 846 |
+
"Should `fprop` be executed in higher precision? [yes/NO]: ",
|
| 847 |
+
_convert_yes_no_to_bool,
|
| 848 |
+
default=False,
|
| 849 |
+
)
|
| 850 |
+
dgrad = _ask_field(
|
| 851 |
+
"Should `dgrad` be executed in higher precision? [yes/NO]: ",
|
| 852 |
+
_convert_yes_no_to_bool,
|
| 853 |
+
default=False,
|
| 854 |
+
)
|
| 855 |
+
wgrad = _ask_field(
|
| 856 |
+
"Should `wgrad` be executed in higher precision? [yes/NO]: ",
|
| 857 |
+
_convert_yes_no_to_bool,
|
| 858 |
+
default=False,
|
| 859 |
+
)
|
| 860 |
+
fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad)
|
| 861 |
+
else:
|
| 862 |
+
fp8_config["override_linear_precision"] = (False, False, False)
|
| 863 |
+
|
| 864 |
+
elif fp8_config["backend"] == "MSAMP":
|
| 865 |
+
if not is_msamp_available():
|
| 866 |
+
raise ValueError("MSAMP was selected, but it is not installed on this machine.")
|
| 867 |
+
fp8_config["optimization_level"] = _ask_options(
|
| 868 |
+
"Which optimization level should be used?",
|
| 869 |
+
["O1", "O2"],
|
| 870 |
+
lambda x: "O1" if x == 0 else "O2",
|
| 871 |
+
default=1,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
if use_dynamo and mixed_precision == "no" and not use_cpu:
|
| 875 |
+
print(
|
| 876 |
+
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
if distributed_type == DistributedType.XLA and mixed_precision == "bf16":
|
| 880 |
+
tpu_downcast_bf16 = _ask_field(
|
| 881 |
+
"Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
return ClusterConfig(
|
| 885 |
+
compute_environment=ComputeEnvironment.LOCAL_MACHINE,
|
| 886 |
+
distributed_type=distributed_type,
|
| 887 |
+
num_processes=num_processes,
|
| 888 |
+
gpu_ids=gpu_ids,
|
| 889 |
+
mixed_precision=mixed_precision,
|
| 890 |
+
downcast_bf16=tpu_downcast_bf16,
|
| 891 |
+
machine_rank=machine_rank,
|
| 892 |
+
num_machines=num_machines,
|
| 893 |
+
main_process_ip=main_process_ip,
|
| 894 |
+
main_process_port=main_process_port,
|
| 895 |
+
main_training_function=main_training_function,
|
| 896 |
+
fp8_config=fp8_config,
|
| 897 |
+
deepspeed_config=deepspeed_config,
|
| 898 |
+
fsdp_config=fsdp_config,
|
| 899 |
+
parallelism_config=parallelism_config,
|
| 900 |
+
megatron_lm_config=megatron_lm_config,
|
| 901 |
+
ipex_config=ipex_config,
|
| 902 |
+
mpirun_config=mpirun_config,
|
| 903 |
+
use_cpu=use_cpu,
|
| 904 |
+
rdzv_backend=rdzv_backend,
|
| 905 |
+
same_network=same_network,
|
| 906 |
+
commands=tpu_commands,
|
| 907 |
+
command_file=tpu_command_file,
|
| 908 |
+
tpu_env=tpu_env,
|
| 909 |
+
tpu_name=tpu_name,
|
| 910 |
+
tpu_vm=tpu_vm,
|
| 911 |
+
tpu_zone=tpu_zone,
|
| 912 |
+
tpu_use_sudo=tpu_use_sudo,
|
| 913 |
+
tpu_use_cluster=tpu_use_cluster,
|
| 914 |
+
dynamo_config=dynamo_config,
|
| 915 |
+
debug=debug,
|
| 916 |
+
enable_cpu_affinity=enable_cpu_affinity,
|
| 917 |
+
)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from accelerate.utils import ComputeEnvironment
|
| 21 |
+
|
| 22 |
+
from .cluster import get_cluster_input
|
| 23 |
+
from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401
|
| 24 |
+
from .config_utils import _ask_field, _ask_options, _convert_compute_environment # noqa: F401
|
| 25 |
+
from .sagemaker import get_sagemaker_input
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
description = "Launches a series of prompts to create and save a `default_config.yaml` configuration file for your training system. Should always be ran first on your machine"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_user_input():
|
| 32 |
+
compute_environment = _ask_options(
|
| 33 |
+
"In which compute environment are you running?",
|
| 34 |
+
["This machine", "AWS (Amazon SageMaker)"],
|
| 35 |
+
_convert_compute_environment,
|
| 36 |
+
)
|
| 37 |
+
if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
| 38 |
+
config = get_sagemaker_input()
|
| 39 |
+
else:
|
| 40 |
+
config = get_cluster_input()
|
| 41 |
+
return config
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def config_command_parser(subparsers=None):
|
| 45 |
+
if subparsers is not None:
|
| 46 |
+
parser = subparsers.add_parser("config", description=description)
|
| 47 |
+
else:
|
| 48 |
+
parser = argparse.ArgumentParser("Accelerate config command", description=description)
|
| 49 |
+
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--config_file",
|
| 52 |
+
default=None,
|
| 53 |
+
help=(
|
| 54 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
| 55 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 56 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 57 |
+
"with 'huggingface'."
|
| 58 |
+
),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if subparsers is not None:
|
| 62 |
+
parser.set_defaults(func=config_command)
|
| 63 |
+
return parser
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def config_command(args):
|
| 67 |
+
config = get_user_input()
|
| 68 |
+
if args.config_file is not None:
|
| 69 |
+
config_file = args.config_file
|
| 70 |
+
else:
|
| 71 |
+
if not os.path.isdir(cache_dir):
|
| 72 |
+
os.makedirs(cache_dir)
|
| 73 |
+
config_file = default_yaml_config_file
|
| 74 |
+
|
| 75 |
+
if config_file.endswith(".json"):
|
| 76 |
+
config.to_json_file(config_file)
|
| 77 |
+
else:
|
| 78 |
+
config.to_yaml_file(config_file)
|
| 79 |
+
print(f"accelerate configuration saved at {config_file}")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
parser = config_command_parser()
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
config_command(args)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_args.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from enum import Enum
|
| 21 |
+
from typing import Optional, Union
|
| 22 |
+
|
| 23 |
+
import yaml
|
| 24 |
+
|
| 25 |
+
from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
| 26 |
+
from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
hf_cache_home = os.path.expanduser(
|
| 30 |
+
os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
| 31 |
+
)
|
| 32 |
+
cache_dir = os.path.join(hf_cache_home, "accelerate")
|
| 33 |
+
default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
|
| 34 |
+
default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
|
| 35 |
+
|
| 36 |
+
# For backward compatibility: the default config is the json one if it's the only existing file.
|
| 37 |
+
if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
|
| 38 |
+
default_config_file = default_yaml_config_file
|
| 39 |
+
else:
|
| 40 |
+
default_config_file = default_json_config_file
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_config_from_file(config_file):
|
| 44 |
+
if config_file is not None:
|
| 45 |
+
if not os.path.isfile(config_file):
|
| 46 |
+
raise FileNotFoundError(
|
| 47 |
+
f"The passed configuration file `{config_file}` does not exist. "
|
| 48 |
+
"Please pass an existing file to `accelerate launch`, or use the default one "
|
| 49 |
+
"created through `accelerate config` and run `accelerate launch` "
|
| 50 |
+
"without the `--config_file` argument."
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
config_file = default_config_file
|
| 54 |
+
with open(config_file, encoding="utf-8") as f:
|
| 55 |
+
if config_file.endswith(".json"):
|
| 56 |
+
if (
|
| 57 |
+
json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
| 58 |
+
== ComputeEnvironment.LOCAL_MACHINE
|
| 59 |
+
):
|
| 60 |
+
config_class = ClusterConfig
|
| 61 |
+
else:
|
| 62 |
+
config_class = SageMakerConfig
|
| 63 |
+
return config_class.from_json_file(json_file=config_file)
|
| 64 |
+
else:
|
| 65 |
+
if (
|
| 66 |
+
yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
| 67 |
+
== ComputeEnvironment.LOCAL_MACHINE
|
| 68 |
+
):
|
| 69 |
+
config_class = ClusterConfig
|
| 70 |
+
else:
|
| 71 |
+
config_class = SageMakerConfig
|
| 72 |
+
return config_class.from_yaml_file(yaml_file=config_file)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class BaseConfig:
|
| 77 |
+
compute_environment: ComputeEnvironment
|
| 78 |
+
distributed_type: Union[DistributedType, SageMakerDistributedType]
|
| 79 |
+
mixed_precision: str
|
| 80 |
+
use_cpu: bool
|
| 81 |
+
debug: bool
|
| 82 |
+
|
| 83 |
+
def to_dict(self):
|
| 84 |
+
result = self.__dict__
|
| 85 |
+
# For serialization, it's best to convert Enums to strings (or their underlying value type).
|
| 86 |
+
|
| 87 |
+
def _convert_enums(value):
|
| 88 |
+
if isinstance(value, Enum):
|
| 89 |
+
return value.value
|
| 90 |
+
if isinstance(value, dict):
|
| 91 |
+
if not bool(value):
|
| 92 |
+
return None
|
| 93 |
+
for key1, value1 in value.items():
|
| 94 |
+
value[key1] = _convert_enums(value1)
|
| 95 |
+
return value
|
| 96 |
+
|
| 97 |
+
for key, value in result.items():
|
| 98 |
+
result[key] = _convert_enums(value)
|
| 99 |
+
result = {k: v for k, v in result.items() if v is not None}
|
| 100 |
+
return result
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def process_config(config_dict):
|
| 104 |
+
"""
|
| 105 |
+
Processes `config_dict` and sets default values for any missing keys
|
| 106 |
+
"""
|
| 107 |
+
if "compute_environment" not in config_dict:
|
| 108 |
+
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
| 109 |
+
if "distributed_type" not in config_dict:
|
| 110 |
+
raise ValueError("A `distributed_type` must be specified in the config file.")
|
| 111 |
+
if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
|
| 112 |
+
config_dict["num_processes"] = 1
|
| 113 |
+
if "mixed_precision" not in config_dict:
|
| 114 |
+
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
|
| 115 |
+
if "fp16" in config_dict: # Convert the config to the new format.
|
| 116 |
+
del config_dict["fp16"]
|
| 117 |
+
if "dynamo_backend" in config_dict: # Convert the config to the new format.
|
| 118 |
+
dynamo_backend = config_dict.pop("dynamo_backend")
|
| 119 |
+
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
|
| 120 |
+
if "use_cpu" not in config_dict:
|
| 121 |
+
config_dict["use_cpu"] = False
|
| 122 |
+
if "debug" not in config_dict:
|
| 123 |
+
config_dict["debug"] = False
|
| 124 |
+
if "enable_cpu_affinity" not in config_dict:
|
| 125 |
+
config_dict["enable_cpu_affinity"] = False
|
| 126 |
+
return config_dict
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def from_json_file(cls, json_file=None):
|
| 130 |
+
json_file = default_json_config_file if json_file is None else json_file
|
| 131 |
+
with open(json_file, encoding="utf-8") as f:
|
| 132 |
+
config_dict = json.load(f)
|
| 133 |
+
config_dict = cls.process_config(config_dict)
|
| 134 |
+
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
| 135 |
+
if len(extra_keys) > 0:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
| 138 |
+
" version or fix (and potentially remove) these keys from your config file."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return cls(**config_dict)
|
| 142 |
+
|
| 143 |
+
def to_json_file(self, json_file):
|
| 144 |
+
with open(json_file, "w", encoding="utf-8") as f:
|
| 145 |
+
content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
| 146 |
+
f.write(content)
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def from_yaml_file(cls, yaml_file=None):
|
| 150 |
+
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
|
| 151 |
+
with open(yaml_file, encoding="utf-8") as f:
|
| 152 |
+
config_dict = yaml.safe_load(f)
|
| 153 |
+
config_dict = cls.process_config(config_dict)
|
| 154 |
+
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
| 155 |
+
if len(extra_keys) > 0:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
| 158 |
+
" version or fix (and potentially remove) these keys from your config file."
|
| 159 |
+
)
|
| 160 |
+
return cls(**config_dict)
|
| 161 |
+
|
| 162 |
+
def to_yaml_file(self, yaml_file):
|
| 163 |
+
with open(yaml_file, "w", encoding="utf-8") as f:
|
| 164 |
+
yaml.safe_dump(self.to_dict(), f)
|
| 165 |
+
|
| 166 |
+
def __post_init__(self):
|
| 167 |
+
if isinstance(self.compute_environment, str):
|
| 168 |
+
self.compute_environment = ComputeEnvironment(self.compute_environment)
|
| 169 |
+
if isinstance(self.distributed_type, str):
|
| 170 |
+
if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
| 171 |
+
self.distributed_type = SageMakerDistributedType(self.distributed_type)
|
| 172 |
+
else:
|
| 173 |
+
self.distributed_type = DistributedType(self.distributed_type)
|
| 174 |
+
if getattr(self, "dynamo_config", None) is None:
|
| 175 |
+
self.dynamo_config = {}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class ClusterConfig(BaseConfig):
|
| 180 |
+
num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
|
| 181 |
+
machine_rank: int = 0
|
| 182 |
+
num_machines: int = 1
|
| 183 |
+
gpu_ids: Optional[str] = None
|
| 184 |
+
main_process_ip: Optional[str] = None
|
| 185 |
+
main_process_port: Optional[int] = None
|
| 186 |
+
rdzv_backend: Optional[str] = "static"
|
| 187 |
+
same_network: Optional[bool] = False
|
| 188 |
+
main_training_function: str = "main"
|
| 189 |
+
enable_cpu_affinity: bool = False
|
| 190 |
+
|
| 191 |
+
# args for FP8 training
|
| 192 |
+
fp8_config: Optional[dict] = None
|
| 193 |
+
# args for deepspeed_plugin
|
| 194 |
+
deepspeed_config: Optional[dict] = None
|
| 195 |
+
# args for fsdp
|
| 196 |
+
fsdp_config: Optional[dict] = None
|
| 197 |
+
# args for parallelism config
|
| 198 |
+
parallelism_config: Optional[dict] = None
|
| 199 |
+
# args for megatron_lm
|
| 200 |
+
megatron_lm_config: Optional[dict] = None
|
| 201 |
+
# args for ipex
|
| 202 |
+
ipex_config: Optional[dict] = None
|
| 203 |
+
# args for mpirun
|
| 204 |
+
mpirun_config: Optional[dict] = None
|
| 205 |
+
# args for TPU
|
| 206 |
+
downcast_bf16: bool = False
|
| 207 |
+
|
| 208 |
+
# args for TPU pods
|
| 209 |
+
tpu_name: Optional[str] = None
|
| 210 |
+
tpu_zone: Optional[str] = None
|
| 211 |
+
tpu_use_cluster: bool = False
|
| 212 |
+
tpu_use_sudo: bool = False
|
| 213 |
+
command_file: Optional[str] = None
|
| 214 |
+
commands: list[str] = None
|
| 215 |
+
tpu_vm: list[str] = None
|
| 216 |
+
tpu_env: list[str] = None
|
| 217 |
+
|
| 218 |
+
# args for dynamo
|
| 219 |
+
dynamo_config: Optional[dict] = None
|
| 220 |
+
|
| 221 |
+
def __post_init__(self):
|
| 222 |
+
if self.deepspeed_config is None:
|
| 223 |
+
self.deepspeed_config = {}
|
| 224 |
+
if self.fsdp_config is None:
|
| 225 |
+
self.fsdp_config = {}
|
| 226 |
+
if self.megatron_lm_config is None:
|
| 227 |
+
self.megatron_lm_config = {}
|
| 228 |
+
if self.ipex_config is None:
|
| 229 |
+
self.ipex_config = {}
|
| 230 |
+
if self.mpirun_config is None:
|
| 231 |
+
self.mpirun_config = {}
|
| 232 |
+
if self.fp8_config is None:
|
| 233 |
+
self.fp8_config = {}
|
| 234 |
+
if self.parallelism_config is None:
|
| 235 |
+
self.parallelism_config = {}
|
| 236 |
+
return super().__post_init__()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@dataclass
|
| 240 |
+
class SageMakerConfig(BaseConfig):
|
| 241 |
+
ec2_instance_type: str
|
| 242 |
+
iam_role_name: str
|
| 243 |
+
image_uri: Optional[str] = None
|
| 244 |
+
profile: Optional[str] = None
|
| 245 |
+
region: str = "us-east-1"
|
| 246 |
+
num_machines: int = 1
|
| 247 |
+
gpu_ids: str = "all"
|
| 248 |
+
base_job_name: str = f"accelerate-sagemaker-{num_machines}"
|
| 249 |
+
pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
|
| 250 |
+
transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
|
| 251 |
+
py_version: str = SAGEMAKER_PYTHON_VERSION
|
| 252 |
+
sagemaker_inputs_file: Optional[str] = None
|
| 253 |
+
sagemaker_metrics_file: Optional[str] = None
|
| 254 |
+
additional_args: Optional[dict] = None
|
| 255 |
+
dynamo_config: Optional[dict] = None
|
| 256 |
+
enable_cpu_affinity: bool = False
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from ...utils.dataclasses import (
|
| 20 |
+
ComputeEnvironment,
|
| 21 |
+
DistributedType,
|
| 22 |
+
DynamoBackend,
|
| 23 |
+
FP8BackendType,
|
| 24 |
+
PrecisionType,
|
| 25 |
+
SageMakerDistributedType,
|
| 26 |
+
)
|
| 27 |
+
from ..menu import BulletMenu
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DYNAMO_BACKENDS = [
|
| 31 |
+
"EAGER",
|
| 32 |
+
"AOT_EAGER",
|
| 33 |
+
"INDUCTOR",
|
| 34 |
+
"AOT_TS_NVFUSER",
|
| 35 |
+
"NVPRIMS_NVFUSER",
|
| 36 |
+
"CUDAGRAPHS",
|
| 37 |
+
"OFI",
|
| 38 |
+
"FX2TRT",
|
| 39 |
+
"ONNXRT",
|
| 40 |
+
"TENSORRT",
|
| 41 |
+
"AOT_TORCHXLA_TRACE_ONCE",
|
| 42 |
+
"TORHCHXLA_TRACE_ONCE",
|
| 43 |
+
"IPEX",
|
| 44 |
+
"TVM",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _ask_field(input_text, convert_value=None, default=None, error_message=None):
|
| 49 |
+
ask_again = True
|
| 50 |
+
while ask_again:
|
| 51 |
+
result = input(input_text)
|
| 52 |
+
try:
|
| 53 |
+
if default is not None and len(result) == 0:
|
| 54 |
+
return default
|
| 55 |
+
return convert_value(result) if convert_value is not None else result
|
| 56 |
+
except Exception:
|
| 57 |
+
if error_message is not None:
|
| 58 |
+
print(error_message)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _ask_options(input_text, options=[], convert_value=None, default=0):
|
| 62 |
+
menu = BulletMenu(input_text, options)
|
| 63 |
+
result = menu.run(default_choice=default)
|
| 64 |
+
return convert_value(result) if convert_value is not None else result
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _convert_compute_environment(value):
|
| 68 |
+
value = int(value)
|
| 69 |
+
return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _convert_distributed_mode(value):
|
| 73 |
+
value = int(value)
|
| 74 |
+
return DistributedType(
|
| 75 |
+
[
|
| 76 |
+
"NO",
|
| 77 |
+
"MULTI_CPU",
|
| 78 |
+
"MULTI_XPU",
|
| 79 |
+
"MULTI_HPU",
|
| 80 |
+
"MULTI_GPU",
|
| 81 |
+
"MULTI_NPU",
|
| 82 |
+
"MULTI_MLU",
|
| 83 |
+
"MULTI_SDAA",
|
| 84 |
+
"MULTI_MUSA",
|
| 85 |
+
"XLA",
|
| 86 |
+
][value]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _convert_dynamo_backend(value):
|
| 91 |
+
value = int(value)
|
| 92 |
+
return DynamoBackend(DYNAMO_BACKENDS[value]).value
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _convert_mixed_precision(value):
|
| 96 |
+
value = int(value)
|
| 97 |
+
return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _convert_sagemaker_distributed_mode(value):
|
| 101 |
+
value = int(value)
|
| 102 |
+
return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _convert_fp8_backend(value):
|
| 106 |
+
value = int(value)
|
| 107 |
+
return FP8BackendType(["TE", "MSAMP"][value])
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _convert_yes_no_to_bool(value):
|
| 111 |
+
return {"yes": True, "no": False}[value.lower()]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
|
| 115 |
+
"""
|
| 116 |
+
A custom formatter that will remove the usage line from the help message for subcommands.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def _format_usage(self, usage, actions, groups, prefix):
|
| 120 |
+
usage = super()._format_usage(usage, actions, groups, prefix)
|
| 121 |
+
usage = usage.replace("<command> [<args>] ", "")
|
| 122 |
+
return usage
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/default.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ...utils import (
|
| 22 |
+
is_hpu_available,
|
| 23 |
+
is_mlu_available,
|
| 24 |
+
is_musa_available,
|
| 25 |
+
is_npu_available,
|
| 26 |
+
is_sdaa_available,
|
| 27 |
+
is_xpu_available,
|
| 28 |
+
)
|
| 29 |
+
from .config_args import ClusterConfig, default_json_config_file
|
| 30 |
+
from .config_utils import SubcommandHelpFormatter
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
description = "Create a default config file for Accelerate with only a few flags set."
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file):
|
| 37 |
+
"""
|
| 38 |
+
Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
|
| 39 |
+
set CPU if it is a CPU-only machine.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
mixed_precision (`str`, *optional*, defaults to "no"):
|
| 43 |
+
Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
|
| 44 |
+
save_location (`str`, *optional*, defaults to `default_json_config_file`):
|
| 45 |
+
Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default
|
| 46 |
+
location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overridden by setting
|
| 47 |
+
the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.
|
| 48 |
+
"""
|
| 49 |
+
path = Path(save_location)
|
| 50 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
if path.exists():
|
| 52 |
+
print(
|
| 53 |
+
f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
|
| 54 |
+
)
|
| 55 |
+
return False
|
| 56 |
+
mixed_precision = mixed_precision.lower()
|
| 57 |
+
if mixed_precision not in ["no", "fp16", "bf16", "fp8"]:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}"
|
| 60 |
+
)
|
| 61 |
+
config = {
|
| 62 |
+
"compute_environment": "LOCAL_MACHINE",
|
| 63 |
+
"mixed_precision": mixed_precision,
|
| 64 |
+
}
|
| 65 |
+
if is_mlu_available():
|
| 66 |
+
num_mlus = torch.mlu.device_count()
|
| 67 |
+
config["num_processes"] = num_mlus
|
| 68 |
+
config["use_cpu"] = False
|
| 69 |
+
if num_mlus > 1:
|
| 70 |
+
config["distributed_type"] = "MULTI_MLU"
|
| 71 |
+
else:
|
| 72 |
+
config["distributed_type"] = "NO"
|
| 73 |
+
if is_sdaa_available():
|
| 74 |
+
num_sdaas = torch.sdaa.device_count()
|
| 75 |
+
config["num_processes"] = num_sdaas
|
| 76 |
+
config["use_cpu"] = False
|
| 77 |
+
if num_sdaas > 1:
|
| 78 |
+
config["distributed_type"] = "MULTI_SDAA"
|
| 79 |
+
else:
|
| 80 |
+
config["distributed_type"] = "NO"
|
| 81 |
+
elif is_musa_available():
|
| 82 |
+
num_musas = torch.musa.device_count()
|
| 83 |
+
config["num_processes"] = num_musas
|
| 84 |
+
config["use_cpu"] = False
|
| 85 |
+
if num_musas > 1:
|
| 86 |
+
config["distributed_type"] = "MULTI_MUSA"
|
| 87 |
+
else:
|
| 88 |
+
config["distributed_type"] = "NO"
|
| 89 |
+
elif is_hpu_available():
|
| 90 |
+
num_hpus = torch.hpu.device_count()
|
| 91 |
+
config["num_processes"] = num_hpus
|
| 92 |
+
config["use_cpu"] = False
|
| 93 |
+
if num_hpus > 1:
|
| 94 |
+
config["distributed_type"] = "MULTI_HPU"
|
| 95 |
+
else:
|
| 96 |
+
config["distributed_type"] = "NO"
|
| 97 |
+
elif torch.cuda.is_available():
|
| 98 |
+
num_gpus = torch.cuda.device_count()
|
| 99 |
+
config["num_processes"] = num_gpus
|
| 100 |
+
config["use_cpu"] = False
|
| 101 |
+
if num_gpus > 1:
|
| 102 |
+
config["distributed_type"] = "MULTI_GPU"
|
| 103 |
+
else:
|
| 104 |
+
config["distributed_type"] = "NO"
|
| 105 |
+
elif is_xpu_available():
|
| 106 |
+
num_xpus = torch.xpu.device_count()
|
| 107 |
+
config["num_processes"] = num_xpus
|
| 108 |
+
config["use_cpu"] = False
|
| 109 |
+
if num_xpus > 1:
|
| 110 |
+
config["distributed_type"] = "MULTI_XPU"
|
| 111 |
+
else:
|
| 112 |
+
config["distributed_type"] = "NO"
|
| 113 |
+
elif is_npu_available():
|
| 114 |
+
num_npus = torch.npu.device_count()
|
| 115 |
+
config["num_processes"] = num_npus
|
| 116 |
+
config["use_cpu"] = False
|
| 117 |
+
if num_npus > 1:
|
| 118 |
+
config["distributed_type"] = "MULTI_NPU"
|
| 119 |
+
else:
|
| 120 |
+
config["distributed_type"] = "NO"
|
| 121 |
+
else:
|
| 122 |
+
num_xpus = 0
|
| 123 |
+
config["use_cpu"] = True
|
| 124 |
+
config["num_processes"] = 1
|
| 125 |
+
config["distributed_type"] = "NO"
|
| 126 |
+
config["debug"] = False
|
| 127 |
+
config["enable_cpu_affinity"] = False
|
| 128 |
+
config = ClusterConfig(**config)
|
| 129 |
+
config.to_json_file(path)
|
| 130 |
+
return path
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def default_command_parser(parser, parents):
|
| 134 |
+
parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--config_file",
|
| 137 |
+
default=default_json_config_file,
|
| 138 |
+
help=(
|
| 139 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
| 140 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 141 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 142 |
+
"with 'huggingface'."
|
| 143 |
+
),
|
| 144 |
+
dest="save_location",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--mixed_precision",
|
| 149 |
+
choices=["no", "fp16", "bf16"],
|
| 150 |
+
type=str,
|
| 151 |
+
help="Whether or not to use mixed precision training. "
|
| 152 |
+
"Choose between FP16 and BF16 (bfloat16) training. "
|
| 153 |
+
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
|
| 154 |
+
default="no",
|
| 155 |
+
)
|
| 156 |
+
parser.set_defaults(func=default_config_command)
|
| 157 |
+
return parser
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def default_config_command(args):
|
| 161 |
+
config_file = write_basic_config(args.mixed_precision, args.save_location)
|
| 162 |
+
if config_file:
|
| 163 |
+
print(f"accelerate configuration saved at {config_file}")
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/sagemaker.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
|
| 20 |
+
from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
|
| 21 |
+
from ...utils.imports import is_boto3_available
|
| 22 |
+
from .config_args import SageMakerConfig
|
| 23 |
+
from .config_utils import (
|
| 24 |
+
DYNAMO_BACKENDS,
|
| 25 |
+
_ask_field,
|
| 26 |
+
_ask_options,
|
| 27 |
+
_convert_dynamo_backend,
|
| 28 |
+
_convert_mixed_precision,
|
| 29 |
+
_convert_sagemaker_distributed_mode,
|
| 30 |
+
_convert_yes_no_to_bool,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_boto3_available():
|
| 35 |
+
import boto3 # noqa: F401
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _create_iam_role_for_sagemaker(role_name):
|
| 39 |
+
iam_client = boto3.client("iam")
|
| 40 |
+
|
| 41 |
+
sagemaker_trust_policy = {
|
| 42 |
+
"Version": "2012-10-17",
|
| 43 |
+
"Statement": [
|
| 44 |
+
{"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
|
| 45 |
+
],
|
| 46 |
+
}
|
| 47 |
+
try:
|
| 48 |
+
# create the role, associated with the chosen trust policy
|
| 49 |
+
iam_client.create_role(
|
| 50 |
+
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
|
| 51 |
+
)
|
| 52 |
+
policy_document = {
|
| 53 |
+
"Version": "2012-10-17",
|
| 54 |
+
"Statement": [
|
| 55 |
+
{
|
| 56 |
+
"Effect": "Allow",
|
| 57 |
+
"Action": [
|
| 58 |
+
"sagemaker:*",
|
| 59 |
+
"ecr:GetDownloadUrlForLayer",
|
| 60 |
+
"ecr:BatchGetImage",
|
| 61 |
+
"ecr:BatchCheckLayerAvailability",
|
| 62 |
+
"ecr:GetAuthorizationToken",
|
| 63 |
+
"cloudwatch:PutMetricData",
|
| 64 |
+
"cloudwatch:GetMetricData",
|
| 65 |
+
"cloudwatch:GetMetricStatistics",
|
| 66 |
+
"cloudwatch:ListMetrics",
|
| 67 |
+
"logs:CreateLogGroup",
|
| 68 |
+
"logs:CreateLogStream",
|
| 69 |
+
"logs:DescribeLogStreams",
|
| 70 |
+
"logs:PutLogEvents",
|
| 71 |
+
"logs:GetLogEvents",
|
| 72 |
+
"s3:CreateBucket",
|
| 73 |
+
"s3:ListBucket",
|
| 74 |
+
"s3:GetBucketLocation",
|
| 75 |
+
"s3:GetObject",
|
| 76 |
+
"s3:PutObject",
|
| 77 |
+
],
|
| 78 |
+
"Resource": "*",
|
| 79 |
+
}
|
| 80 |
+
],
|
| 81 |
+
}
|
| 82 |
+
# attach policy to role
|
| 83 |
+
iam_client.put_role_policy(
|
| 84 |
+
RoleName=role_name,
|
| 85 |
+
PolicyName=f"{role_name}_policy_permission",
|
| 86 |
+
PolicyDocument=json.dumps(policy_document, indent=2),
|
| 87 |
+
)
|
| 88 |
+
except iam_client.exceptions.EntityAlreadyExistsException:
|
| 89 |
+
print(f"role {role_name} already exists. Using existing one")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _get_iam_role_arn(role_name):
|
| 93 |
+
iam_client = boto3.client("iam")
|
| 94 |
+
return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_sagemaker_input():
|
| 98 |
+
credentials_configuration = _ask_options(
|
| 99 |
+
"How do you want to authorize?",
|
| 100 |
+
["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
|
| 101 |
+
int,
|
| 102 |
+
)
|
| 103 |
+
aws_profile = None
|
| 104 |
+
if credentials_configuration == 0:
|
| 105 |
+
aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
|
| 106 |
+
os.environ["AWS_PROFILE"] = aws_profile
|
| 107 |
+
else:
|
| 108 |
+
print(
|
| 109 |
+
"Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
|
| 110 |
+
"`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
|
| 111 |
+
)
|
| 112 |
+
aws_access_key_id = _ask_field("AWS Access Key ID: ")
|
| 113 |
+
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
| 114 |
+
|
| 115 |
+
aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
|
| 116 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
| 117 |
+
|
| 118 |
+
aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
|
| 119 |
+
os.environ["AWS_DEFAULT_REGION"] = aws_region
|
| 120 |
+
|
| 121 |
+
role_management = _ask_options(
|
| 122 |
+
"Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
|
| 123 |
+
["Provide IAM Role name", "Create new IAM role using credentials"],
|
| 124 |
+
int,
|
| 125 |
+
)
|
| 126 |
+
if role_management == 0:
|
| 127 |
+
iam_role_name = _ask_field("Enter your IAM role name: ")
|
| 128 |
+
else:
|
| 129 |
+
iam_role_name = "accelerate_sagemaker_execution_role"
|
| 130 |
+
print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
|
| 131 |
+
_create_iam_role_for_sagemaker(iam_role_name)
|
| 132 |
+
|
| 133 |
+
is_custom_docker_image = _ask_field(
|
| 134 |
+
"Do you want to use custom Docker image? [yes/NO]: ",
|
| 135 |
+
_convert_yes_no_to_bool,
|
| 136 |
+
default=False,
|
| 137 |
+
error_message="Please enter yes or no.",
|
| 138 |
+
)
|
| 139 |
+
docker_image = None
|
| 140 |
+
if is_custom_docker_image:
|
| 141 |
+
docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
|
| 142 |
+
|
| 143 |
+
is_sagemaker_inputs_enabled = _ask_field(
|
| 144 |
+
"Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
|
| 145 |
+
_convert_yes_no_to_bool,
|
| 146 |
+
default=False,
|
| 147 |
+
error_message="Please enter yes or no.",
|
| 148 |
+
)
|
| 149 |
+
sagemaker_inputs_file = None
|
| 150 |
+
if is_sagemaker_inputs_enabled:
|
| 151 |
+
sagemaker_inputs_file = _ask_field(
|
| 152 |
+
"Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
|
| 153 |
+
lambda x: str(x).lower(),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
is_sagemaker_metrics_enabled = _ask_field(
|
| 157 |
+
"Do you want to enable SageMaker metrics? [yes/NO]: ",
|
| 158 |
+
_convert_yes_no_to_bool,
|
| 159 |
+
default=False,
|
| 160 |
+
error_message="Please enter yes or no.",
|
| 161 |
+
)
|
| 162 |
+
sagemaker_metrics_file = None
|
| 163 |
+
if is_sagemaker_metrics_enabled:
|
| 164 |
+
sagemaker_metrics_file = _ask_field(
|
| 165 |
+
"Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
|
| 166 |
+
lambda x: str(x).lower(),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
distributed_type = _ask_options(
|
| 170 |
+
"What is the distributed mode?",
|
| 171 |
+
["No distributed training", "Data parallelism"],
|
| 172 |
+
_convert_sagemaker_distributed_mode,
|
| 173 |
+
)
|
| 174 |
+
dynamo_config = {}
|
| 175 |
+
use_dynamo = _ask_field(
|
| 176 |
+
"Do you wish to optimize your script with torch dynamo?[yes/NO]:",
|
| 177 |
+
_convert_yes_no_to_bool,
|
| 178 |
+
default=False,
|
| 179 |
+
error_message="Please enter yes or no.",
|
| 180 |
+
)
|
| 181 |
+
if use_dynamo:
|
| 182 |
+
prefix = "dynamo_"
|
| 183 |
+
dynamo_config[prefix + "backend"] = _ask_options(
|
| 184 |
+
"Which dynamo backend would you like to use?",
|
| 185 |
+
[x.lower() for x in DYNAMO_BACKENDS],
|
| 186 |
+
_convert_dynamo_backend,
|
| 187 |
+
default=2,
|
| 188 |
+
)
|
| 189 |
+
use_custom_options = _ask_field(
|
| 190 |
+
"Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
|
| 191 |
+
_convert_yes_no_to_bool,
|
| 192 |
+
default=False,
|
| 193 |
+
error_message="Please enter yes or no.",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if use_custom_options:
|
| 197 |
+
dynamo_config[prefix + "mode"] = _ask_options(
|
| 198 |
+
"Which mode do you want to use?",
|
| 199 |
+
TORCH_DYNAMO_MODES,
|
| 200 |
+
lambda x: TORCH_DYNAMO_MODES[int(x)],
|
| 201 |
+
default="default",
|
| 202 |
+
)
|
| 203 |
+
dynamo_config[prefix + "use_fullgraph"] = _ask_field(
|
| 204 |
+
"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
|
| 205 |
+
_convert_yes_no_to_bool,
|
| 206 |
+
default=False,
|
| 207 |
+
error_message="Please enter yes or no.",
|
| 208 |
+
)
|
| 209 |
+
dynamo_config[prefix + "use_dynamic"] = _ask_field(
|
| 210 |
+
"Do you want to enable dynamic shape tracing? [yes/NO]: ",
|
| 211 |
+
_convert_yes_no_to_bool,
|
| 212 |
+
default=False,
|
| 213 |
+
error_message="Please enter yes or no.",
|
| 214 |
+
)
|
| 215 |
+
dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
|
| 216 |
+
"Do you want to enable regional compilation? [yes/NO]: ",
|
| 217 |
+
_convert_yes_no_to_bool,
|
| 218 |
+
default=False,
|
| 219 |
+
error_message="Please enter yes or no.",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
ec2_instance_query = "Which EC2 instance type you want to use for your training?"
|
| 223 |
+
if distributed_type != SageMakerDistributedType.NO:
|
| 224 |
+
ec2_instance_type = _ask_options(
|
| 225 |
+
ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
ec2_instance_query += "? [ml.p3.2xlarge]:"
|
| 229 |
+
ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
|
| 230 |
+
|
| 231 |
+
debug = False
|
| 232 |
+
if distributed_type != SageMakerDistributedType.NO:
|
| 233 |
+
debug = _ask_field(
|
| 234 |
+
"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
|
| 235 |
+
_convert_yes_no_to_bool,
|
| 236 |
+
default=False,
|
| 237 |
+
error_message="Please enter yes or no.",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
num_machines = 1
|
| 241 |
+
if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
|
| 242 |
+
num_machines = _ask_field(
|
| 243 |
+
"How many machines do you want use? [1]: ",
|
| 244 |
+
int,
|
| 245 |
+
default=1,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
mixed_precision = _ask_options(
|
| 249 |
+
"Do you wish to use FP16 or BF16 (mixed precision)?",
|
| 250 |
+
["no", "fp16", "bf16", "fp8"],
|
| 251 |
+
_convert_mixed_precision,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if use_dynamo and mixed_precision == "no":
|
| 255 |
+
print(
|
| 256 |
+
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return SageMakerConfig(
|
| 260 |
+
image_uri=docker_image,
|
| 261 |
+
compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
|
| 262 |
+
distributed_type=distributed_type,
|
| 263 |
+
use_cpu=False,
|
| 264 |
+
dynamo_config=dynamo_config,
|
| 265 |
+
ec2_instance_type=ec2_instance_type,
|
| 266 |
+
profile=aws_profile,
|
| 267 |
+
region=aws_region,
|
| 268 |
+
iam_role_name=iam_role_name,
|
| 269 |
+
mixed_precision=mixed_precision,
|
| 270 |
+
num_machines=num_machines,
|
| 271 |
+
sagemaker_inputs_file=sagemaker_inputs_file,
|
| 272 |
+
sagemaker_metrics_file=sagemaker_metrics_file,
|
| 273 |
+
debug=debug,
|
| 274 |
+
)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/update.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
from .config_args import default_config_file, load_config_from_file
|
| 20 |
+
from .config_utils import SubcommandHelpFormatter
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
description = "Update an existing config file with the latest defaults while maintaining the old configuration."
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def update_config(args):
|
| 27 |
+
"""
|
| 28 |
+
Update an existing config file with the latest defaults while maintaining the old configuration.
|
| 29 |
+
"""
|
| 30 |
+
config_file = args.config_file
|
| 31 |
+
if config_file is None and Path(default_config_file).exists():
|
| 32 |
+
config_file = default_config_file
|
| 33 |
+
elif not Path(config_file).exists():
|
| 34 |
+
raise ValueError(f"The passed config file located at {config_file} doesn't exist.")
|
| 35 |
+
config = load_config_from_file(config_file)
|
| 36 |
+
|
| 37 |
+
if config_file.endswith(".json"):
|
| 38 |
+
config.to_json_file(config_file)
|
| 39 |
+
else:
|
| 40 |
+
config.to_yaml_file(config_file)
|
| 41 |
+
return config_file
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def update_command_parser(parser, parents):
|
| 45 |
+
parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--config_file",
|
| 48 |
+
default=None,
|
| 49 |
+
help=(
|
| 50 |
+
"The path to the config file to update. Will default to a file named default_config.yaml in the cache "
|
| 51 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 52 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 53 |
+
"with 'huggingface'."
|
| 54 |
+
),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
parser.set_defaults(func=update_config_command)
|
| 58 |
+
return parser
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def update_config_command(args):
|
| 62 |
+
config_file = update_config(args)
|
| 63 |
+
print(f"Successfully updated the configuration file at {config_file}.")
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from .selection_menu import BulletMenu
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (284 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-312.pyc
ADDED
|
Binary file (3.06 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/helpers.cpython-312.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/input.cpython-312.pyc
ADDED
|
Binary file (3.17 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-312.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-312.pyc
ADDED
|
Binary file (7.47 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/cursor.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
A utility for showing and hiding the terminal cursor on Windows and Linux, based on https://github.com/bchao1/bullet
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from contextlib import contextmanager
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Windows only
|
| 25 |
+
if os.name == "nt":
|
| 26 |
+
import ctypes
|
| 27 |
+
import msvcrt # noqa
|
| 28 |
+
|
| 29 |
+
class CursorInfo(ctypes.Structure):
|
| 30 |
+
# _fields is a specific attr expected by ctypes
|
| 31 |
+
_fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def hide_cursor():
|
| 35 |
+
if os.name == "nt":
|
| 36 |
+
ci = CursorInfo()
|
| 37 |
+
handle = ctypes.windll.kernel32.GetStdHandle(-11)
|
| 38 |
+
ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 39 |
+
ci.visible = False
|
| 40 |
+
ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 41 |
+
elif os.name == "posix":
|
| 42 |
+
sys.stdout.write("\033[?25l")
|
| 43 |
+
sys.stdout.flush()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def show_cursor():
|
| 47 |
+
if os.name == "nt":
|
| 48 |
+
ci = CursorInfo()
|
| 49 |
+
handle = ctypes.windll.kernel32.GetStdHandle(-11)
|
| 50 |
+
ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 51 |
+
ci.visible = True
|
| 52 |
+
ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 53 |
+
elif os.name == "posix":
|
| 54 |
+
sys.stdout.write("\033[?25h")
|
| 55 |
+
sys.stdout.flush()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@contextmanager
|
| 59 |
+
def hide():
|
| 60 |
+
"Context manager to hide the terminal cursor"
|
| 61 |
+
try:
|
| 62 |
+
hide_cursor()
|
| 63 |
+
yield
|
| 64 |
+
finally:
|
| 65 |
+
show_cursor()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/helpers.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
A variety of helper functions and constants when dealing with terminal menu choices, based on
|
| 17 |
+
https://github.com/bchao1/bullet
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import enum
|
| 21 |
+
import shutil
|
| 22 |
+
import sys
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
TERMINAL_WIDTH, _ = shutil.get_terminal_size()
|
| 26 |
+
|
| 27 |
+
CURSOR_TO_CHAR = {"UP": "A", "DOWN": "B", "RIGHT": "C", "LEFT": "D"}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Direction(enum.Enum):
|
| 31 |
+
UP = 0
|
| 32 |
+
DOWN = 1
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def forceWrite(content, end=""):
|
| 36 |
+
sys.stdout.write(str(content) + end)
|
| 37 |
+
sys.stdout.flush()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def writeColor(content, color, end=""):
|
| 41 |
+
forceWrite(f"\u001b[{color}m{content}\u001b[0m", end)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def reset_cursor():
|
| 45 |
+
forceWrite("\r")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def move_cursor(num_lines: int, direction: str):
|
| 49 |
+
forceWrite(f"\033[{num_lines}{CURSOR_TO_CHAR[direction.upper()]}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def clear_line():
|
| 53 |
+
forceWrite(" " * TERMINAL_WIDTH)
|
| 54 |
+
reset_cursor()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def linebreak():
|
| 58 |
+
reset_cursor()
|
| 59 |
+
forceWrite("-" * TERMINAL_WIDTH)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/input.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This file contains utilities for handling input from the user and registering specific keys to specific functions,
|
| 17 |
+
based on https://github.com/bchao1/bullet
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from .keymap import KEYMAP, get_character
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def mark(key: str):
|
| 24 |
+
"""
|
| 25 |
+
Mark the function with the key code so it can be handled in the register
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def decorator(func):
|
| 29 |
+
handle = getattr(func, "handle_key", [])
|
| 30 |
+
handle += [key]
|
| 31 |
+
func.handle_key = handle
|
| 32 |
+
return func
|
| 33 |
+
|
| 34 |
+
return decorator
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def mark_multiple(*keys: list[str]):
|
| 38 |
+
"""
|
| 39 |
+
Mark the function with the key codes so it can be handled in the register
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def decorator(func):
|
| 43 |
+
handle = getattr(func, "handle_key", [])
|
| 44 |
+
handle += keys
|
| 45 |
+
func.handle_key = handle
|
| 46 |
+
return func
|
| 47 |
+
|
| 48 |
+
return decorator
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class KeyHandler(type):
|
| 52 |
+
"""
|
| 53 |
+
Metaclass that adds the key handlers to the class
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __new__(cls, name, bases, attrs):
|
| 57 |
+
new_cls = super().__new__(cls, name, bases, attrs)
|
| 58 |
+
if not hasattr(new_cls, "key_handler"):
|
| 59 |
+
new_cls.key_handler = {}
|
| 60 |
+
new_cls.handle_input = KeyHandler.handle_input
|
| 61 |
+
|
| 62 |
+
for value in attrs.values():
|
| 63 |
+
handled_keys = getattr(value, "handle_key", [])
|
| 64 |
+
for key in handled_keys:
|
| 65 |
+
new_cls.key_handler[key] = value
|
| 66 |
+
return new_cls
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def handle_input(cls):
|
| 70 |
+
"Finds and returns the selected character if it exists in the handler"
|
| 71 |
+
char = get_character()
|
| 72 |
+
if char != KEYMAP["undefined"]:
|
| 73 |
+
char = ord(char)
|
| 74 |
+
handler = cls.key_handler.get(char)
|
| 75 |
+
if handler:
|
| 76 |
+
cls.current_selection = char
|
| 77 |
+
return handler(cls)
|
| 78 |
+
else:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def register(cls):
|
| 83 |
+
"""Adds KeyHandler metaclass to the class"""
|
| 84 |
+
return KeyHandler(cls.__name__, cls.__bases__, cls.__dict__.copy())
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/keymap.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Utilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import string
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ARROW_KEY_FLAG = 1 << 8
|
| 25 |
+
|
| 26 |
+
KEYMAP = {
|
| 27 |
+
"tab": ord("\t"),
|
| 28 |
+
"newline": ord("\r"),
|
| 29 |
+
"esc": 27,
|
| 30 |
+
"up": 65 + ARROW_KEY_FLAG,
|
| 31 |
+
"down": 66 + ARROW_KEY_FLAG,
|
| 32 |
+
"right": 67 + ARROW_KEY_FLAG,
|
| 33 |
+
"left": 68 + ARROW_KEY_FLAG,
|
| 34 |
+
"mod_int": 91,
|
| 35 |
+
"undefined": sys.maxsize,
|
| 36 |
+
"interrupt": 3,
|
| 37 |
+
"insert": 50,
|
| 38 |
+
"delete": 51,
|
| 39 |
+
"pg_up": 53,
|
| 40 |
+
"pg_down": 54,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
KEYMAP["arrow_begin"] = KEYMAP["up"]
|
| 44 |
+
KEYMAP["arrow_end"] = KEYMAP["left"]
|
| 45 |
+
|
| 46 |
+
if sys.platform == "win32":
|
| 47 |
+
WIN_CH_BUFFER = []
|
| 48 |
+
WIN_KEYMAP = {
|
| 49 |
+
b"\xe0H": KEYMAP["up"] - ARROW_KEY_FLAG,
|
| 50 |
+
b"\x00H": KEYMAP["up"] - ARROW_KEY_FLAG,
|
| 51 |
+
b"\xe0P": KEYMAP["down"] - ARROW_KEY_FLAG,
|
| 52 |
+
b"\x00P": KEYMAP["down"] - ARROW_KEY_FLAG,
|
| 53 |
+
b"\xe0M": KEYMAP["right"] - ARROW_KEY_FLAG,
|
| 54 |
+
b"\x00M": KEYMAP["right"] - ARROW_KEY_FLAG,
|
| 55 |
+
b"\xe0K": KEYMAP["left"] - ARROW_KEY_FLAG,
|
| 56 |
+
b"\x00K": KEYMAP["left"] - ARROW_KEY_FLAG,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
for i in range(10):
|
| 60 |
+
KEYMAP[str(i)] = ord(str(i))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_raw_chars():
|
| 64 |
+
"Gets raw characters from inputs"
|
| 65 |
+
if os.name == "nt":
|
| 66 |
+
import msvcrt
|
| 67 |
+
|
| 68 |
+
encoding = "mbcs"
|
| 69 |
+
# Flush the keyboard buffer
|
| 70 |
+
while msvcrt.kbhit():
|
| 71 |
+
msvcrt.getch()
|
| 72 |
+
if len(WIN_CH_BUFFER) == 0:
|
| 73 |
+
# Read the keystroke
|
| 74 |
+
ch = msvcrt.getch()
|
| 75 |
+
|
| 76 |
+
# If it is a prefix char, get second part
|
| 77 |
+
if ch in (b"\x00", b"\xe0"):
|
| 78 |
+
ch2 = ch + msvcrt.getch()
|
| 79 |
+
# Translate actual Win chars to bullet char types
|
| 80 |
+
try:
|
| 81 |
+
chx = chr(WIN_KEYMAP[ch2])
|
| 82 |
+
WIN_CH_BUFFER.append(chr(KEYMAP["mod_int"]))
|
| 83 |
+
WIN_CH_BUFFER.append(chx)
|
| 84 |
+
if ord(chx) in (
|
| 85 |
+
KEYMAP["insert"] - 1 << 9,
|
| 86 |
+
KEYMAP["delete"] - 1 << 9,
|
| 87 |
+
KEYMAP["pg_up"] - 1 << 9,
|
| 88 |
+
KEYMAP["pg_down"] - 1 << 9,
|
| 89 |
+
):
|
| 90 |
+
WIN_CH_BUFFER.append(chr(126))
|
| 91 |
+
ch = chr(KEYMAP["esc"])
|
| 92 |
+
except KeyError:
|
| 93 |
+
ch = ch2[1]
|
| 94 |
+
else:
|
| 95 |
+
ch = ch.decode(encoding)
|
| 96 |
+
else:
|
| 97 |
+
ch = WIN_CH_BUFFER.pop(0)
|
| 98 |
+
elif os.name == "posix":
|
| 99 |
+
import termios
|
| 100 |
+
import tty
|
| 101 |
+
|
| 102 |
+
fd = sys.stdin.fileno()
|
| 103 |
+
old_settings = termios.tcgetattr(fd)
|
| 104 |
+
try:
|
| 105 |
+
tty.setraw(fd)
|
| 106 |
+
ch = sys.stdin.read(1)
|
| 107 |
+
finally:
|
| 108 |
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
| 109 |
+
return ch
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_character():
|
| 113 |
+
"Gets a character from the keyboard and returns the key code"
|
| 114 |
+
char = get_raw_chars()
|
| 115 |
+
if ord(char) in [KEYMAP["interrupt"], KEYMAP["newline"]]:
|
| 116 |
+
return char
|
| 117 |
+
|
| 118 |
+
elif ord(char) == KEYMAP["esc"]:
|
| 119 |
+
combo = get_raw_chars()
|
| 120 |
+
if ord(combo) == KEYMAP["mod_int"]:
|
| 121 |
+
key = get_raw_chars()
|
| 122 |
+
if ord(key) >= KEYMAP["arrow_begin"] - ARROW_KEY_FLAG and ord(key) <= KEYMAP["arrow_end"] - ARROW_KEY_FLAG:
|
| 123 |
+
return chr(ord(key) + ARROW_KEY_FLAG)
|
| 124 |
+
else:
|
| 125 |
+
return KEYMAP["undefined"]
|
| 126 |
+
else:
|
| 127 |
+
return get_raw_chars()
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
if char in string.printable:
|
| 131 |
+
return char
|
| 132 |
+
else:
|
| 133 |
+
return KEYMAP["undefined"]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/selection_menu.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Main driver for the selection menu, based on https://github.com/bchao1/bullet
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import builtins
|
| 20 |
+
import sys
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from ...utils.imports import _is_package_available
|
| 24 |
+
from . import cursor, input
|
| 25 |
+
from .helpers import Direction, clear_line, forceWrite, linebreak, move_cursor, reset_cursor, writeColor
|
| 26 |
+
from .keymap import KEYMAP
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
in_colab = False
|
| 30 |
+
try:
|
| 31 |
+
in_colab = _is_package_available("google.colab")
|
| 32 |
+
except ModuleNotFoundError:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@input.register
|
| 37 |
+
class BulletMenu:
|
| 38 |
+
"""
|
| 39 |
+
A CLI menu to select a choice from a list of choices using the keyboard.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, prompt: Optional[str] = None, choices: list = []):
|
| 43 |
+
self.position = 0
|
| 44 |
+
self.choices = choices
|
| 45 |
+
self.prompt = prompt
|
| 46 |
+
if sys.platform == "win32":
|
| 47 |
+
self.arrow_char = "*"
|
| 48 |
+
else:
|
| 49 |
+
self.arrow_char = "➔ "
|
| 50 |
+
|
| 51 |
+
def write_choice(self, index, end: str = ""):
|
| 52 |
+
if sys.platform != "win32":
|
| 53 |
+
writeColor(self.choices[index], 32, end)
|
| 54 |
+
else:
|
| 55 |
+
forceWrite(self.choices[index], end)
|
| 56 |
+
|
| 57 |
+
def print_choice(self, index: int):
|
| 58 |
+
"Prints the choice at the given index"
|
| 59 |
+
if index == self.position:
|
| 60 |
+
forceWrite(f" {self.arrow_char} ")
|
| 61 |
+
self.write_choice(index)
|
| 62 |
+
else:
|
| 63 |
+
forceWrite(f" {self.choices[index]}")
|
| 64 |
+
reset_cursor()
|
| 65 |
+
|
| 66 |
+
def move_direction(self, direction: Direction, num_spaces: int = 1):
|
| 67 |
+
"Should not be directly called, used to move a direction of either up or down"
|
| 68 |
+
old_position = self.position
|
| 69 |
+
if direction == Direction.DOWN:
|
| 70 |
+
if self.position + 1 >= len(self.choices):
|
| 71 |
+
return
|
| 72 |
+
self.position += num_spaces
|
| 73 |
+
else:
|
| 74 |
+
if self.position - 1 < 0:
|
| 75 |
+
return
|
| 76 |
+
self.position -= num_spaces
|
| 77 |
+
clear_line()
|
| 78 |
+
self.print_choice(old_position)
|
| 79 |
+
move_cursor(num_spaces, direction.name)
|
| 80 |
+
self.print_choice(self.position)
|
| 81 |
+
|
| 82 |
+
@input.mark(KEYMAP["up"])
|
| 83 |
+
def move_up(self):
|
| 84 |
+
self.move_direction(Direction.UP)
|
| 85 |
+
|
| 86 |
+
@input.mark(KEYMAP["down"])
|
| 87 |
+
def move_down(self):
|
| 88 |
+
self.move_direction(Direction.DOWN)
|
| 89 |
+
|
| 90 |
+
@input.mark(KEYMAP["newline"])
|
| 91 |
+
def select(self):
|
| 92 |
+
move_cursor(len(self.choices) - self.position, "DOWN")
|
| 93 |
+
return self.position
|
| 94 |
+
|
| 95 |
+
@input.mark(KEYMAP["interrupt"])
|
| 96 |
+
def interrupt(self):
|
| 97 |
+
move_cursor(len(self.choices) - self.position, "DOWN")
|
| 98 |
+
raise KeyboardInterrupt
|
| 99 |
+
|
| 100 |
+
@input.mark_multiple(*[KEYMAP[str(number)] for number in range(10)])
|
| 101 |
+
def select_row(self):
|
| 102 |
+
index = int(chr(self.current_selection))
|
| 103 |
+
movement = index - self.position
|
| 104 |
+
if index == self.position:
|
| 105 |
+
return
|
| 106 |
+
if index < len(self.choices):
|
| 107 |
+
if self.position > index:
|
| 108 |
+
self.move_direction(Direction.UP, -movement)
|
| 109 |
+
elif self.position < index:
|
| 110 |
+
self.move_direction(Direction.DOWN, movement)
|
| 111 |
+
else:
|
| 112 |
+
return
|
| 113 |
+
else:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
def run(self, default_choice: int = 0):
|
| 117 |
+
"Start the menu and return the selected choice"
|
| 118 |
+
if self.prompt:
|
| 119 |
+
linebreak()
|
| 120 |
+
forceWrite(self.prompt, "\n")
|
| 121 |
+
if in_colab:
|
| 122 |
+
forceWrite("Please input a choice index (starting from 0), and press enter", "\n")
|
| 123 |
+
else:
|
| 124 |
+
forceWrite("Please select a choice using the arrow or number keys, and selecting with enter", "\n")
|
| 125 |
+
self.position = default_choice
|
| 126 |
+
for i in range(len(self.choices)):
|
| 127 |
+
self.print_choice(i)
|
| 128 |
+
forceWrite("\n")
|
| 129 |
+
move_cursor(len(self.choices) - self.position, "UP")
|
| 130 |
+
with cursor.hide():
|
| 131 |
+
while True:
|
| 132 |
+
if in_colab:
|
| 133 |
+
try:
|
| 134 |
+
choice = int(builtins.input())
|
| 135 |
+
except ValueError:
|
| 136 |
+
choice = default_choice
|
| 137 |
+
else:
|
| 138 |
+
choice = self.handle_input()
|
| 139 |
+
if choice is not None:
|
| 140 |
+
reset_cursor()
|
| 141 |
+
for _ in range(len(self.choices) + 1):
|
| 142 |
+
move_cursor(1, "UP")
|
| 143 |
+
clear_line()
|
| 144 |
+
self.write_choice(choice, "\n")
|
| 145 |
+
return choice
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/examples.cpython-312.pyc
ADDED
|
Binary file (6.93 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/testing.cpython-312.pyc
ADDED
|
Binary file (42.3 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/training.cpython-312.pyc
ADDED
|
Binary file (7.46 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_ddp_comm_hook.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState
|
| 17 |
+
from accelerate.utils import is_hpu_available
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MockModel(torch.nn.Module):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
torch.manual_seed(0)
|
| 24 |
+
self.p = torch.nn.Parameter(torch.randn(40, 20))
|
| 25 |
+
|
| 26 |
+
def forward(self, x, rank):
|
| 27 |
+
return self.p * (x ** (1 + rank))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _run_and_get_grads(model, rank):
|
| 31 |
+
torch.manual_seed(2024)
|
| 32 |
+
input = torch.randn(40, 20)
|
| 33 |
+
output = model(input, rank)
|
| 34 |
+
output.mean().backward()
|
| 35 |
+
param = next(model.parameters())
|
| 36 |
+
return param.grad
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option):
|
| 40 |
+
ddp_kwargs = DistributedDataParallelKwargs(
|
| 41 |
+
comm_hook=comm_hook,
|
| 42 |
+
comm_wrapper=comm_wrapper,
|
| 43 |
+
comm_state_option=comm_state_option,
|
| 44 |
+
)
|
| 45 |
+
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
| 46 |
+
|
| 47 |
+
model = accelerator.prepare(MockModel())
|
| 48 |
+
hook_grads = _run_and_get_grads(model, accelerator.local_process_index)
|
| 49 |
+
|
| 50 |
+
reference_model = torch.nn.parallel.DistributedDataParallel(
|
| 51 |
+
MockModel().to(accelerator.device),
|
| 52 |
+
device_ids=[accelerator.local_process_index],
|
| 53 |
+
output_device=accelerator.local_process_index,
|
| 54 |
+
)
|
| 55 |
+
reference_grads = _run_and_get_grads(reference_model, accelerator.local_process_index)
|
| 56 |
+
|
| 57 |
+
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-2, atol=1e-2)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
for comm_hook, comm_wrapper, comm_state_option in [
|
| 62 |
+
(DDPCommunicationHookType.NO, DDPCommunicationHookType.NO, {}),
|
| 63 |
+
(DDPCommunicationHookType.FP16, DDPCommunicationHookType.NO, {}),
|
| 64 |
+
(DDPCommunicationHookType.BF16, DDPCommunicationHookType.NO, {}),
|
| 65 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {}),
|
| 66 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.FP16, {}),
|
| 67 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BF16, {}),
|
| 68 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {"matrix_approximation_rank": 2}),
|
| 69 |
+
(DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.NO, {}),
|
| 70 |
+
(DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.FP16, {}),
|
| 71 |
+
(DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.BF16, {}),
|
| 72 |
+
]:
|
| 73 |
+
if is_hpu_available():
|
| 74 |
+
HPU_UNSUPPORTED_COMM_HOOKS = {DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16}
|
| 75 |
+
if comm_hook in HPU_UNSUPPORTED_COMM_HOOKS or comm_wrapper in HPU_UNSUPPORTED_COMM_HOOKS:
|
| 76 |
+
print(f"Skipping test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper} on HPU")
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
|
| 80 |
+
test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
|
| 81 |
+
PartialState().destroy_process_group()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_distributed_data_loop.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import pickle
|
| 18 |
+
import tempfile
|
| 19 |
+
import warnings
|
| 20 |
+
from unittest.mock import Mock
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch.utils.data import (
|
| 24 |
+
BatchSampler,
|
| 25 |
+
DataLoader,
|
| 26 |
+
Dataset,
|
| 27 |
+
IterableDataset,
|
| 28 |
+
RandomSampler,
|
| 29 |
+
TensorDataset,
|
| 30 |
+
default_collate,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from accelerate.accelerator import Accelerator, DataLoaderConfiguration
|
| 34 |
+
from accelerate.utils.dataclasses import DistributedType
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
NUM_ELEMENTS = 22
|
| 38 |
+
NUM_WORKERS = 4
|
| 39 |
+
BATCH_SIZE = 4
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DummyDataset(Dataset):
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return NUM_ELEMENTS
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, index):
|
| 47 |
+
squeeze = False
|
| 48 |
+
|
| 49 |
+
if isinstance(index, int):
|
| 50 |
+
index = [index]
|
| 51 |
+
squeeze = True
|
| 52 |
+
elif isinstance(index, slice):
|
| 53 |
+
index = list(range(*index.indices(self.size)))
|
| 54 |
+
else:
|
| 55 |
+
index = list(index)
|
| 56 |
+
|
| 57 |
+
batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index]
|
| 58 |
+
|
| 59 |
+
if squeeze:
|
| 60 |
+
batch = batch[0]
|
| 61 |
+
|
| 62 |
+
return batch
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DummyIterableDataset(IterableDataset):
|
| 66 |
+
def __init__(self, data):
|
| 67 |
+
self.data = data
|
| 68 |
+
|
| 69 |
+
def __iter__(self):
|
| 70 |
+
yield from self.data
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def create_accelerator(even_batches=True):
|
| 74 |
+
dataloader_config = DataLoaderConfiguration(even_batches=even_batches)
|
| 75 |
+
accelerator = Accelerator(dataloader_config=dataloader_config)
|
| 76 |
+
assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
|
| 77 |
+
return accelerator
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def create_dataloader(
|
| 81 |
+
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Create a simple DataLoader to use during the test cases
|
| 85 |
+
"""
|
| 86 |
+
values = torch.as_tensor(range(dataset_size))
|
| 87 |
+
if shuffle:
|
| 88 |
+
values = values[torch.randperm(values.size(0))]
|
| 89 |
+
if iterable:
|
| 90 |
+
dataset = DummyIterableDataset(values)
|
| 91 |
+
else:
|
| 92 |
+
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
|
| 93 |
+
|
| 94 |
+
dl = DataLoader(dataset, batch_size=batch_size)
|
| 95 |
+
dl = accelerator.prepare(dl)
|
| 96 |
+
|
| 97 |
+
return dl
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def verify_dataloader_batch_sizes(
|
| 101 |
+
accelerator: Accelerator,
|
| 102 |
+
dataset_size: int,
|
| 103 |
+
batch_size: int,
|
| 104 |
+
process_0_expected_batch_sizes: list[int],
|
| 105 |
+
process_1_expected_batch_sizes: list[int],
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
A helper function for verifying the batch sizes coming from a prepared dataloader in each process
|
| 109 |
+
"""
|
| 110 |
+
dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size)
|
| 111 |
+
|
| 112 |
+
batch_sizes = [len(batch[0]) for batch in dl]
|
| 113 |
+
|
| 114 |
+
if accelerator.process_index == 0:
|
| 115 |
+
assert batch_sizes == process_0_expected_batch_sizes
|
| 116 |
+
elif accelerator.process_index == 1:
|
| 117 |
+
assert batch_sizes == process_1_expected_batch_sizes
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def test_default_ensures_even_batch_sizes():
|
| 121 |
+
accelerator = create_accelerator()
|
| 122 |
+
|
| 123 |
+
# without padding, we would expect a different number of batches
|
| 124 |
+
verify_dataloader_batch_sizes(
|
| 125 |
+
accelerator,
|
| 126 |
+
dataset_size=3,
|
| 127 |
+
batch_size=1,
|
| 128 |
+
process_0_expected_batch_sizes=[1, 1],
|
| 129 |
+
process_1_expected_batch_sizes=[1, 1],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# without padding, we would expect the same number of batches, but different sizes
|
| 133 |
+
verify_dataloader_batch_sizes(
|
| 134 |
+
accelerator,
|
| 135 |
+
dataset_size=7,
|
| 136 |
+
batch_size=2,
|
| 137 |
+
process_0_expected_batch_sizes=[2, 2],
|
| 138 |
+
process_1_expected_batch_sizes=[2, 2],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def test_can_disable_even_batches():
|
| 143 |
+
accelerator = create_accelerator(even_batches=False)
|
| 144 |
+
|
| 145 |
+
verify_dataloader_batch_sizes(
|
| 146 |
+
accelerator,
|
| 147 |
+
dataset_size=3,
|
| 148 |
+
batch_size=1,
|
| 149 |
+
process_0_expected_batch_sizes=[1, 1],
|
| 150 |
+
process_1_expected_batch_sizes=[1],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
verify_dataloader_batch_sizes(
|
| 154 |
+
accelerator,
|
| 155 |
+
dataset_size=7,
|
| 156 |
+
batch_size=2,
|
| 157 |
+
process_0_expected_batch_sizes=[2, 2],
|
| 158 |
+
process_1_expected_batch_sizes=[2, 1],
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test_can_join_uneven_inputs():
|
| 163 |
+
accelerator = create_accelerator(even_batches=False)
|
| 164 |
+
|
| 165 |
+
model = torch.nn.Linear(1, 1)
|
| 166 |
+
ddp_model = accelerator.prepare(model)
|
| 167 |
+
|
| 168 |
+
dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 169 |
+
|
| 170 |
+
batch_idxs = []
|
| 171 |
+
with accelerator.join_uneven_inputs([ddp_model]):
|
| 172 |
+
for batch_idx, batch in enumerate(dl):
|
| 173 |
+
output = ddp_model(batch[0].float())
|
| 174 |
+
loss = output.sum()
|
| 175 |
+
loss.backward()
|
| 176 |
+
batch_idxs.append(batch_idx)
|
| 177 |
+
|
| 178 |
+
accelerator.wait_for_everyone()
|
| 179 |
+
|
| 180 |
+
if accelerator.process_index == 0:
|
| 181 |
+
assert batch_idxs == [0, 1]
|
| 182 |
+
elif accelerator.process_index == 1:
|
| 183 |
+
assert batch_idxs == [0]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def test_join_raises_warning_for_non_ddp_distributed(accelerator):
|
| 187 |
+
with warnings.catch_warnings(record=True) as w:
|
| 188 |
+
with accelerator.join_uneven_inputs([Mock()]):
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
assert issubclass(w[-1].category, UserWarning)
|
| 192 |
+
assert "only supported for multi-GPU" in str(w[-1].message)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def test_join_can_override_even_batches():
|
| 196 |
+
default_even_batches = True
|
| 197 |
+
overridden_even_batches = False
|
| 198 |
+
accelerator = create_accelerator(even_batches=default_even_batches)
|
| 199 |
+
model = torch.nn.Linear(1, 1)
|
| 200 |
+
ddp_model = accelerator.prepare(model)
|
| 201 |
+
train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 202 |
+
valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 203 |
+
|
| 204 |
+
with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
|
| 205 |
+
train_dl_overridden_value = train_dl.batch_sampler.even_batches
|
| 206 |
+
valid_dl_overridden_value = valid_dl.batch_sampler.even_batches
|
| 207 |
+
|
| 208 |
+
assert train_dl_overridden_value == overridden_even_batches
|
| 209 |
+
assert valid_dl_overridden_value == overridden_even_batches
|
| 210 |
+
assert train_dl.batch_sampler.even_batches == default_even_batches
|
| 211 |
+
assert valid_dl.batch_sampler.even_batches == default_even_batches
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def test_join_can_override_for_mixed_type_dataloaders():
|
| 215 |
+
default_even_batches = True
|
| 216 |
+
overridden_even_batches = False
|
| 217 |
+
accelerator = create_accelerator(even_batches=default_even_batches)
|
| 218 |
+
model = torch.nn.Linear(1, 1)
|
| 219 |
+
ddp_model = accelerator.prepare(model)
|
| 220 |
+
create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
|
| 221 |
+
batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 222 |
+
|
| 223 |
+
with warnings.catch_warnings():
|
| 224 |
+
warnings.filterwarnings("ignore")
|
| 225 |
+
try:
|
| 226 |
+
with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
|
| 227 |
+
batch_dl_overridden_value = batch_dl.batch_sampler.even_batches
|
| 228 |
+
except AttributeError:
|
| 229 |
+
# ensure attribute error is not raised when processing iterable dl
|
| 230 |
+
raise AssertionError
|
| 231 |
+
|
| 232 |
+
assert batch_dl_overridden_value == overridden_even_batches
|
| 233 |
+
assert batch_dl.batch_sampler.even_batches == default_even_batches
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def test_join_raises_warning_for_iterable_when_overriding_even_batches():
|
| 237 |
+
accelerator = create_accelerator()
|
| 238 |
+
model = torch.nn.Linear(1, 1)
|
| 239 |
+
ddp_model = accelerator.prepare(model)
|
| 240 |
+
create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
|
| 241 |
+
|
| 242 |
+
with warnings.catch_warnings(record=True) as w:
|
| 243 |
+
with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
|
| 244 |
+
pass
|
| 245 |
+
|
| 246 |
+
assert issubclass(w[-1].category, UserWarning)
|
| 247 |
+
assert "only supported for map-style datasets" in str(w[-1].message)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def test_pickle_accelerator():
|
| 251 |
+
accelerator = create_accelerator()
|
| 252 |
+
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
|
| 253 |
+
_ = accelerator.prepare(data_loader)
|
| 254 |
+
pickled_accelerator = pickle.dumps(accelerator)
|
| 255 |
+
unpickled_accelerator = pickle.loads(pickled_accelerator)
|
| 256 |
+
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
|
| 257 |
+
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def test_data_loader(data_loader, accelerator):
|
| 261 |
+
# Prepare the DataLoader
|
| 262 |
+
data_loader = accelerator.prepare(data_loader)
|
| 263 |
+
|
| 264 |
+
all_examples = []
|
| 265 |
+
for i, batch in enumerate(data_loader):
|
| 266 |
+
index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"]))
|
| 267 |
+
all_examples.extend(index.detach().cpu().numpy().tolist())
|
| 268 |
+
|
| 269 |
+
# Sort the examples
|
| 270 |
+
sorted_all_examples = sorted(all_examples)
|
| 271 |
+
|
| 272 |
+
# Check if all elements are present in the sorted list of iterated samples
|
| 273 |
+
assert len(set(sorted_all_examples)) == NUM_ELEMENTS, (
|
| 274 |
+
"Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def test_stateful_dataloader(accelerator):
|
| 279 |
+
"""
|
| 280 |
+
Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
|
| 281 |
+
resumed from the saved state.
|
| 282 |
+
|
| 283 |
+
The result should be the same as the rest of the data that iterated over after saving.
|
| 284 |
+
"""
|
| 285 |
+
old_dataloader_config = accelerator.dataloader_config
|
| 286 |
+
try:
|
| 287 |
+
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
| 288 |
+
prepared_dl = create_dataloader(
|
| 289 |
+
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
|
| 290 |
+
)
|
| 291 |
+
untrained_batches = []
|
| 292 |
+
# Calculate what step that will be
|
| 293 |
+
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
|
| 294 |
+
last_batch_num = total_batches - 1
|
| 295 |
+
for step, batch in enumerate(prepared_dl):
|
| 296 |
+
# Step just before
|
| 297 |
+
if step == last_batch_num - 1:
|
| 298 |
+
state_dict = prepared_dl.state_dict()
|
| 299 |
+
if step >= last_batch_num:
|
| 300 |
+
# Otherwise grab the "unseen" batches
|
| 301 |
+
untrained_batches.append(batch)
|
| 302 |
+
not_skipped_batches = accelerator.gather(untrained_batches)
|
| 303 |
+
prepared_dl.load_state_dict(state_dict)
|
| 304 |
+
resumed_batches = []
|
| 305 |
+
for batch in prepared_dl:
|
| 306 |
+
resumed_batches.append(batch)
|
| 307 |
+
resumed_batches = accelerator.gather(resumed_batches)
|
| 308 |
+
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
| 309 |
+
for v1, v2 in zip(b1, b2):
|
| 310 |
+
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
| 311 |
+
finally:
|
| 312 |
+
accelerator.dataloader_config = old_dataloader_config
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def test_stateful_dataloader_save_state(accelerator):
|
| 316 |
+
"""
|
| 317 |
+
Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
|
| 318 |
+
and then resumed from the saved state.
|
| 319 |
+
|
| 320 |
+
The result should be the same as the rest of the data that iterated over after saving.
|
| 321 |
+
"""
|
| 322 |
+
old_dataloader_config = accelerator.dataloader_config
|
| 323 |
+
try:
|
| 324 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 325 |
+
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
| 326 |
+
prepared_dl = create_dataloader(
|
| 327 |
+
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
|
| 328 |
+
)
|
| 329 |
+
untrained_batches = []
|
| 330 |
+
# Calculate what step that will be
|
| 331 |
+
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
|
| 332 |
+
last_batch_num = total_batches - 1
|
| 333 |
+
for step, batch in enumerate(prepared_dl):
|
| 334 |
+
# Step just before
|
| 335 |
+
if step == last_batch_num - 1:
|
| 336 |
+
accelerator.save_state(tmpdir)
|
| 337 |
+
if step >= last_batch_num:
|
| 338 |
+
# Otherwise grab the "unseen" batches
|
| 339 |
+
untrained_batches.append(batch)
|
| 340 |
+
not_skipped_batches = accelerator.gather(untrained_batches)
|
| 341 |
+
accelerator.load_state(tmpdir)
|
| 342 |
+
resumed_batches = []
|
| 343 |
+
for batch in prepared_dl:
|
| 344 |
+
resumed_batches.append(batch)
|
| 345 |
+
resumed_batches = accelerator.gather(resumed_batches)
|
| 346 |
+
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
| 347 |
+
for v1, v2 in zip(b1, b2):
|
| 348 |
+
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
| 349 |
+
finally:
|
| 350 |
+
accelerator.dataloader_config = old_dataloader_config
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def main():
|
| 354 |
+
accelerator = create_accelerator()
|
| 355 |
+
torch.manual_seed(accelerator.process_index)
|
| 356 |
+
|
| 357 |
+
accelerator.print("Test that even_batches variable ensures uniform batches across processes")
|
| 358 |
+
test_default_ensures_even_batch_sizes()
|
| 359 |
+
|
| 360 |
+
accelerator.print("Run tests with even_batches disabled")
|
| 361 |
+
test_can_disable_even_batches()
|
| 362 |
+
|
| 363 |
+
accelerator.print("Test joining uneven inputs")
|
| 364 |
+
test_can_join_uneven_inputs()
|
| 365 |
+
|
| 366 |
+
accelerator.print("Test overriding even_batches when joining uneven inputs")
|
| 367 |
+
test_join_can_override_even_batches()
|
| 368 |
+
|
| 369 |
+
accelerator.print("Test overriding even_batches for mixed dataloader types")
|
| 370 |
+
test_join_can_override_for_mixed_type_dataloaders()
|
| 371 |
+
|
| 372 |
+
accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders")
|
| 373 |
+
test_join_raises_warning_for_iterable_when_overriding_even_batches()
|
| 374 |
+
|
| 375 |
+
accelerator.print("Test join with non DDP distributed raises warning")
|
| 376 |
+
original_state = accelerator.state.distributed_type
|
| 377 |
+
accelerator.state.distributed_type = DistributedType.FSDP
|
| 378 |
+
test_join_raises_warning_for_non_ddp_distributed(accelerator)
|
| 379 |
+
accelerator.state.distributed_type = original_state
|
| 380 |
+
|
| 381 |
+
accelerator.print("Test pickling an accelerator")
|
| 382 |
+
test_pickle_accelerator()
|
| 383 |
+
|
| 384 |
+
dataset = DummyDataset()
|
| 385 |
+
|
| 386 |
+
accelerator.print("Test DataLoader with shuffle=False")
|
| 387 |
+
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
|
| 388 |
+
test_data_loader(loader, accelerator)
|
| 389 |
+
|
| 390 |
+
accelerator.print("Test DataLoader with shuffle=True")
|
| 391 |
+
loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
|
| 392 |
+
test_data_loader(loader, accelerator)
|
| 393 |
+
|
| 394 |
+
accelerator.print("Test DataLoader with batch_sampler")
|
| 395 |
+
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
|
| 396 |
+
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS)
|
| 397 |
+
test_data_loader(loader, accelerator)
|
| 398 |
+
|
| 399 |
+
accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`")
|
| 400 |
+
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
|
| 401 |
+
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
|
| 402 |
+
test_data_loader(loader, accelerator)
|
| 403 |
+
test_stateful_dataloader(accelerator)
|
| 404 |
+
test_stateful_dataloader_save_state(accelerator)
|
| 405 |
+
|
| 406 |
+
accelerator.end_training()
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if __name__ == "__main__":
|
| 410 |
+
main()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_merge_weights.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import gc
|
| 16 |
+
import logging
|
| 17 |
+
import shutil
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from safetensors.torch import load_file
|
| 22 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
|
| 25 |
+
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
| 26 |
+
from accelerate.commands.merge import merge_command, merge_command_parser
|
| 27 |
+
from accelerate.state import AcceleratorState
|
| 28 |
+
from accelerate.test_utils import torch_device
|
| 29 |
+
from accelerate.test_utils.training import RegressionDataset
|
| 30 |
+
from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO)
|
| 34 |
+
|
| 35 |
+
parser = merge_command_parser()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TinyModel(torch.nn.Module):
|
| 39 |
+
def __init__(self):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.linear1 = torch.nn.Linear(16, 16)
|
| 42 |
+
self.activation = torch.nn.ReLU()
|
| 43 |
+
self.linear2 = torch.nn.Linear(16, 16)
|
| 44 |
+
self.softmax = torch.nn.Softmax()
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.linear2(self.activation(self.linear1(x)))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def setup():
|
| 51 |
+
if AcceleratorState._shared_state != {}:
|
| 52 |
+
AcceleratorState()._reset_state()
|
| 53 |
+
plugin = FullyShardedDataParallelPlugin(
|
| 54 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
|
| 55 |
+
)
|
| 56 |
+
model = TinyModel()
|
| 57 |
+
with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
|
| 58 |
+
plugin.set_auto_wrap_policy(model)
|
| 59 |
+
accelerator = Accelerator(fsdp_plugin=plugin)
|
| 60 |
+
model = accelerator.prepare(model)
|
| 61 |
+
return model, plugin, accelerator
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def mock_training(accelerator, model):
|
| 65 |
+
train_set = RegressionDataset(length=128, seed=42)
|
| 66 |
+
train_dl = DataLoader(train_set, batch_size=16, shuffle=False)
|
| 67 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 68 |
+
|
| 69 |
+
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
| 70 |
+
for _ in range(3):
|
| 71 |
+
for batch in train_dl:
|
| 72 |
+
model.zero_grad()
|
| 73 |
+
output = model(batch["x"])
|
| 74 |
+
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
| 75 |
+
accelerator.backward(loss)
|
| 76 |
+
optimizer.step()
|
| 77 |
+
return model
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def check_weights(operation, state_1, state_2):
|
| 81 |
+
for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
|
| 82 |
+
if operation == "same":
|
| 83 |
+
assert torch.allclose(weight_1, weight_2)
|
| 84 |
+
else:
|
| 85 |
+
assert not torch.allclose(weight_1, weight_2)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def check_safetensors_weights(path, model):
|
| 89 |
+
safe_state_dict = load_file(path / "model.safetensors")
|
| 90 |
+
safe_loaded_model = TinyModel().to(torch_device)
|
| 91 |
+
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
|
| 92 |
+
safe_loaded_model.load_state_dict(safe_state_dict)
|
| 93 |
+
check_weights("same", model.state_dict(), safe_loaded_model.state_dict())
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def check_pytorch_weights(path, model):
|
| 97 |
+
nonsafe_state_dict = torch.load(path / "pytorch_model.bin", weights_only=True)
|
| 98 |
+
nonsafe_loaded_model = TinyModel().to(torch_device)
|
| 99 |
+
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
|
| 100 |
+
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
|
| 101 |
+
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_merge_weights_safetensors(model, path):
|
| 105 |
+
# Should now be saved at `path/merged.safetensors`
|
| 106 |
+
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True)
|
| 107 |
+
check_safetensors_weights(path, model)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_merge_weights_command_safetensors(model, path):
|
| 111 |
+
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)])
|
| 112 |
+
merge_command(args)
|
| 113 |
+
check_safetensors_weights(path, model)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_merge_weights_pytorch(model, path):
|
| 117 |
+
# Should now be saved at `path/merged.bin`
|
| 118 |
+
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False)
|
| 119 |
+
check_pytorch_weights(path, model)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_merge_weights_command_pytorch(model, path):
|
| 123 |
+
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"])
|
| 124 |
+
merge_command(args)
|
| 125 |
+
check_pytorch_weights(path, model)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
# Note this test requires at least two accelerators!
|
| 130 |
+
model, plugin, accelerator = setup()
|
| 131 |
+
if accelerator.num_processes > 1:
|
| 132 |
+
try:
|
| 133 |
+
# Initial setup for things
|
| 134 |
+
out_path = Path("test_merge_weights_fsdp_weights")
|
| 135 |
+
if not out_path.exists():
|
| 136 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
# Train briefly once weights aren't the baseline
|
| 139 |
+
model = mock_training(accelerator, model)
|
| 140 |
+
accelerator.wait_for_everyone()
|
| 141 |
+
|
| 142 |
+
gc.collect() # Needed for some lingering refs after training
|
| 143 |
+
save_fsdp_model(plugin, accelerator, model, out_path)
|
| 144 |
+
accelerator.wait_for_everyone()
|
| 145 |
+
|
| 146 |
+
# Finally we can test
|
| 147 |
+
test_merge_weights_safetensors(model, out_path)
|
| 148 |
+
test_merge_weights_command_safetensors(model, out_path)
|
| 149 |
+
test_merge_weights_pytorch(model, out_path)
|
| 150 |
+
test_merge_weights_command_pytorch(model, out_path)
|
| 151 |
+
except Exception:
|
| 152 |
+
raise
|
| 153 |
+
finally:
|
| 154 |
+
# Cleanup in case of any failures
|
| 155 |
+
if accelerator.is_main_process:
|
| 156 |
+
shutil.rmtree(out_path)
|
| 157 |
+
accelerator.wait_for_everyone()
|
| 158 |
+
accelerator.end_training()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_notebook.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Test file to ensure that in general certain situational setups for notebooks work.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
from multiprocessing import Queue
|
| 21 |
+
|
| 22 |
+
from pytest import mark, raises
|
| 23 |
+
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
| 24 |
+
|
| 25 |
+
from accelerate import PartialState, notebook_launcher
|
| 26 |
+
from accelerate.test_utils import require_bnb
|
| 27 |
+
from accelerate.utils import is_bnb_available
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def basic_function():
|
| 31 |
+
# Just prints the PartialState
|
| 32 |
+
print(f"PartialState:\n{PartialState()}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def tough_nut_function(queue: Queue):
|
| 36 |
+
if queue.empty():
|
| 37 |
+
return
|
| 38 |
+
trial = queue.get()
|
| 39 |
+
if trial > 0:
|
| 40 |
+
queue.put(trial - 1)
|
| 41 |
+
raise RuntimeError("The nut hasn't cracked yet! Try again.")
|
| 42 |
+
|
| 43 |
+
print(f"PartialState:\n{PartialState()}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def bipolar_sleep_function(sleep_sec: int):
|
| 47 |
+
state = PartialState()
|
| 48 |
+
if state.process_index % 2 == 0:
|
| 49 |
+
raise RuntimeError("I'm an even process. I don't like to sleep.")
|
| 50 |
+
else:
|
| 51 |
+
time.sleep(sleep_sec)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_can_initialize():
|
| 58 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends")
|
| 62 |
+
def test_static_rdzv_backend():
|
| 63 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends")
|
| 67 |
+
def test_c10d_rdzv_backend():
|
| 68 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance")
|
| 72 |
+
def test_fault_tolerant(max_restarts: int = 3):
|
| 73 |
+
queue = Queue()
|
| 74 |
+
queue.put(max_restarts)
|
| 75 |
+
notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring")
|
| 79 |
+
def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100):
|
| 80 |
+
start_time = time.time()
|
| 81 |
+
with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."):
|
| 82 |
+
notebook_launcher(
|
| 83 |
+
bipolar_sleep_function,
|
| 84 |
+
(sleep_sec,),
|
| 85 |
+
num_processes=NUM_PROCESSES,
|
| 86 |
+
monitor_interval=monitor_interval,
|
| 87 |
+
)
|
| 88 |
+
assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time."
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@require_bnb
|
| 92 |
+
def test_problematic_imports():
|
| 93 |
+
with raises(RuntimeError, match="Please keep these imports"):
|
| 94 |
+
import bitsandbytes as bnb # noqa: F401
|
| 95 |
+
|
| 96 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
+
print("Test basic notebook can be ran")
|
| 101 |
+
test_can_initialize()
|
| 102 |
+
print("Test static rendezvous backend")
|
| 103 |
+
test_static_rdzv_backend()
|
| 104 |
+
print("Test c10d rendezvous backend")
|
| 105 |
+
test_c10d_rdzv_backend()
|
| 106 |
+
print("Test fault tolerant")
|
| 107 |
+
test_fault_tolerant()
|
| 108 |
+
print("Test monitoring")
|
| 109 |
+
test_monitoring()
|
| 110 |
+
if is_bnb_available():
|
| 111 |
+
print("Test problematic imports (bnb)")
|
| 112 |
+
test_problematic_imports()
|
| 113 |
+
if NUM_PROCESSES > 1:
|
| 114 |
+
PartialState().destroy_process_group()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
main()
|