BryanW commited on
Commit
1254814
·
verified ·
1 Parent(s): 24c31ad

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/hooks.cpython-312.pyc +0 -0
  2. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/memory_utils.cpython-312.pyc +0 -0
  3. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/state.cpython-312.pyc +0 -0
  4. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/__init__.cpython-312.pyc +0 -0
  5. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-312.pyc +0 -0
  6. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/env.cpython-312.pyc +0 -0
  7. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/estimate.cpython-312.pyc +0 -0
  8. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/launch.cpython-312.pyc +0 -0
  9. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/merge.cpython-312.pyc +0 -0
  10. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/test.cpython-312.pyc +0 -0
  11. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/to_fsdp2.cpython-312.pyc +0 -0
  12. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/tpu.cpython-312.pyc +0 -0
  13. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/utils.cpython-312.pyc +0 -0
  14. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__init__.py +52 -0
  15. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/__init__.cpython-312.pyc +0 -0
  16. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-312.pyc +0 -0
  17. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config.cpython-312.pyc +0 -0
  18. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-312.pyc +0 -0
  19. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-312.pyc +0 -0
  20. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/default.cpython-312.pyc +0 -0
  21. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/sagemaker.cpython-312.pyc +0 -0
  22. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/update.cpython-312.pyc +0 -0
  23. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/cluster.py +917 -0
  24. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config.py +89 -0
  25. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_args.py +256 -0
  26. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_utils.py +122 -0
  27. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/default.py +163 -0
  28. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/sagemaker.py +274 -0
  29. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/update.py +63 -0
  30. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__init__.py +14 -0
  31. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-312.pyc +0 -0
  32. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-312.pyc +0 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/helpers.cpython-312.pyc +0 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/input.cpython-312.pyc +0 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-312.pyc +0 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-312.pyc +0 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/cursor.py +65 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/helpers.py +59 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/input.py +84 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/keymap.py +133 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/selection_menu.py +145 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/__init__.cpython-312.pyc +0 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/examples.cpython-312.pyc +0 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/testing.cpython-312.pyc +0 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/training.cpython-312.pyc +0 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/__init__.py +13 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_ddp_comm_hook.py +85 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_distributed_data_loop.py +410 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_merge_weights.py +158 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_notebook.py +118 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/hooks.cpython-312.pyc ADDED
Binary file (34.4 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/memory_utils.cpython-312.pyc ADDED
Binary file (517 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/state.cpython-312.pyc ADDED
Binary file (64.5 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (225 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-312.pyc ADDED
Binary file (1.86 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/env.cpython-312.pyc ADDED
Binary file (5.17 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/estimate.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/launch.cpython-312.pyc ADDED
Binary file (51.8 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/merge.cpython-312.pyc ADDED
Binary file (2.45 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/test.cpython-312.pyc ADDED
Binary file (2.21 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/to_fsdp2.cpython-312.pyc ADDED
Binary file (6.25 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/tpu.cpython-312.pyc ADDED
Binary file (6.04 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.24 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ from .config import config_command_parser
20
+ from .config_args import default_config_file, load_config_from_file # noqa: F401
21
+ from .default import default_command_parser
22
+ from .update import update_command_parser
23
+
24
+
25
+ def get_config_parser(subparsers=None):
26
+ parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
27
+ # The main config parser
28
+ config_parser = config_command_parser(subparsers)
29
+ # The subparser to add commands to
30
+ subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand")
31
+
32
+ # Then add other parsers with the parent parser
33
+ default_command_parser(subcommands, parents=[parent_parser])
34
+ update_command_parser(subcommands, parents=[parent_parser])
35
+
36
+ return config_parser
37
+
38
+
39
+ def main():
40
+ config_parser = get_config_parser()
41
+ args = config_parser.parse_args()
42
+
43
+ if not hasattr(args, "func"):
44
+ config_parser.print_help()
45
+ exit(1)
46
+
47
+ # Run
48
+ args.func(args)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.5 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-312.pyc ADDED
Binary file (28.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config.cpython-312.pyc ADDED
Binary file (3.29 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-312.pyc ADDED
Binary file (3.97 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/default.cpython-312.pyc ADDED
Binary file (5.99 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/sagemaker.cpython-312.pyc ADDED
Binary file (9.52 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/__pycache__/update.cpython-312.pyc ADDED
Binary file (2.47 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/cluster.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+
19
+ from ...utils import (
20
+ ComputeEnvironment,
21
+ DistributedType,
22
+ is_deepspeed_available,
23
+ is_fp8_available,
24
+ is_hpu_available,
25
+ is_mlu_available,
26
+ is_mps_available,
27
+ is_msamp_available,
28
+ is_musa_available,
29
+ is_npu_available,
30
+ is_sdaa_available,
31
+ is_transformer_engine_available,
32
+ is_transformers_available,
33
+ is_xpu_available,
34
+ )
35
+ from ...utils.constants import (
36
+ DEEPSPEED_MULTINODE_LAUNCHERS,
37
+ FSDP2_STATE_DICT_TYPE,
38
+ FSDP_AUTO_WRAP_POLICY,
39
+ FSDP_BACKWARD_PREFETCH,
40
+ FSDP_SHARDING_STRATEGY,
41
+ FSDP_STATE_DICT_TYPE,
42
+ TORCH_DYNAMO_MODES,
43
+ )
44
+ from .config_args import ClusterConfig
45
+ from .config_utils import (
46
+ DYNAMO_BACKENDS,
47
+ _ask_field,
48
+ _ask_options,
49
+ _convert_distributed_mode,
50
+ _convert_dynamo_backend,
51
+ _convert_fp8_backend,
52
+ _convert_mixed_precision,
53
+ _convert_yes_no_to_bool,
54
+ )
55
+
56
+
57
+ def get_cluster_input():
58
+ distributed_type = _ask_options(
59
+ "Which type of machine are you using?",
60
+ [
61
+ "No distributed training",
62
+ "multi-CPU",
63
+ "multi-XPU",
64
+ "multi-HPU",
65
+ "multi-GPU",
66
+ "multi-NPU",
67
+ "multi-MLU",
68
+ "multi-SDAA",
69
+ "multi-MUSA",
70
+ "TPU",
71
+ ],
72
+ _convert_distributed_mode,
73
+ )
74
+
75
+ machine_rank = 0
76
+ num_machines = 1
77
+ num_processes = 1
78
+ gpu_ids = None
79
+ main_process_ip = None
80
+ main_process_port = None
81
+ rdzv_backend = "static"
82
+ same_network = True
83
+ debug = False
84
+
85
+ if distributed_type in [
86
+ DistributedType.MULTI_GPU,
87
+ DistributedType.MULTI_MLU,
88
+ DistributedType.MULTI_SDAA,
89
+ DistributedType.MULTI_MUSA,
90
+ DistributedType.MULTI_NPU,
91
+ DistributedType.MULTI_XPU,
92
+ DistributedType.MULTI_CPU,
93
+ DistributedType.MULTI_HPU,
94
+ ]:
95
+ num_machines = _ask_field(
96
+ "How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
97
+ int,
98
+ default=1,
99
+ )
100
+ if num_machines > 1:
101
+ machine_rank = _ask_options(
102
+ "What is the rank of this machine?",
103
+ list(range(num_machines)),
104
+ int,
105
+ )
106
+ main_process_ip = _ask_field(
107
+ "What is the IP address of the machine that will host the main process? ",
108
+ )
109
+ main_process_port = _ask_field(
110
+ "What is the port you will use to communicate with the main process? ",
111
+ int,
112
+ )
113
+ same_network = _ask_field(
114
+ "Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: ",
115
+ _convert_yes_no_to_bool,
116
+ default=True,
117
+ error_message="Please enter yes or no.",
118
+ )
119
+ if not same_network:
120
+ rdzv_backend = _ask_field(
121
+ "What rendezvous backend will you use? ('static', 'c10d', ...): ", default="static"
122
+ )
123
+ debug = _ask_field(
124
+ "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
125
+ _convert_yes_no_to_bool,
126
+ default=False,
127
+ error_message="Please enter yes or no.",
128
+ )
129
+
130
+ if distributed_type == DistributedType.NO:
131
+ use_cpu = _ask_field(
132
+ "Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:",
133
+ _convert_yes_no_to_bool,
134
+ default=False,
135
+ error_message="Please enter yes or no.",
136
+ )
137
+ elif distributed_type == DistributedType.MULTI_CPU:
138
+ use_cpu = True
139
+ else:
140
+ use_cpu = False
141
+
142
+ ipex_config = {}
143
+ mpirun_config = {}
144
+ if use_cpu or is_xpu_available():
145
+ ipex_config["ipex"] = _ask_field(
146
+ "Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU/XPU? [yes/NO]:",
147
+ _convert_yes_no_to_bool,
148
+ default=False,
149
+ error_message="Please enter yes or no.",
150
+ )
151
+
152
+ if use_cpu:
153
+ if distributed_type == DistributedType.MULTI_CPU:
154
+ use_mpirun = _ask_field(
155
+ "Do you want accelerate to launch mpirun? [yes/NO]: ",
156
+ _convert_yes_no_to_bool,
157
+ default=False,
158
+ error_message="Please enter yes or no.",
159
+ )
160
+ if use_mpirun:
161
+ mpirun_hostfile = _ask_field(
162
+ "Please enter the path to the hostfile to use with mpirun [~/hostfile]: ",
163
+ str,
164
+ default="~/hostfile",
165
+ )
166
+ mpirun_config["mpirun_hostfile"] = os.path.expanduser(mpirun_hostfile.strip())
167
+ mpirun_config["mpirun_ccl"] = _ask_field("Enter the number of oneCCL worker threads [1]: ", default=1)
168
+
169
+ dynamo_config = {}
170
+ use_dynamo = _ask_field(
171
+ "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
172
+ _convert_yes_no_to_bool,
173
+ default=False,
174
+ error_message="Please enter yes or no.",
175
+ )
176
+ if use_dynamo:
177
+ prefix = "dynamo_"
178
+ dynamo_config[prefix + "backend"] = _ask_options(
179
+ "Which dynamo backend would you like to use?",
180
+ [x.lower() for x in DYNAMO_BACKENDS],
181
+ _convert_dynamo_backend,
182
+ default=2,
183
+ )
184
+ use_custom_options = _ask_field(
185
+ "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
186
+ _convert_yes_no_to_bool,
187
+ default=False,
188
+ error_message="Please enter yes or no.",
189
+ )
190
+
191
+ if use_custom_options:
192
+ dynamo_config[prefix + "mode"] = _ask_options(
193
+ "Which mode do you want to use?",
194
+ TORCH_DYNAMO_MODES,
195
+ lambda x: TORCH_DYNAMO_MODES[int(x)],
196
+ default=0,
197
+ )
198
+ dynamo_config[prefix + "use_fullgraph"] = _ask_field(
199
+ "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
200
+ _convert_yes_no_to_bool,
201
+ default=False,
202
+ error_message="Please enter yes or no.",
203
+ )
204
+ dynamo_config[prefix + "use_dynamic"] = _ask_field(
205
+ "Do you want to enable dynamic shape tracing? [yes/NO]: ",
206
+ _convert_yes_no_to_bool,
207
+ default=False,
208
+ error_message="Please enter yes or no.",
209
+ )
210
+ dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
211
+ "Do you want to enable regional compilation? [yes/NO]: ",
212
+ _convert_yes_no_to_bool,
213
+ default=False,
214
+ error_message="Please enter yes or no.",
215
+ )
216
+
217
+ use_mps = not use_cpu and is_mps_available()
218
+ deepspeed_config = {}
219
+ if (
220
+ distributed_type
221
+ in [
222
+ DistributedType.MULTI_GPU,
223
+ DistributedType.MULTI_XPU,
224
+ DistributedType.MULTI_HPU,
225
+ DistributedType.MULTI_NPU,
226
+ DistributedType.MULTI_MLU,
227
+ DistributedType.MULTI_SDAA,
228
+ DistributedType.MULTI_MUSA,
229
+ DistributedType.NO,
230
+ ]
231
+ and not use_mps
232
+ ):
233
+ use_deepspeed = _ask_field(
234
+ "Do you want to use DeepSpeed? [yes/NO]: ",
235
+ _convert_yes_no_to_bool,
236
+ default=False,
237
+ error_message="Please enter yes or no.",
238
+ )
239
+ if use_deepspeed:
240
+ distributed_type = DistributedType.DEEPSPEED
241
+ assert is_deepspeed_available(), (
242
+ "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
243
+ )
244
+
245
+ if distributed_type == DistributedType.DEEPSPEED:
246
+ use_deepspeed_config = _ask_field(
247
+ "Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ",
248
+ _convert_yes_no_to_bool,
249
+ default=False,
250
+ error_message="Please enter yes or no.",
251
+ )
252
+ if use_deepspeed_config:
253
+ deepspeed_config["deepspeed_config_file"] = _ask_field(
254
+ "Please enter the path to the json DeepSpeed config file: ",
255
+ str,
256
+ default="none",
257
+ )
258
+ else:
259
+ deepspeed_config["zero_stage"] = _ask_options(
260
+ "What should be your DeepSpeed's ZeRO optimization stage?",
261
+ [0, 1, 2, 3],
262
+ int,
263
+ default=2,
264
+ )
265
+
266
+ deepspeed_devices = ["none", "cpu", "nvme"]
267
+ if deepspeed_config["zero_stage"] >= 2:
268
+ deepspeed_config["offload_optimizer_device"] = _ask_options(
269
+ "Where to offload optimizer states?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
270
+ )
271
+ deepspeed_config["offload_param_device"] = _ask_options(
272
+ "Where to offload parameters?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
273
+ )
274
+ if deepspeed_config["offload_param_device"] == "nvme":
275
+ deepspeed_config["offload_param_nvme_path"] = _ask_field(
276
+ "Nvme Path to offload parameters?",
277
+ str,
278
+ default="/nvme",
279
+ )
280
+ if deepspeed_config["offload_optimizer_device"] == "nvme":
281
+ deepspeed_config["offload_optimizer_nvme_path"] = _ask_field(
282
+ "Nvme Path to offload optimizer states?",
283
+ str,
284
+ default="/nvme",
285
+ )
286
+ deepspeed_config["gradient_accumulation_steps"] = _ask_field(
287
+ "How many gradient accumulation steps you're passing in your script? [1]: ",
288
+ int,
289
+ default=1,
290
+ )
291
+ use_gradient_clipping = _ask_field(
292
+ "Do you want to use gradient clipping? [yes/NO]: ",
293
+ _convert_yes_no_to_bool,
294
+ default=False,
295
+ error_message="Please enter yes or no.",
296
+ )
297
+ if use_gradient_clipping:
298
+ deepspeed_config["gradient_clipping"] = _ask_field(
299
+ "What is the gradient clipping value? [1.0]: ",
300
+ float,
301
+ default=1.0,
302
+ )
303
+ if deepspeed_config["zero_stage"] == 3:
304
+ deepspeed_config["zero3_save_16bit_model"] = _ask_field(
305
+ "Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ",
306
+ _convert_yes_no_to_bool,
307
+ default=False,
308
+ error_message="Please enter yes or no.",
309
+ )
310
+ deepspeed_config["zero3_init_flag"] = _ask_field(
311
+ "Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ",
312
+ _convert_yes_no_to_bool,
313
+ default=False,
314
+ error_message="Please enter yes or no.",
315
+ )
316
+ if deepspeed_config["zero3_init_flag"]:
317
+ if not is_transformers_available():
318
+ raise Exception(
319
+ "When `zero3_init_flag` is set, it requires Transformers to be installed. "
320
+ "Please run `pip3 install transformers`."
321
+ )
322
+ use_moe = _ask_field(
323
+ "Do you want to enable Mixture-of-Experts training (MoE)? [yes/NO]: ",
324
+ _convert_yes_no_to_bool,
325
+ default=False,
326
+ error_message="Please enter yes or no.",
327
+ )
328
+ if use_moe:
329
+ deepspeed_config["deepspeed_moe_layer_cls_names"] = _ask_field(
330
+ "Specify the comma-separated list of transformers MoE layer class names (case-sensitive), e.g : "
331
+ " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ... : ",
332
+ str,
333
+ )
334
+
335
+ if num_machines > 1:
336
+ launcher_query = "Which Type of launcher do you want to use?"
337
+ deepspeed_config["deepspeed_multinode_launcher"] = _ask_options(
338
+ launcher_query,
339
+ DEEPSPEED_MULTINODE_LAUNCHERS,
340
+ lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)],
341
+ )
342
+
343
+ if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
344
+ deepspeed_config["deepspeed_hostfile"] = _ask_field(
345
+ "DeepSpeed configures multi-node compute resources with hostfile. "
346
+ "Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; "
347
+ "for more information please refer official [documentation]"
348
+ "(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). "
349
+ "Please specify the location of hostfile: ",
350
+ str,
351
+ )
352
+
353
+ is_exclusion_filter = _ask_field(
354
+ "Do you want to specify exclusion filter string? [yes/NO]: ",
355
+ _convert_yes_no_to_bool,
356
+ default=False,
357
+ error_message="Please enter yes or no.",
358
+ )
359
+ if is_exclusion_filter:
360
+ deepspeed_config["deepspeed_exclusion_filter"] = _ask_field(
361
+ "DeepSpeed exclusion filter string: ",
362
+ str,
363
+ )
364
+
365
+ is_inclusion_filter = _ask_field(
366
+ "Do you want to specify inclusion filter string? [yes/NO]: ",
367
+ _convert_yes_no_to_bool,
368
+ default=False,
369
+ error_message="Please enter yes or no.",
370
+ )
371
+ if is_inclusion_filter:
372
+ deepspeed_config["deepspeed_inclusion_filter"] = _ask_field(
373
+ "DeepSpeed inclusion filter string: ",
374
+ str,
375
+ )
376
+
377
+ fsdp_config = {}
378
+
379
+ if distributed_type in [
380
+ DistributedType.MULTI_GPU,
381
+ DistributedType.MULTI_NPU,
382
+ DistributedType.MULTI_MLU,
383
+ DistributedType.MULTI_SDAA,
384
+ DistributedType.MULTI_MUSA,
385
+ DistributedType.MULTI_XPU,
386
+ DistributedType.MULTI_HPU,
387
+ ]:
388
+ use_fsdp = _ask_field(
389
+ "Do you want to use FullyShardedDataParallel? [yes/NO]: ",
390
+ _convert_yes_no_to_bool,
391
+ default=False,
392
+ error_message="Please enter yes or no.",
393
+ )
394
+ if use_fsdp:
395
+ distributed_type = DistributedType.FSDP
396
+ if distributed_type == DistributedType.FSDP:
397
+ fsdp_config["fsdp_version"] = _ask_options(
398
+ "What should be your FSDP version? [2]: ",
399
+ [1, 2],
400
+ lambda x: int(x) + 1,
401
+ default=1,
402
+ )
403
+ fsdp_version = fsdp_config["fsdp_version"] # extract to a variable to simplify usage later
404
+
405
+ if fsdp_version == 1:
406
+ sharding_strategy_query = "What should be your sharding strategy?"
407
+ fsdp_config["fsdp_reshard_after_forward"] = _ask_options(
408
+ sharding_strategy_query,
409
+ FSDP_SHARDING_STRATEGY,
410
+ lambda x: FSDP_SHARDING_STRATEGY[int(x)],
411
+ )
412
+ else:
413
+ fsdp_config["fsdp_reshard_after_forward"] = _ask_field(
414
+ "Do you want to enable resharding after forward? [YES/no]: ",
415
+ _convert_yes_no_to_bool,
416
+ default=True,
417
+ error_message="Please enter yes or no.",
418
+ )
419
+
420
+ fsdp_config["fsdp_offload_params"] = _ask_field(
421
+ "Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
422
+ _convert_yes_no_to_bool,
423
+ default=False,
424
+ error_message="Please enter yes or no.",
425
+ )
426
+
427
+ fsdp_wrap_query = "What should be your auto wrap policy?"
428
+ fsdp_config["fsdp_auto_wrap_policy"] = _ask_options(
429
+ fsdp_wrap_query,
430
+ FSDP_AUTO_WRAP_POLICY,
431
+ lambda x: FSDP_AUTO_WRAP_POLICY[int(x)],
432
+ )
433
+ if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]:
434
+ use_no_split_modules = _ask_field(
435
+ "Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: ",
436
+ _convert_yes_no_to_bool,
437
+ default=False,
438
+ error_message="Please enter yes or no.",
439
+ )
440
+ if not use_no_split_modules:
441
+ fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field(
442
+ "Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :"
443
+ "`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : ",
444
+ str,
445
+ )
446
+ elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]:
447
+ fsdp_config["fsdp_min_num_params"] = _ask_field(
448
+ "What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ",
449
+ int,
450
+ default=100000000,
451
+ )
452
+ # Removed in FSDP2, ask for user input for FSDP1
453
+ if fsdp_version == 1:
454
+ fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?"
455
+ fsdp_config["fsdp_backward_prefetch"] = _ask_options(
456
+ fsdp_backward_prefetch_query,
457
+ FSDP_BACKWARD_PREFETCH,
458
+ lambda x: FSDP_BACKWARD_PREFETCH[int(x)],
459
+ )
460
+
461
+ fsdp_state_dict_type_query = "What should be your FSDP's state dict type?"
462
+ fsdp_config["fsdp_state_dict_type"] = _ask_options(
463
+ fsdp_state_dict_type_query,
464
+ FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE,
465
+ lambda x: FSDP_STATE_DICT_TYPE[int(x)] if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE[int(x)],
466
+ default=0,
467
+ )
468
+ # Not implemented in FSDP2, ask for user input for FSDP1
469
+ if fsdp_version == 1:
470
+ fsdp_config["fsdp_forward_prefetch"] = _ask_field(
471
+ "Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ",
472
+ _convert_yes_no_to_bool,
473
+ default=False,
474
+ error_message="Please enter yes or no.",
475
+ )
476
+ # Obsolete in FSDP2, ask for user input for FSDP1
477
+ if fsdp_version == 1:
478
+ fsdp_config["fsdp_use_orig_params"] = _ask_field(
479
+ "Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ",
480
+ _convert_yes_no_to_bool,
481
+ default=True,
482
+ error_message="Please enter yes or no.",
483
+ )
484
+ fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field(
485
+ "Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ",
486
+ _convert_yes_no_to_bool,
487
+ default=True,
488
+ error_message="Please enter yes or no.",
489
+ )
490
+ # Obsolete in FSDP2, ask for user input for FSDP1
491
+ if fsdp_version == 1:
492
+ if fsdp_config["fsdp_cpu_ram_efficient_loading"]:
493
+ fsdp_config["fsdp_sync_module_states"] = True
494
+ else:
495
+ fsdp_config["fsdp_sync_module_states"] = _ask_field(
496
+ "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
497
+ _convert_yes_no_to_bool,
498
+ default=True,
499
+ error_message="Please enter yes or no.",
500
+ )
501
+ fsdp_config["fsdp_activation_checkpointing"] = _ask_field(
502
+ "Do you want to enable FSDP activation checkpointing? [yes/NO]: ",
503
+ _convert_yes_no_to_bool,
504
+ default=False,
505
+ error_message="Please enter yes or no.",
506
+ )
507
+
508
+ parallelism_config = {}
509
+
510
+ if fsdp_config.get("fsdp_version", 1) == 2:
511
+ use_parallelism_config = _ask_field(
512
+ "Do you want to use the parallelism config? [yes/NO]: ",
513
+ _convert_yes_no_to_bool,
514
+ default=False,
515
+ error_message="Please enter yes or no.",
516
+ )
517
+
518
+ if use_parallelism_config:
519
+ prefix = "parallelism_config_"
520
+ parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
521
+ "What is the data parallelism replicate size? [1]: ",
522
+ int,
523
+ default=1,
524
+ error_message="Please enter an integer.",
525
+ )
526
+
527
+ parallelism_config[prefix + "dp_shard_size"] = _ask_field(
528
+ "What is the FSDP shard size? [1]: ",
529
+ int,
530
+ default=1,
531
+ error_message="Please enter an integer.",
532
+ )
533
+
534
+ parallelism_config[prefix + "tp_size"] = _ask_field(
535
+ "What is the tensor parallelism size? [1]: ",
536
+ int,
537
+ default=1,
538
+ error_message="Please enter an integer.",
539
+ )
540
+
541
+ parallelism_config[prefix + "cp_size"] = _ask_field(
542
+ "What is the context parallelism size? [1]: ",
543
+ int,
544
+ default=1,
545
+ error_message="Please enter an integer.",
546
+ )
547
+ if parallelism_config[prefix + "cp_size"] > 1:
548
+ parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
549
+ "What is the compute parallelism communication strategy?",
550
+ ["allgather", "alltoall"],
551
+ lambda x: ["allgather", "alltoall"][int(x)],
552
+ default=0,
553
+ )
554
+
555
+ megatron_lm_config = {}
556
+ if distributed_type in [DistributedType.MULTI_GPU]:
557
+ use_megatron_lm = _ask_field(
558
+ "Do you want to use Megatron-LM ? [yes/NO]: ",
559
+ _convert_yes_no_to_bool,
560
+ default=False,
561
+ error_message="Please enter yes or no.",
562
+ )
563
+ if use_megatron_lm:
564
+ distributed_type = DistributedType.MEGATRON_LM
565
+ if distributed_type == DistributedType.MEGATRON_LM:
566
+ prefix = "megatron_lm_"
567
+ megatron_lm_config[prefix + "tp_degree"] = _ask_field(
568
+ "What is the Tensor Parallelism degree/size? [1]:",
569
+ int,
570
+ default=1,
571
+ error_message="Please enter an integer.",
572
+ )
573
+ if megatron_lm_config[prefix + "tp_degree"] > 1:
574
+ megatron_lm_config[prefix + "sequence_parallelism"] = _ask_field(
575
+ "Do you want to enable Sequence Parallelism? [YES/no]: ",
576
+ _convert_yes_no_to_bool,
577
+ default=True,
578
+ error_message="Please enter yes or no.",
579
+ )
580
+
581
+ megatron_lm_config[prefix + "pp_degree"] = _ask_field(
582
+ "What is the Pipeline Parallelism degree/size? [1]:",
583
+ int,
584
+ default=1,
585
+ error_message="Please enter an integer.",
586
+ )
587
+ if megatron_lm_config[prefix + "pp_degree"] > 1:
588
+ megatron_lm_config[prefix + "num_micro_batches"] = _ask_field(
589
+ "What is the number of micro-batches? [1]:",
590
+ int,
591
+ default=1,
592
+ error_message="Please enter an integer.",
593
+ )
594
+
595
+ megatron_lm_config[prefix + "recompute_activations"] = _ask_field(
596
+ "Do you want to enable selective activation recomputation? [YES/no]: ",
597
+ _convert_yes_no_to_bool,
598
+ default=True,
599
+ error_message="Please enter yes or no.",
600
+ )
601
+
602
+ megatron_lm_config[prefix + "use_distributed_optimizer"] = _ask_field(
603
+ "Do you want to use distributed optimizer "
604
+ "which shards optimizer state and gradients across data parallel ranks? [YES/no]: ",
605
+ _convert_yes_no_to_bool,
606
+ default=True,
607
+ error_message="Please enter yes or no.",
608
+ )
609
+
610
+ megatron_lm_config[prefix + "gradient_clipping"] = _ask_field(
611
+ "What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: ",
612
+ float,
613
+ default=1.0,
614
+ )
615
+ # TPU specific defaults
616
+ tpu_commands = None
617
+ tpu_command_file = None
618
+ tpu_downcast_bf16 = "no"
619
+ tpu_env = []
620
+ tpu_name = None
621
+ tpu_vm = None
622
+ tpu_zone = None
623
+ tpu_use_sudo = False
624
+ tpu_use_cluster = False
625
+
626
+ if distributed_type in [
627
+ DistributedType.MULTI_CPU,
628
+ DistributedType.MULTI_XPU,
629
+ DistributedType.MULTI_HPU,
630
+ DistributedType.MULTI_GPU,
631
+ DistributedType.MULTI_MLU,
632
+ DistributedType.MULTI_SDAA,
633
+ DistributedType.MULTI_MUSA,
634
+ DistributedType.MULTI_NPU,
635
+ DistributedType.XLA,
636
+ ]:
637
+ machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
638
+ if machine_type == "TPU":
639
+ machine_type += " cores"
640
+ elif machine_type == "CPU":
641
+ machine_type = "processes"
642
+ else:
643
+ machine_type += "(s)"
644
+ num_processes = _ask_field(
645
+ f"How many {machine_type} should be used for distributed training? [1]:",
646
+ int,
647
+ default=1,
648
+ error_message="Please enter an integer.",
649
+ )
650
+ elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
651
+ num_processes = _ask_field(
652
+ "How many GPU(s) should be used for distributed training? [1]:",
653
+ int,
654
+ default=1,
655
+ error_message="Please enter an integer.",
656
+ )
657
+ else:
658
+ num_processes = 1
659
+
660
+ if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1):
661
+ raise ValueError(
662
+ f"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using."
663
+ )
664
+
665
+ if (
666
+ distributed_type
667
+ in [
668
+ DistributedType.MULTI_GPU,
669
+ DistributedType.MULTI_MLU,
670
+ DistributedType.MULTI_SDAA,
671
+ DistributedType.MULTI_MUSA,
672
+ DistributedType.MULTI_NPU,
673
+ DistributedType.MULTI_XPU,
674
+ DistributedType.MULTI_HPU,
675
+ DistributedType.NO,
676
+ ]
677
+ and not use_cpu
678
+ and not use_mps
679
+ ):
680
+ if is_npu_available():
681
+ machine_type = "NPU(s)"
682
+ elif is_mlu_available():
683
+ machine_type = "MLU(s)"
684
+ elif is_sdaa_available():
685
+ machine_type = "SDAA(s)"
686
+ elif is_musa_available():
687
+ machine_type = "MUSA(s)"
688
+ elif is_xpu_available():
689
+ machine_type = "XPU(s)"
690
+ elif is_hpu_available():
691
+ machine_type = "HPU(s)"
692
+ else:
693
+ machine_type = "GPU(s)"
694
+ gpu_ids = _ask_field(
695
+ f"What {machine_type} (by id) should be used for training on this machine as a comma-separated list? [all]:",
696
+ default="all",
697
+ )
698
+
699
+ # CPU affinity is only supported on NVIDIA hardware for now
700
+ enable_cpu_affinity = False
701
+ if distributed_type in (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps:
702
+ enable_cpu_affinity = _ask_field(
703
+ "Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: ",
704
+ _convert_yes_no_to_bool,
705
+ default=False,
706
+ error_message="Please enter yes or no.",
707
+ )
708
+
709
+ fp8_config = None
710
+ if distributed_type == DistributedType.XLA:
711
+ mixed_precision = "no"
712
+ main_training_function = _ask_field(
713
+ "What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
714
+ default="main",
715
+ )
716
+ tpu_use_cluster = _ask_field(
717
+ "Are you using a TPU cluster? [yes/NO]: ",
718
+ _convert_yes_no_to_bool,
719
+ default=False,
720
+ error_message="Please enter yes or no.",
721
+ )
722
+ if tpu_use_cluster:
723
+ tpu_name = _ask_field(
724
+ "What is the name of your TPU cluster? ",
725
+ default=None,
726
+ error_message="Please enter the name of your TPU cluster.",
727
+ )
728
+ tpu_zone = _ask_field(
729
+ "What is the zone of your TPU cluster? ",
730
+ default=None,
731
+ error_message="Please enter the zone of your TPU cluster.",
732
+ )
733
+ tpu_use_sudo = _ask_field(
734
+ "To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: ",
735
+ default=False,
736
+ error_message="Please enter yes or no.",
737
+ )
738
+ run_commands = _ask_field(
739
+ "Do you have code you wish to run on startup in each pod? [yes/NO]: ",
740
+ _convert_yes_no_to_bool,
741
+ default=False,
742
+ error_message="Please enter yes or no.",
743
+ )
744
+ if run_commands:
745
+ use_command_file = _ask_field(
746
+ "Is this code located in a bash script? [yes/NO]: ",
747
+ _convert_yes_no_to_bool,
748
+ default=False,
749
+ error_message="Please enter yes or no.",
750
+ )
751
+ if use_command_file:
752
+ tpu_command_file = _ask_field(
753
+ "What is the path to your bash script? ",
754
+ default=None,
755
+ error_message="Please enter the path to your bash script.",
756
+ )
757
+ tpu_command_file = os.path.abspath(tpu_command_file)
758
+ else:
759
+ print("Please enter each command separately you wish to run on startup in each pod.")
760
+ tpu_commands = []
761
+ another_command = True
762
+ while another_command:
763
+ tpu_commands.append(
764
+ _ask_field(
765
+ "Please enter a single command to be ran ",
766
+ default=None,
767
+ error_message="Please enter the commands you wish to run on startup in each pod as a single string.",
768
+ )
769
+ )
770
+ another_command = _ask_field(
771
+ "Do you wish to add another command? [yes/NO]: ",
772
+ _convert_yes_no_to_bool,
773
+ default=False,
774
+ error_message="Please enter yes or no.",
775
+ )
776
+ tpu_vm = _ask_field(
777
+ "If not using an instance group, what are the names of the Compute VM instances to be used, separated by a comma: ",
778
+ default="",
779
+ ).split(",")
780
+ tpu_env = _ask_field(
781
+ "What environment variables do you wish to set in each pod, separated by a comma: ",
782
+ default="",
783
+ ).split(",")
784
+
785
+ else:
786
+ main_training_function = "main"
787
+ if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
788
+ mixed_precision = None
789
+ else:
790
+ mixed_precision = _ask_options(
791
+ "Do you wish to use mixed precision?",
792
+ ["no", "fp16", "bf16", "fp8"],
793
+ _convert_mixed_precision,
794
+ )
795
+ if mixed_precision == "fp8":
796
+ if not is_fp8_available():
797
+ raise ValueError("FP8 (either Transformer Engine or MSAMP) is not installed on this machine.")
798
+ fp8_config = {}
799
+ fp8_config["backend"] = _ask_options(
800
+ "Which FP8 backend do you want to use?",
801
+ ["te", "msamp"],
802
+ _convert_fp8_backend,
803
+ )
804
+ if fp8_config["backend"] == "TE":
805
+ if not is_transformer_engine_available():
806
+ raise ValueError("TransformersEngine was selected, but it is not installed on this machine.")
807
+ fp8_config["use_autocast_during_eval"] = _ask_field(
808
+ "Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ",
809
+ _convert_yes_no_to_bool,
810
+ default=False,
811
+ )
812
+ fp8_config["margin"] = _ask_field(
813
+ "What margin should be used for gradient scaling? [0]: ",
814
+ int,
815
+ default=0,
816
+ )
817
+ fp8_config["interval"] = _ask_field(
818
+ "What interval should be used for for how often the scaling factor is recomputed? [1]: ",
819
+ int,
820
+ default=1,
821
+ )
822
+ fp8_config["fp8_format"] = _ask_options(
823
+ "Which weight format should be used?",
824
+ ["HYBRID", "E4M3", "E5M2"],
825
+ lambda i: ["HYBRID", "E4M3", "E5M2"][i],
826
+ default=0,
827
+ )
828
+ fp8_config["amax_history_length"] = _ask_field(
829
+ "What length of history should be used for the amax scaling factor computation? [1024]: ",
830
+ int,
831
+ default=1024,
832
+ )
833
+ fp8_config["amax_compute_algorithm"] = _ask_options(
834
+ "Which algorithm should be used for the amax scaling factor computation?",
835
+ ["max", "most_recent"],
836
+ lambda x: "max" if x == 0 else "most_recent",
837
+ default=0,
838
+ )
839
+ fp8_config["override_linear_precision"] = _ask_field(
840
+ "Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ",
841
+ _convert_yes_no_to_bool,
842
+ default=False,
843
+ )
844
+ if fp8_config["override_linear_precision"]:
845
+ fprop = _ask_field(
846
+ "Should `fprop` be executed in higher precision? [yes/NO]: ",
847
+ _convert_yes_no_to_bool,
848
+ default=False,
849
+ )
850
+ dgrad = _ask_field(
851
+ "Should `dgrad` be executed in higher precision? [yes/NO]: ",
852
+ _convert_yes_no_to_bool,
853
+ default=False,
854
+ )
855
+ wgrad = _ask_field(
856
+ "Should `wgrad` be executed in higher precision? [yes/NO]: ",
857
+ _convert_yes_no_to_bool,
858
+ default=False,
859
+ )
860
+ fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad)
861
+ else:
862
+ fp8_config["override_linear_precision"] = (False, False, False)
863
+
864
+ elif fp8_config["backend"] == "MSAMP":
865
+ if not is_msamp_available():
866
+ raise ValueError("MSAMP was selected, but it is not installed on this machine.")
867
+ fp8_config["optimization_level"] = _ask_options(
868
+ "Which optimization level should be used?",
869
+ ["O1", "O2"],
870
+ lambda x: "O1" if x == 0 else "O2",
871
+ default=1,
872
+ )
873
+
874
+ if use_dynamo and mixed_precision == "no" and not use_cpu:
875
+ print(
876
+ "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
877
+ )
878
+
879
+ if distributed_type == DistributedType.XLA and mixed_precision == "bf16":
880
+ tpu_downcast_bf16 = _ask_field(
881
+ "Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
882
+ )
883
+
884
+ return ClusterConfig(
885
+ compute_environment=ComputeEnvironment.LOCAL_MACHINE,
886
+ distributed_type=distributed_type,
887
+ num_processes=num_processes,
888
+ gpu_ids=gpu_ids,
889
+ mixed_precision=mixed_precision,
890
+ downcast_bf16=tpu_downcast_bf16,
891
+ machine_rank=machine_rank,
892
+ num_machines=num_machines,
893
+ main_process_ip=main_process_ip,
894
+ main_process_port=main_process_port,
895
+ main_training_function=main_training_function,
896
+ fp8_config=fp8_config,
897
+ deepspeed_config=deepspeed_config,
898
+ fsdp_config=fsdp_config,
899
+ parallelism_config=parallelism_config,
900
+ megatron_lm_config=megatron_lm_config,
901
+ ipex_config=ipex_config,
902
+ mpirun_config=mpirun_config,
903
+ use_cpu=use_cpu,
904
+ rdzv_backend=rdzv_backend,
905
+ same_network=same_network,
906
+ commands=tpu_commands,
907
+ command_file=tpu_command_file,
908
+ tpu_env=tpu_env,
909
+ tpu_name=tpu_name,
910
+ tpu_vm=tpu_vm,
911
+ tpu_zone=tpu_zone,
912
+ tpu_use_sudo=tpu_use_sudo,
913
+ tpu_use_cluster=tpu_use_cluster,
914
+ dynamo_config=dynamo_config,
915
+ debug=debug,
916
+ enable_cpu_affinity=enable_cpu_affinity,
917
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+
20
+ from accelerate.utils import ComputeEnvironment
21
+
22
+ from .cluster import get_cluster_input
23
+ from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401
24
+ from .config_utils import _ask_field, _ask_options, _convert_compute_environment # noqa: F401
25
+ from .sagemaker import get_sagemaker_input
26
+
27
+
28
+ description = "Launches a series of prompts to create and save a `default_config.yaml` configuration file for your training system. Should always be ran first on your machine"
29
+
30
+
31
+ def get_user_input():
32
+ compute_environment = _ask_options(
33
+ "In which compute environment are you running?",
34
+ ["This machine", "AWS (Amazon SageMaker)"],
35
+ _convert_compute_environment,
36
+ )
37
+ if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
38
+ config = get_sagemaker_input()
39
+ else:
40
+ config = get_cluster_input()
41
+ return config
42
+
43
+
44
+ def config_command_parser(subparsers=None):
45
+ if subparsers is not None:
46
+ parser = subparsers.add_parser("config", description=description)
47
+ else:
48
+ parser = argparse.ArgumentParser("Accelerate config command", description=description)
49
+
50
+ parser.add_argument(
51
+ "--config_file",
52
+ default=None,
53
+ help=(
54
+ "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
55
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
56
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
57
+ "with 'huggingface'."
58
+ ),
59
+ )
60
+
61
+ if subparsers is not None:
62
+ parser.set_defaults(func=config_command)
63
+ return parser
64
+
65
+
66
+ def config_command(args):
67
+ config = get_user_input()
68
+ if args.config_file is not None:
69
+ config_file = args.config_file
70
+ else:
71
+ if not os.path.isdir(cache_dir):
72
+ os.makedirs(cache_dir)
73
+ config_file = default_yaml_config_file
74
+
75
+ if config_file.endswith(".json"):
76
+ config.to_json_file(config_file)
77
+ else:
78
+ config.to_yaml_file(config_file)
79
+ print(f"accelerate configuration saved at {config_file}")
80
+
81
+
82
+ def main():
83
+ parser = config_command_parser()
84
+ args = parser.parse_args()
85
+ config_command(args)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_args.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ from dataclasses import dataclass
20
+ from enum import Enum
21
+ from typing import Optional, Union
22
+
23
+ import yaml
24
+
25
+ from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
26
+ from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION
27
+
28
+
29
+ hf_cache_home = os.path.expanduser(
30
+ os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
31
+ )
32
+ cache_dir = os.path.join(hf_cache_home, "accelerate")
33
+ default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
34
+ default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
35
+
36
+ # For backward compatibility: the default config is the json one if it's the only existing file.
37
+ if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
38
+ default_config_file = default_yaml_config_file
39
+ else:
40
+ default_config_file = default_json_config_file
41
+
42
+
43
+ def load_config_from_file(config_file):
44
+ if config_file is not None:
45
+ if not os.path.isfile(config_file):
46
+ raise FileNotFoundError(
47
+ f"The passed configuration file `{config_file}` does not exist. "
48
+ "Please pass an existing file to `accelerate launch`, or use the default one "
49
+ "created through `accelerate config` and run `accelerate launch` "
50
+ "without the `--config_file` argument."
51
+ )
52
+ else:
53
+ config_file = default_config_file
54
+ with open(config_file, encoding="utf-8") as f:
55
+ if config_file.endswith(".json"):
56
+ if (
57
+ json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
58
+ == ComputeEnvironment.LOCAL_MACHINE
59
+ ):
60
+ config_class = ClusterConfig
61
+ else:
62
+ config_class = SageMakerConfig
63
+ return config_class.from_json_file(json_file=config_file)
64
+ else:
65
+ if (
66
+ yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
67
+ == ComputeEnvironment.LOCAL_MACHINE
68
+ ):
69
+ config_class = ClusterConfig
70
+ else:
71
+ config_class = SageMakerConfig
72
+ return config_class.from_yaml_file(yaml_file=config_file)
73
+
74
+
75
+ @dataclass
76
+ class BaseConfig:
77
+ compute_environment: ComputeEnvironment
78
+ distributed_type: Union[DistributedType, SageMakerDistributedType]
79
+ mixed_precision: str
80
+ use_cpu: bool
81
+ debug: bool
82
+
83
+ def to_dict(self):
84
+ result = self.__dict__
85
+ # For serialization, it's best to convert Enums to strings (or their underlying value type).
86
+
87
+ def _convert_enums(value):
88
+ if isinstance(value, Enum):
89
+ return value.value
90
+ if isinstance(value, dict):
91
+ if not bool(value):
92
+ return None
93
+ for key1, value1 in value.items():
94
+ value[key1] = _convert_enums(value1)
95
+ return value
96
+
97
+ for key, value in result.items():
98
+ result[key] = _convert_enums(value)
99
+ result = {k: v for k, v in result.items() if v is not None}
100
+ return result
101
+
102
+ @staticmethod
103
+ def process_config(config_dict):
104
+ """
105
+ Processes `config_dict` and sets default values for any missing keys
106
+ """
107
+ if "compute_environment" not in config_dict:
108
+ config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
109
+ if "distributed_type" not in config_dict:
110
+ raise ValueError("A `distributed_type` must be specified in the config file.")
111
+ if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
112
+ config_dict["num_processes"] = 1
113
+ if "mixed_precision" not in config_dict:
114
+ config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
115
+ if "fp16" in config_dict: # Convert the config to the new format.
116
+ del config_dict["fp16"]
117
+ if "dynamo_backend" in config_dict: # Convert the config to the new format.
118
+ dynamo_backend = config_dict.pop("dynamo_backend")
119
+ config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
120
+ if "use_cpu" not in config_dict:
121
+ config_dict["use_cpu"] = False
122
+ if "debug" not in config_dict:
123
+ config_dict["debug"] = False
124
+ if "enable_cpu_affinity" not in config_dict:
125
+ config_dict["enable_cpu_affinity"] = False
126
+ return config_dict
127
+
128
+ @classmethod
129
+ def from_json_file(cls, json_file=None):
130
+ json_file = default_json_config_file if json_file is None else json_file
131
+ with open(json_file, encoding="utf-8") as f:
132
+ config_dict = json.load(f)
133
+ config_dict = cls.process_config(config_dict)
134
+ extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
135
+ if len(extra_keys) > 0:
136
+ raise ValueError(
137
+ f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
138
+ " version or fix (and potentially remove) these keys from your config file."
139
+ )
140
+
141
+ return cls(**config_dict)
142
+
143
+ def to_json_file(self, json_file):
144
+ with open(json_file, "w", encoding="utf-8") as f:
145
+ content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
146
+ f.write(content)
147
+
148
+ @classmethod
149
+ def from_yaml_file(cls, yaml_file=None):
150
+ yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
151
+ with open(yaml_file, encoding="utf-8") as f:
152
+ config_dict = yaml.safe_load(f)
153
+ config_dict = cls.process_config(config_dict)
154
+ extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
155
+ if len(extra_keys) > 0:
156
+ raise ValueError(
157
+ f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
158
+ " version or fix (and potentially remove) these keys from your config file."
159
+ )
160
+ return cls(**config_dict)
161
+
162
+ def to_yaml_file(self, yaml_file):
163
+ with open(yaml_file, "w", encoding="utf-8") as f:
164
+ yaml.safe_dump(self.to_dict(), f)
165
+
166
+ def __post_init__(self):
167
+ if isinstance(self.compute_environment, str):
168
+ self.compute_environment = ComputeEnvironment(self.compute_environment)
169
+ if isinstance(self.distributed_type, str):
170
+ if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
171
+ self.distributed_type = SageMakerDistributedType(self.distributed_type)
172
+ else:
173
+ self.distributed_type = DistributedType(self.distributed_type)
174
+ if getattr(self, "dynamo_config", None) is None:
175
+ self.dynamo_config = {}
176
+
177
+
178
+ @dataclass
179
+ class ClusterConfig(BaseConfig):
180
+ num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
181
+ machine_rank: int = 0
182
+ num_machines: int = 1
183
+ gpu_ids: Optional[str] = None
184
+ main_process_ip: Optional[str] = None
185
+ main_process_port: Optional[int] = None
186
+ rdzv_backend: Optional[str] = "static"
187
+ same_network: Optional[bool] = False
188
+ main_training_function: str = "main"
189
+ enable_cpu_affinity: bool = False
190
+
191
+ # args for FP8 training
192
+ fp8_config: Optional[dict] = None
193
+ # args for deepspeed_plugin
194
+ deepspeed_config: Optional[dict] = None
195
+ # args for fsdp
196
+ fsdp_config: Optional[dict] = None
197
+ # args for parallelism config
198
+ parallelism_config: Optional[dict] = None
199
+ # args for megatron_lm
200
+ megatron_lm_config: Optional[dict] = None
201
+ # args for ipex
202
+ ipex_config: Optional[dict] = None
203
+ # args for mpirun
204
+ mpirun_config: Optional[dict] = None
205
+ # args for TPU
206
+ downcast_bf16: bool = False
207
+
208
+ # args for TPU pods
209
+ tpu_name: Optional[str] = None
210
+ tpu_zone: Optional[str] = None
211
+ tpu_use_cluster: bool = False
212
+ tpu_use_sudo: bool = False
213
+ command_file: Optional[str] = None
214
+ commands: list[str] = None
215
+ tpu_vm: list[str] = None
216
+ tpu_env: list[str] = None
217
+
218
+ # args for dynamo
219
+ dynamo_config: Optional[dict] = None
220
+
221
+ def __post_init__(self):
222
+ if self.deepspeed_config is None:
223
+ self.deepspeed_config = {}
224
+ if self.fsdp_config is None:
225
+ self.fsdp_config = {}
226
+ if self.megatron_lm_config is None:
227
+ self.megatron_lm_config = {}
228
+ if self.ipex_config is None:
229
+ self.ipex_config = {}
230
+ if self.mpirun_config is None:
231
+ self.mpirun_config = {}
232
+ if self.fp8_config is None:
233
+ self.fp8_config = {}
234
+ if self.parallelism_config is None:
235
+ self.parallelism_config = {}
236
+ return super().__post_init__()
237
+
238
+
239
+ @dataclass
240
+ class SageMakerConfig(BaseConfig):
241
+ ec2_instance_type: str
242
+ iam_role_name: str
243
+ image_uri: Optional[str] = None
244
+ profile: Optional[str] = None
245
+ region: str = "us-east-1"
246
+ num_machines: int = 1
247
+ gpu_ids: str = "all"
248
+ base_job_name: str = f"accelerate-sagemaker-{num_machines}"
249
+ pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
250
+ transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
251
+ py_version: str = SAGEMAKER_PYTHON_VERSION
252
+ sagemaker_inputs_file: Optional[str] = None
253
+ sagemaker_metrics_file: Optional[str] = None
254
+ additional_args: Optional[dict] = None
255
+ dynamo_config: Optional[dict] = None
256
+ enable_cpu_affinity: bool = False
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/config_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ from ...utils.dataclasses import (
20
+ ComputeEnvironment,
21
+ DistributedType,
22
+ DynamoBackend,
23
+ FP8BackendType,
24
+ PrecisionType,
25
+ SageMakerDistributedType,
26
+ )
27
+ from ..menu import BulletMenu
28
+
29
+
30
+ DYNAMO_BACKENDS = [
31
+ "EAGER",
32
+ "AOT_EAGER",
33
+ "INDUCTOR",
34
+ "AOT_TS_NVFUSER",
35
+ "NVPRIMS_NVFUSER",
36
+ "CUDAGRAPHS",
37
+ "OFI",
38
+ "FX2TRT",
39
+ "ONNXRT",
40
+ "TENSORRT",
41
+ "AOT_TORCHXLA_TRACE_ONCE",
42
+ "TORHCHXLA_TRACE_ONCE",
43
+ "IPEX",
44
+ "TVM",
45
+ ]
46
+
47
+
48
+ def _ask_field(input_text, convert_value=None, default=None, error_message=None):
49
+ ask_again = True
50
+ while ask_again:
51
+ result = input(input_text)
52
+ try:
53
+ if default is not None and len(result) == 0:
54
+ return default
55
+ return convert_value(result) if convert_value is not None else result
56
+ except Exception:
57
+ if error_message is not None:
58
+ print(error_message)
59
+
60
+
61
+ def _ask_options(input_text, options=[], convert_value=None, default=0):
62
+ menu = BulletMenu(input_text, options)
63
+ result = menu.run(default_choice=default)
64
+ return convert_value(result) if convert_value is not None else result
65
+
66
+
67
+ def _convert_compute_environment(value):
68
+ value = int(value)
69
+ return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
70
+
71
+
72
+ def _convert_distributed_mode(value):
73
+ value = int(value)
74
+ return DistributedType(
75
+ [
76
+ "NO",
77
+ "MULTI_CPU",
78
+ "MULTI_XPU",
79
+ "MULTI_HPU",
80
+ "MULTI_GPU",
81
+ "MULTI_NPU",
82
+ "MULTI_MLU",
83
+ "MULTI_SDAA",
84
+ "MULTI_MUSA",
85
+ "XLA",
86
+ ][value]
87
+ )
88
+
89
+
90
+ def _convert_dynamo_backend(value):
91
+ value = int(value)
92
+ return DynamoBackend(DYNAMO_BACKENDS[value]).value
93
+
94
+
95
+ def _convert_mixed_precision(value):
96
+ value = int(value)
97
+ return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
98
+
99
+
100
+ def _convert_sagemaker_distributed_mode(value):
101
+ value = int(value)
102
+ return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
103
+
104
+
105
+ def _convert_fp8_backend(value):
106
+ value = int(value)
107
+ return FP8BackendType(["TE", "MSAMP"][value])
108
+
109
+
110
+ def _convert_yes_no_to_bool(value):
111
+ return {"yes": True, "no": False}[value.lower()]
112
+
113
+
114
+ class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
115
+ """
116
+ A custom formatter that will remove the usage line from the help message for subcommands.
117
+ """
118
+
119
+ def _format_usage(self, usage, actions, groups, prefix):
120
+ usage = super()._format_usage(usage, actions, groups, prefix)
121
+ usage = usage.replace("<command> [<args>] ", "")
122
+ return usage
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/default.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ from ...utils import (
22
+ is_hpu_available,
23
+ is_mlu_available,
24
+ is_musa_available,
25
+ is_npu_available,
26
+ is_sdaa_available,
27
+ is_xpu_available,
28
+ )
29
+ from .config_args import ClusterConfig, default_json_config_file
30
+ from .config_utils import SubcommandHelpFormatter
31
+
32
+
33
+ description = "Create a default config file for Accelerate with only a few flags set."
34
+
35
+
36
+ def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file):
37
+ """
38
+ Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
39
+ set CPU if it is a CPU-only machine.
40
+
41
+ Args:
42
+ mixed_precision (`str`, *optional*, defaults to "no"):
43
+ Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
44
+ save_location (`str`, *optional*, defaults to `default_json_config_file`):
45
+ Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default
46
+ location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overridden by setting
47
+ the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.
48
+ """
49
+ path = Path(save_location)
50
+ path.parent.mkdir(parents=True, exist_ok=True)
51
+ if path.exists():
52
+ print(
53
+ f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
54
+ )
55
+ return False
56
+ mixed_precision = mixed_precision.lower()
57
+ if mixed_precision not in ["no", "fp16", "bf16", "fp8"]:
58
+ raise ValueError(
59
+ f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}"
60
+ )
61
+ config = {
62
+ "compute_environment": "LOCAL_MACHINE",
63
+ "mixed_precision": mixed_precision,
64
+ }
65
+ if is_mlu_available():
66
+ num_mlus = torch.mlu.device_count()
67
+ config["num_processes"] = num_mlus
68
+ config["use_cpu"] = False
69
+ if num_mlus > 1:
70
+ config["distributed_type"] = "MULTI_MLU"
71
+ else:
72
+ config["distributed_type"] = "NO"
73
+ if is_sdaa_available():
74
+ num_sdaas = torch.sdaa.device_count()
75
+ config["num_processes"] = num_sdaas
76
+ config["use_cpu"] = False
77
+ if num_sdaas > 1:
78
+ config["distributed_type"] = "MULTI_SDAA"
79
+ else:
80
+ config["distributed_type"] = "NO"
81
+ elif is_musa_available():
82
+ num_musas = torch.musa.device_count()
83
+ config["num_processes"] = num_musas
84
+ config["use_cpu"] = False
85
+ if num_musas > 1:
86
+ config["distributed_type"] = "MULTI_MUSA"
87
+ else:
88
+ config["distributed_type"] = "NO"
89
+ elif is_hpu_available():
90
+ num_hpus = torch.hpu.device_count()
91
+ config["num_processes"] = num_hpus
92
+ config["use_cpu"] = False
93
+ if num_hpus > 1:
94
+ config["distributed_type"] = "MULTI_HPU"
95
+ else:
96
+ config["distributed_type"] = "NO"
97
+ elif torch.cuda.is_available():
98
+ num_gpus = torch.cuda.device_count()
99
+ config["num_processes"] = num_gpus
100
+ config["use_cpu"] = False
101
+ if num_gpus > 1:
102
+ config["distributed_type"] = "MULTI_GPU"
103
+ else:
104
+ config["distributed_type"] = "NO"
105
+ elif is_xpu_available():
106
+ num_xpus = torch.xpu.device_count()
107
+ config["num_processes"] = num_xpus
108
+ config["use_cpu"] = False
109
+ if num_xpus > 1:
110
+ config["distributed_type"] = "MULTI_XPU"
111
+ else:
112
+ config["distributed_type"] = "NO"
113
+ elif is_npu_available():
114
+ num_npus = torch.npu.device_count()
115
+ config["num_processes"] = num_npus
116
+ config["use_cpu"] = False
117
+ if num_npus > 1:
118
+ config["distributed_type"] = "MULTI_NPU"
119
+ else:
120
+ config["distributed_type"] = "NO"
121
+ else:
122
+ num_xpus = 0
123
+ config["use_cpu"] = True
124
+ config["num_processes"] = 1
125
+ config["distributed_type"] = "NO"
126
+ config["debug"] = False
127
+ config["enable_cpu_affinity"] = False
128
+ config = ClusterConfig(**config)
129
+ config.to_json_file(path)
130
+ return path
131
+
132
+
133
+ def default_command_parser(parser, parents):
134
+ parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
135
+ parser.add_argument(
136
+ "--config_file",
137
+ default=default_json_config_file,
138
+ help=(
139
+ "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
140
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
141
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
142
+ "with 'huggingface'."
143
+ ),
144
+ dest="save_location",
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--mixed_precision",
149
+ choices=["no", "fp16", "bf16"],
150
+ type=str,
151
+ help="Whether or not to use mixed precision training. "
152
+ "Choose between FP16 and BF16 (bfloat16) training. "
153
+ "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
154
+ default="no",
155
+ )
156
+ parser.set_defaults(func=default_config_command)
157
+ return parser
158
+
159
+
160
+ def default_config_command(args):
161
+ config_file = write_basic_config(args.mixed_precision, args.save_location)
162
+ if config_file:
163
+ print(f"accelerate configuration saved at {config_file}")
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/sagemaker.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import json
17
+ import os
18
+
19
+ from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
20
+ from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
21
+ from ...utils.imports import is_boto3_available
22
+ from .config_args import SageMakerConfig
23
+ from .config_utils import (
24
+ DYNAMO_BACKENDS,
25
+ _ask_field,
26
+ _ask_options,
27
+ _convert_dynamo_backend,
28
+ _convert_mixed_precision,
29
+ _convert_sagemaker_distributed_mode,
30
+ _convert_yes_no_to_bool,
31
+ )
32
+
33
+
34
+ if is_boto3_available():
35
+ import boto3 # noqa: F401
36
+
37
+
38
+ def _create_iam_role_for_sagemaker(role_name):
39
+ iam_client = boto3.client("iam")
40
+
41
+ sagemaker_trust_policy = {
42
+ "Version": "2012-10-17",
43
+ "Statement": [
44
+ {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
45
+ ],
46
+ }
47
+ try:
48
+ # create the role, associated with the chosen trust policy
49
+ iam_client.create_role(
50
+ RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
51
+ )
52
+ policy_document = {
53
+ "Version": "2012-10-17",
54
+ "Statement": [
55
+ {
56
+ "Effect": "Allow",
57
+ "Action": [
58
+ "sagemaker:*",
59
+ "ecr:GetDownloadUrlForLayer",
60
+ "ecr:BatchGetImage",
61
+ "ecr:BatchCheckLayerAvailability",
62
+ "ecr:GetAuthorizationToken",
63
+ "cloudwatch:PutMetricData",
64
+ "cloudwatch:GetMetricData",
65
+ "cloudwatch:GetMetricStatistics",
66
+ "cloudwatch:ListMetrics",
67
+ "logs:CreateLogGroup",
68
+ "logs:CreateLogStream",
69
+ "logs:DescribeLogStreams",
70
+ "logs:PutLogEvents",
71
+ "logs:GetLogEvents",
72
+ "s3:CreateBucket",
73
+ "s3:ListBucket",
74
+ "s3:GetBucketLocation",
75
+ "s3:GetObject",
76
+ "s3:PutObject",
77
+ ],
78
+ "Resource": "*",
79
+ }
80
+ ],
81
+ }
82
+ # attach policy to role
83
+ iam_client.put_role_policy(
84
+ RoleName=role_name,
85
+ PolicyName=f"{role_name}_policy_permission",
86
+ PolicyDocument=json.dumps(policy_document, indent=2),
87
+ )
88
+ except iam_client.exceptions.EntityAlreadyExistsException:
89
+ print(f"role {role_name} already exists. Using existing one")
90
+
91
+
92
+ def _get_iam_role_arn(role_name):
93
+ iam_client = boto3.client("iam")
94
+ return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
95
+
96
+
97
+ def get_sagemaker_input():
98
+ credentials_configuration = _ask_options(
99
+ "How do you want to authorize?",
100
+ ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
101
+ int,
102
+ )
103
+ aws_profile = None
104
+ if credentials_configuration == 0:
105
+ aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
106
+ os.environ["AWS_PROFILE"] = aws_profile
107
+ else:
108
+ print(
109
+ "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
110
+ "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
111
+ )
112
+ aws_access_key_id = _ask_field("AWS Access Key ID: ")
113
+ os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
114
+
115
+ aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
116
+ os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
117
+
118
+ aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
119
+ os.environ["AWS_DEFAULT_REGION"] = aws_region
120
+
121
+ role_management = _ask_options(
122
+ "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
123
+ ["Provide IAM Role name", "Create new IAM role using credentials"],
124
+ int,
125
+ )
126
+ if role_management == 0:
127
+ iam_role_name = _ask_field("Enter your IAM role name: ")
128
+ else:
129
+ iam_role_name = "accelerate_sagemaker_execution_role"
130
+ print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
131
+ _create_iam_role_for_sagemaker(iam_role_name)
132
+
133
+ is_custom_docker_image = _ask_field(
134
+ "Do you want to use custom Docker image? [yes/NO]: ",
135
+ _convert_yes_no_to_bool,
136
+ default=False,
137
+ error_message="Please enter yes or no.",
138
+ )
139
+ docker_image = None
140
+ if is_custom_docker_image:
141
+ docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
142
+
143
+ is_sagemaker_inputs_enabled = _ask_field(
144
+ "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
145
+ _convert_yes_no_to_bool,
146
+ default=False,
147
+ error_message="Please enter yes or no.",
148
+ )
149
+ sagemaker_inputs_file = None
150
+ if is_sagemaker_inputs_enabled:
151
+ sagemaker_inputs_file = _ask_field(
152
+ "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
153
+ lambda x: str(x).lower(),
154
+ )
155
+
156
+ is_sagemaker_metrics_enabled = _ask_field(
157
+ "Do you want to enable SageMaker metrics? [yes/NO]: ",
158
+ _convert_yes_no_to_bool,
159
+ default=False,
160
+ error_message="Please enter yes or no.",
161
+ )
162
+ sagemaker_metrics_file = None
163
+ if is_sagemaker_metrics_enabled:
164
+ sagemaker_metrics_file = _ask_field(
165
+ "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
166
+ lambda x: str(x).lower(),
167
+ )
168
+
169
+ distributed_type = _ask_options(
170
+ "What is the distributed mode?",
171
+ ["No distributed training", "Data parallelism"],
172
+ _convert_sagemaker_distributed_mode,
173
+ )
174
+ dynamo_config = {}
175
+ use_dynamo = _ask_field(
176
+ "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
177
+ _convert_yes_no_to_bool,
178
+ default=False,
179
+ error_message="Please enter yes or no.",
180
+ )
181
+ if use_dynamo:
182
+ prefix = "dynamo_"
183
+ dynamo_config[prefix + "backend"] = _ask_options(
184
+ "Which dynamo backend would you like to use?",
185
+ [x.lower() for x in DYNAMO_BACKENDS],
186
+ _convert_dynamo_backend,
187
+ default=2,
188
+ )
189
+ use_custom_options = _ask_field(
190
+ "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
191
+ _convert_yes_no_to_bool,
192
+ default=False,
193
+ error_message="Please enter yes or no.",
194
+ )
195
+
196
+ if use_custom_options:
197
+ dynamo_config[prefix + "mode"] = _ask_options(
198
+ "Which mode do you want to use?",
199
+ TORCH_DYNAMO_MODES,
200
+ lambda x: TORCH_DYNAMO_MODES[int(x)],
201
+ default="default",
202
+ )
203
+ dynamo_config[prefix + "use_fullgraph"] = _ask_field(
204
+ "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
205
+ _convert_yes_no_to_bool,
206
+ default=False,
207
+ error_message="Please enter yes or no.",
208
+ )
209
+ dynamo_config[prefix + "use_dynamic"] = _ask_field(
210
+ "Do you want to enable dynamic shape tracing? [yes/NO]: ",
211
+ _convert_yes_no_to_bool,
212
+ default=False,
213
+ error_message="Please enter yes or no.",
214
+ )
215
+ dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
216
+ "Do you want to enable regional compilation? [yes/NO]: ",
217
+ _convert_yes_no_to_bool,
218
+ default=False,
219
+ error_message="Please enter yes or no.",
220
+ )
221
+
222
+ ec2_instance_query = "Which EC2 instance type you want to use for your training?"
223
+ if distributed_type != SageMakerDistributedType.NO:
224
+ ec2_instance_type = _ask_options(
225
+ ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
226
+ )
227
+ else:
228
+ ec2_instance_query += "? [ml.p3.2xlarge]:"
229
+ ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
230
+
231
+ debug = False
232
+ if distributed_type != SageMakerDistributedType.NO:
233
+ debug = _ask_field(
234
+ "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
235
+ _convert_yes_no_to_bool,
236
+ default=False,
237
+ error_message="Please enter yes or no.",
238
+ )
239
+
240
+ num_machines = 1
241
+ if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
242
+ num_machines = _ask_field(
243
+ "How many machines do you want use? [1]: ",
244
+ int,
245
+ default=1,
246
+ )
247
+
248
+ mixed_precision = _ask_options(
249
+ "Do you wish to use FP16 or BF16 (mixed precision)?",
250
+ ["no", "fp16", "bf16", "fp8"],
251
+ _convert_mixed_precision,
252
+ )
253
+
254
+ if use_dynamo and mixed_precision == "no":
255
+ print(
256
+ "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
257
+ )
258
+
259
+ return SageMakerConfig(
260
+ image_uri=docker_image,
261
+ compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
262
+ distributed_type=distributed_type,
263
+ use_cpu=False,
264
+ dynamo_config=dynamo_config,
265
+ ec2_instance_type=ec2_instance_type,
266
+ profile=aws_profile,
267
+ region=aws_region,
268
+ iam_role_name=iam_role_name,
269
+ mixed_precision=mixed_precision,
270
+ num_machines=num_machines,
271
+ sagemaker_inputs_file=sagemaker_inputs_file,
272
+ sagemaker_metrics_file=sagemaker_metrics_file,
273
+ debug=debug,
274
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/config/update.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from pathlib import Path
18
+
19
+ from .config_args import default_config_file, load_config_from_file
20
+ from .config_utils import SubcommandHelpFormatter
21
+
22
+
23
+ description = "Update an existing config file with the latest defaults while maintaining the old configuration."
24
+
25
+
26
+ def update_config(args):
27
+ """
28
+ Update an existing config file with the latest defaults while maintaining the old configuration.
29
+ """
30
+ config_file = args.config_file
31
+ if config_file is None and Path(default_config_file).exists():
32
+ config_file = default_config_file
33
+ elif not Path(config_file).exists():
34
+ raise ValueError(f"The passed config file located at {config_file} doesn't exist.")
35
+ config = load_config_from_file(config_file)
36
+
37
+ if config_file.endswith(".json"):
38
+ config.to_json_file(config_file)
39
+ else:
40
+ config.to_yaml_file(config_file)
41
+ return config_file
42
+
43
+
44
+ def update_command_parser(parser, parents):
45
+ parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
46
+ parser.add_argument(
47
+ "--config_file",
48
+ default=None,
49
+ help=(
50
+ "The path to the config file to update. Will default to a file named default_config.yaml in the cache "
51
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
52
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
53
+ "with 'huggingface'."
54
+ ),
55
+ )
56
+
57
+ parser.set_defaults(func=update_config_command)
58
+ return parser
59
+
60
+
61
+ def update_config_command(args):
62
+ config_file = update_config(args)
63
+ print(f"Successfully updated the configuration file at {config_file}.")
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .selection_menu import BulletMenu
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (284 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-312.pyc ADDED
Binary file (3.06 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/helpers.cpython-312.pyc ADDED
Binary file (2.21 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/input.cpython-312.pyc ADDED
Binary file (3.17 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-312.pyc ADDED
Binary file (4.52 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-312.pyc ADDED
Binary file (7.47 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/cursor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ A utility for showing and hiding the terminal cursor on Windows and Linux, based on https://github.com/bchao1/bullet
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ from contextlib import contextmanager
22
+
23
+
24
+ # Windows only
25
+ if os.name == "nt":
26
+ import ctypes
27
+ import msvcrt # noqa
28
+
29
+ class CursorInfo(ctypes.Structure):
30
+ # _fields is a specific attr expected by ctypes
31
+ _fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
32
+
33
+
34
+ def hide_cursor():
35
+ if os.name == "nt":
36
+ ci = CursorInfo()
37
+ handle = ctypes.windll.kernel32.GetStdHandle(-11)
38
+ ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
39
+ ci.visible = False
40
+ ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
41
+ elif os.name == "posix":
42
+ sys.stdout.write("\033[?25l")
43
+ sys.stdout.flush()
44
+
45
+
46
+ def show_cursor():
47
+ if os.name == "nt":
48
+ ci = CursorInfo()
49
+ handle = ctypes.windll.kernel32.GetStdHandle(-11)
50
+ ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
51
+ ci.visible = True
52
+ ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
53
+ elif os.name == "posix":
54
+ sys.stdout.write("\033[?25h")
55
+ sys.stdout.flush()
56
+
57
+
58
+ @contextmanager
59
+ def hide():
60
+ "Context manager to hide the terminal cursor"
61
+ try:
62
+ hide_cursor()
63
+ yield
64
+ finally:
65
+ show_cursor()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/helpers.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ A variety of helper functions and constants when dealing with terminal menu choices, based on
17
+ https://github.com/bchao1/bullet
18
+ """
19
+
20
+ import enum
21
+ import shutil
22
+ import sys
23
+
24
+
25
+ TERMINAL_WIDTH, _ = shutil.get_terminal_size()
26
+
27
+ CURSOR_TO_CHAR = {"UP": "A", "DOWN": "B", "RIGHT": "C", "LEFT": "D"}
28
+
29
+
30
+ class Direction(enum.Enum):
31
+ UP = 0
32
+ DOWN = 1
33
+
34
+
35
+ def forceWrite(content, end=""):
36
+ sys.stdout.write(str(content) + end)
37
+ sys.stdout.flush()
38
+
39
+
40
+ def writeColor(content, color, end=""):
41
+ forceWrite(f"\u001b[{color}m{content}\u001b[0m", end)
42
+
43
+
44
+ def reset_cursor():
45
+ forceWrite("\r")
46
+
47
+
48
+ def move_cursor(num_lines: int, direction: str):
49
+ forceWrite(f"\033[{num_lines}{CURSOR_TO_CHAR[direction.upper()]}")
50
+
51
+
52
+ def clear_line():
53
+ forceWrite(" " * TERMINAL_WIDTH)
54
+ reset_cursor()
55
+
56
+
57
+ def linebreak():
58
+ reset_cursor()
59
+ forceWrite("-" * TERMINAL_WIDTH)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/input.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This file contains utilities for handling input from the user and registering specific keys to specific functions,
17
+ based on https://github.com/bchao1/bullet
18
+ """
19
+
20
+ from .keymap import KEYMAP, get_character
21
+
22
+
23
+ def mark(key: str):
24
+ """
25
+ Mark the function with the key code so it can be handled in the register
26
+ """
27
+
28
+ def decorator(func):
29
+ handle = getattr(func, "handle_key", [])
30
+ handle += [key]
31
+ func.handle_key = handle
32
+ return func
33
+
34
+ return decorator
35
+
36
+
37
+ def mark_multiple(*keys: list[str]):
38
+ """
39
+ Mark the function with the key codes so it can be handled in the register
40
+ """
41
+
42
+ def decorator(func):
43
+ handle = getattr(func, "handle_key", [])
44
+ handle += keys
45
+ func.handle_key = handle
46
+ return func
47
+
48
+ return decorator
49
+
50
+
51
+ class KeyHandler(type):
52
+ """
53
+ Metaclass that adds the key handlers to the class
54
+ """
55
+
56
+ def __new__(cls, name, bases, attrs):
57
+ new_cls = super().__new__(cls, name, bases, attrs)
58
+ if not hasattr(new_cls, "key_handler"):
59
+ new_cls.key_handler = {}
60
+ new_cls.handle_input = KeyHandler.handle_input
61
+
62
+ for value in attrs.values():
63
+ handled_keys = getattr(value, "handle_key", [])
64
+ for key in handled_keys:
65
+ new_cls.key_handler[key] = value
66
+ return new_cls
67
+
68
+ @staticmethod
69
+ def handle_input(cls):
70
+ "Finds and returns the selected character if it exists in the handler"
71
+ char = get_character()
72
+ if char != KEYMAP["undefined"]:
73
+ char = ord(char)
74
+ handler = cls.key_handler.get(char)
75
+ if handler:
76
+ cls.current_selection = char
77
+ return handler(cls)
78
+ else:
79
+ return None
80
+
81
+
82
+ def register(cls):
83
+ """Adds KeyHandler metaclass to the class"""
84
+ return KeyHandler(cls.__name__, cls.__bases__, cls.__dict__.copy())
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/keymap.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Utilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet
17
+ """
18
+
19
+ import os
20
+ import string
21
+ import sys
22
+
23
+
24
+ ARROW_KEY_FLAG = 1 << 8
25
+
26
+ KEYMAP = {
27
+ "tab": ord("\t"),
28
+ "newline": ord("\r"),
29
+ "esc": 27,
30
+ "up": 65 + ARROW_KEY_FLAG,
31
+ "down": 66 + ARROW_KEY_FLAG,
32
+ "right": 67 + ARROW_KEY_FLAG,
33
+ "left": 68 + ARROW_KEY_FLAG,
34
+ "mod_int": 91,
35
+ "undefined": sys.maxsize,
36
+ "interrupt": 3,
37
+ "insert": 50,
38
+ "delete": 51,
39
+ "pg_up": 53,
40
+ "pg_down": 54,
41
+ }
42
+
43
+ KEYMAP["arrow_begin"] = KEYMAP["up"]
44
+ KEYMAP["arrow_end"] = KEYMAP["left"]
45
+
46
+ if sys.platform == "win32":
47
+ WIN_CH_BUFFER = []
48
+ WIN_KEYMAP = {
49
+ b"\xe0H": KEYMAP["up"] - ARROW_KEY_FLAG,
50
+ b"\x00H": KEYMAP["up"] - ARROW_KEY_FLAG,
51
+ b"\xe0P": KEYMAP["down"] - ARROW_KEY_FLAG,
52
+ b"\x00P": KEYMAP["down"] - ARROW_KEY_FLAG,
53
+ b"\xe0M": KEYMAP["right"] - ARROW_KEY_FLAG,
54
+ b"\x00M": KEYMAP["right"] - ARROW_KEY_FLAG,
55
+ b"\xe0K": KEYMAP["left"] - ARROW_KEY_FLAG,
56
+ b"\x00K": KEYMAP["left"] - ARROW_KEY_FLAG,
57
+ }
58
+
59
+ for i in range(10):
60
+ KEYMAP[str(i)] = ord(str(i))
61
+
62
+
63
+ def get_raw_chars():
64
+ "Gets raw characters from inputs"
65
+ if os.name == "nt":
66
+ import msvcrt
67
+
68
+ encoding = "mbcs"
69
+ # Flush the keyboard buffer
70
+ while msvcrt.kbhit():
71
+ msvcrt.getch()
72
+ if len(WIN_CH_BUFFER) == 0:
73
+ # Read the keystroke
74
+ ch = msvcrt.getch()
75
+
76
+ # If it is a prefix char, get second part
77
+ if ch in (b"\x00", b"\xe0"):
78
+ ch2 = ch + msvcrt.getch()
79
+ # Translate actual Win chars to bullet char types
80
+ try:
81
+ chx = chr(WIN_KEYMAP[ch2])
82
+ WIN_CH_BUFFER.append(chr(KEYMAP["mod_int"]))
83
+ WIN_CH_BUFFER.append(chx)
84
+ if ord(chx) in (
85
+ KEYMAP["insert"] - 1 << 9,
86
+ KEYMAP["delete"] - 1 << 9,
87
+ KEYMAP["pg_up"] - 1 << 9,
88
+ KEYMAP["pg_down"] - 1 << 9,
89
+ ):
90
+ WIN_CH_BUFFER.append(chr(126))
91
+ ch = chr(KEYMAP["esc"])
92
+ except KeyError:
93
+ ch = ch2[1]
94
+ else:
95
+ ch = ch.decode(encoding)
96
+ else:
97
+ ch = WIN_CH_BUFFER.pop(0)
98
+ elif os.name == "posix":
99
+ import termios
100
+ import tty
101
+
102
+ fd = sys.stdin.fileno()
103
+ old_settings = termios.tcgetattr(fd)
104
+ try:
105
+ tty.setraw(fd)
106
+ ch = sys.stdin.read(1)
107
+ finally:
108
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
109
+ return ch
110
+
111
+
112
+ def get_character():
113
+ "Gets a character from the keyboard and returns the key code"
114
+ char = get_raw_chars()
115
+ if ord(char) in [KEYMAP["interrupt"], KEYMAP["newline"]]:
116
+ return char
117
+
118
+ elif ord(char) == KEYMAP["esc"]:
119
+ combo = get_raw_chars()
120
+ if ord(combo) == KEYMAP["mod_int"]:
121
+ key = get_raw_chars()
122
+ if ord(key) >= KEYMAP["arrow_begin"] - ARROW_KEY_FLAG and ord(key) <= KEYMAP["arrow_end"] - ARROW_KEY_FLAG:
123
+ return chr(ord(key) + ARROW_KEY_FLAG)
124
+ else:
125
+ return KEYMAP["undefined"]
126
+ else:
127
+ return get_raw_chars()
128
+
129
+ else:
130
+ if char in string.printable:
131
+ return char
132
+ else:
133
+ return KEYMAP["undefined"]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/menu/selection_menu.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Main driver for the selection menu, based on https://github.com/bchao1/bullet
17
+ """
18
+
19
+ import builtins
20
+ import sys
21
+ from typing import Optional
22
+
23
+ from ...utils.imports import _is_package_available
24
+ from . import cursor, input
25
+ from .helpers import Direction, clear_line, forceWrite, linebreak, move_cursor, reset_cursor, writeColor
26
+ from .keymap import KEYMAP
27
+
28
+
29
+ in_colab = False
30
+ try:
31
+ in_colab = _is_package_available("google.colab")
32
+ except ModuleNotFoundError:
33
+ pass
34
+
35
+
36
+ @input.register
37
+ class BulletMenu:
38
+ """
39
+ A CLI menu to select a choice from a list of choices using the keyboard.
40
+ """
41
+
42
+ def __init__(self, prompt: Optional[str] = None, choices: list = []):
43
+ self.position = 0
44
+ self.choices = choices
45
+ self.prompt = prompt
46
+ if sys.platform == "win32":
47
+ self.arrow_char = "*"
48
+ else:
49
+ self.arrow_char = "➔ "
50
+
51
+ def write_choice(self, index, end: str = ""):
52
+ if sys.platform != "win32":
53
+ writeColor(self.choices[index], 32, end)
54
+ else:
55
+ forceWrite(self.choices[index], end)
56
+
57
+ def print_choice(self, index: int):
58
+ "Prints the choice at the given index"
59
+ if index == self.position:
60
+ forceWrite(f" {self.arrow_char} ")
61
+ self.write_choice(index)
62
+ else:
63
+ forceWrite(f" {self.choices[index]}")
64
+ reset_cursor()
65
+
66
+ def move_direction(self, direction: Direction, num_spaces: int = 1):
67
+ "Should not be directly called, used to move a direction of either up or down"
68
+ old_position = self.position
69
+ if direction == Direction.DOWN:
70
+ if self.position + 1 >= len(self.choices):
71
+ return
72
+ self.position += num_spaces
73
+ else:
74
+ if self.position - 1 < 0:
75
+ return
76
+ self.position -= num_spaces
77
+ clear_line()
78
+ self.print_choice(old_position)
79
+ move_cursor(num_spaces, direction.name)
80
+ self.print_choice(self.position)
81
+
82
+ @input.mark(KEYMAP["up"])
83
+ def move_up(self):
84
+ self.move_direction(Direction.UP)
85
+
86
+ @input.mark(KEYMAP["down"])
87
+ def move_down(self):
88
+ self.move_direction(Direction.DOWN)
89
+
90
+ @input.mark(KEYMAP["newline"])
91
+ def select(self):
92
+ move_cursor(len(self.choices) - self.position, "DOWN")
93
+ return self.position
94
+
95
+ @input.mark(KEYMAP["interrupt"])
96
+ def interrupt(self):
97
+ move_cursor(len(self.choices) - self.position, "DOWN")
98
+ raise KeyboardInterrupt
99
+
100
+ @input.mark_multiple(*[KEYMAP[str(number)] for number in range(10)])
101
+ def select_row(self):
102
+ index = int(chr(self.current_selection))
103
+ movement = index - self.position
104
+ if index == self.position:
105
+ return
106
+ if index < len(self.choices):
107
+ if self.position > index:
108
+ self.move_direction(Direction.UP, -movement)
109
+ elif self.position < index:
110
+ self.move_direction(Direction.DOWN, movement)
111
+ else:
112
+ return
113
+ else:
114
+ return
115
+
116
+ def run(self, default_choice: int = 0):
117
+ "Start the menu and return the selected choice"
118
+ if self.prompt:
119
+ linebreak()
120
+ forceWrite(self.prompt, "\n")
121
+ if in_colab:
122
+ forceWrite("Please input a choice index (starting from 0), and press enter", "\n")
123
+ else:
124
+ forceWrite("Please select a choice using the arrow or number keys, and selecting with enter", "\n")
125
+ self.position = default_choice
126
+ for i in range(len(self.choices)):
127
+ self.print_choice(i)
128
+ forceWrite("\n")
129
+ move_cursor(len(self.choices) - self.position, "UP")
130
+ with cursor.hide():
131
+ while True:
132
+ if in_colab:
133
+ try:
134
+ choice = int(builtins.input())
135
+ except ValueError:
136
+ choice = default_choice
137
+ else:
138
+ choice = self.handle_input()
139
+ if choice is not None:
140
+ reset_cursor()
141
+ for _ in range(len(self.choices) + 1):
142
+ move_cursor(1, "UP")
143
+ clear_line()
144
+ self.write_choice(choice, "\n")
145
+ return choice
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.78 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/examples.cpython-312.pyc ADDED
Binary file (6.93 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/testing.cpython-312.pyc ADDED
Binary file (42.3 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/__pycache__/training.cpython-312.pyc ADDED
Binary file (7.46 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_ddp_comm_hook.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState
17
+ from accelerate.utils import is_hpu_available
18
+
19
+
20
+ class MockModel(torch.nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ torch.manual_seed(0)
24
+ self.p = torch.nn.Parameter(torch.randn(40, 20))
25
+
26
+ def forward(self, x, rank):
27
+ return self.p * (x ** (1 + rank))
28
+
29
+
30
+ def _run_and_get_grads(model, rank):
31
+ torch.manual_seed(2024)
32
+ input = torch.randn(40, 20)
33
+ output = model(input, rank)
34
+ output.mean().backward()
35
+ param = next(model.parameters())
36
+ return param.grad
37
+
38
+
39
+ def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option):
40
+ ddp_kwargs = DistributedDataParallelKwargs(
41
+ comm_hook=comm_hook,
42
+ comm_wrapper=comm_wrapper,
43
+ comm_state_option=comm_state_option,
44
+ )
45
+ accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
46
+
47
+ model = accelerator.prepare(MockModel())
48
+ hook_grads = _run_and_get_grads(model, accelerator.local_process_index)
49
+
50
+ reference_model = torch.nn.parallel.DistributedDataParallel(
51
+ MockModel().to(accelerator.device),
52
+ device_ids=[accelerator.local_process_index],
53
+ output_device=accelerator.local_process_index,
54
+ )
55
+ reference_grads = _run_and_get_grads(reference_model, accelerator.local_process_index)
56
+
57
+ torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-2, atol=1e-2)
58
+
59
+
60
+ def main():
61
+ for comm_hook, comm_wrapper, comm_state_option in [
62
+ (DDPCommunicationHookType.NO, DDPCommunicationHookType.NO, {}),
63
+ (DDPCommunicationHookType.FP16, DDPCommunicationHookType.NO, {}),
64
+ (DDPCommunicationHookType.BF16, DDPCommunicationHookType.NO, {}),
65
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {}),
66
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.FP16, {}),
67
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BF16, {}),
68
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {"matrix_approximation_rank": 2}),
69
+ (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.NO, {}),
70
+ (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.FP16, {}),
71
+ (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.BF16, {}),
72
+ ]:
73
+ if is_hpu_available():
74
+ HPU_UNSUPPORTED_COMM_HOOKS = {DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16}
75
+ if comm_hook in HPU_UNSUPPORTED_COMM_HOOKS or comm_wrapper in HPU_UNSUPPORTED_COMM_HOOKS:
76
+ print(f"Skipping test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper} on HPU")
77
+ continue
78
+
79
+ print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
80
+ test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
81
+ PartialState().destroy_process_group()
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_distributed_data_loop.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import pickle
18
+ import tempfile
19
+ import warnings
20
+ from unittest.mock import Mock
21
+
22
+ import torch
23
+ from torch.utils.data import (
24
+ BatchSampler,
25
+ DataLoader,
26
+ Dataset,
27
+ IterableDataset,
28
+ RandomSampler,
29
+ TensorDataset,
30
+ default_collate,
31
+ )
32
+
33
+ from accelerate.accelerator import Accelerator, DataLoaderConfiguration
34
+ from accelerate.utils.dataclasses import DistributedType
35
+
36
+
37
+ NUM_ELEMENTS = 22
38
+ NUM_WORKERS = 4
39
+ BATCH_SIZE = 4
40
+
41
+
42
+ class DummyDataset(Dataset):
43
+ def __len__(self):
44
+ return NUM_ELEMENTS
45
+
46
+ def __getitem__(self, index):
47
+ squeeze = False
48
+
49
+ if isinstance(index, int):
50
+ index = [index]
51
+ squeeze = True
52
+ elif isinstance(index, slice):
53
+ index = list(range(*index.indices(self.size)))
54
+ else:
55
+ index = list(index)
56
+
57
+ batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index]
58
+
59
+ if squeeze:
60
+ batch = batch[0]
61
+
62
+ return batch
63
+
64
+
65
+ class DummyIterableDataset(IterableDataset):
66
+ def __init__(self, data):
67
+ self.data = data
68
+
69
+ def __iter__(self):
70
+ yield from self.data
71
+
72
+
73
+ def create_accelerator(even_batches=True):
74
+ dataloader_config = DataLoaderConfiguration(even_batches=even_batches)
75
+ accelerator = Accelerator(dataloader_config=dataloader_config)
76
+ assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
77
+ return accelerator
78
+
79
+
80
+ def create_dataloader(
81
+ accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
82
+ ):
83
+ """
84
+ Create a simple DataLoader to use during the test cases
85
+ """
86
+ values = torch.as_tensor(range(dataset_size))
87
+ if shuffle:
88
+ values = values[torch.randperm(values.size(0))]
89
+ if iterable:
90
+ dataset = DummyIterableDataset(values)
91
+ else:
92
+ dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
93
+
94
+ dl = DataLoader(dataset, batch_size=batch_size)
95
+ dl = accelerator.prepare(dl)
96
+
97
+ return dl
98
+
99
+
100
+ def verify_dataloader_batch_sizes(
101
+ accelerator: Accelerator,
102
+ dataset_size: int,
103
+ batch_size: int,
104
+ process_0_expected_batch_sizes: list[int],
105
+ process_1_expected_batch_sizes: list[int],
106
+ ):
107
+ """
108
+ A helper function for verifying the batch sizes coming from a prepared dataloader in each process
109
+ """
110
+ dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size)
111
+
112
+ batch_sizes = [len(batch[0]) for batch in dl]
113
+
114
+ if accelerator.process_index == 0:
115
+ assert batch_sizes == process_0_expected_batch_sizes
116
+ elif accelerator.process_index == 1:
117
+ assert batch_sizes == process_1_expected_batch_sizes
118
+
119
+
120
+ def test_default_ensures_even_batch_sizes():
121
+ accelerator = create_accelerator()
122
+
123
+ # without padding, we would expect a different number of batches
124
+ verify_dataloader_batch_sizes(
125
+ accelerator,
126
+ dataset_size=3,
127
+ batch_size=1,
128
+ process_0_expected_batch_sizes=[1, 1],
129
+ process_1_expected_batch_sizes=[1, 1],
130
+ )
131
+
132
+ # without padding, we would expect the same number of batches, but different sizes
133
+ verify_dataloader_batch_sizes(
134
+ accelerator,
135
+ dataset_size=7,
136
+ batch_size=2,
137
+ process_0_expected_batch_sizes=[2, 2],
138
+ process_1_expected_batch_sizes=[2, 2],
139
+ )
140
+
141
+
142
+ def test_can_disable_even_batches():
143
+ accelerator = create_accelerator(even_batches=False)
144
+
145
+ verify_dataloader_batch_sizes(
146
+ accelerator,
147
+ dataset_size=3,
148
+ batch_size=1,
149
+ process_0_expected_batch_sizes=[1, 1],
150
+ process_1_expected_batch_sizes=[1],
151
+ )
152
+
153
+ verify_dataloader_batch_sizes(
154
+ accelerator,
155
+ dataset_size=7,
156
+ batch_size=2,
157
+ process_0_expected_batch_sizes=[2, 2],
158
+ process_1_expected_batch_sizes=[2, 1],
159
+ )
160
+
161
+
162
+ def test_can_join_uneven_inputs():
163
+ accelerator = create_accelerator(even_batches=False)
164
+
165
+ model = torch.nn.Linear(1, 1)
166
+ ddp_model = accelerator.prepare(model)
167
+
168
+ dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
169
+
170
+ batch_idxs = []
171
+ with accelerator.join_uneven_inputs([ddp_model]):
172
+ for batch_idx, batch in enumerate(dl):
173
+ output = ddp_model(batch[0].float())
174
+ loss = output.sum()
175
+ loss.backward()
176
+ batch_idxs.append(batch_idx)
177
+
178
+ accelerator.wait_for_everyone()
179
+
180
+ if accelerator.process_index == 0:
181
+ assert batch_idxs == [0, 1]
182
+ elif accelerator.process_index == 1:
183
+ assert batch_idxs == [0]
184
+
185
+
186
+ def test_join_raises_warning_for_non_ddp_distributed(accelerator):
187
+ with warnings.catch_warnings(record=True) as w:
188
+ with accelerator.join_uneven_inputs([Mock()]):
189
+ pass
190
+
191
+ assert issubclass(w[-1].category, UserWarning)
192
+ assert "only supported for multi-GPU" in str(w[-1].message)
193
+
194
+
195
+ def test_join_can_override_even_batches():
196
+ default_even_batches = True
197
+ overridden_even_batches = False
198
+ accelerator = create_accelerator(even_batches=default_even_batches)
199
+ model = torch.nn.Linear(1, 1)
200
+ ddp_model = accelerator.prepare(model)
201
+ train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
202
+ valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
203
+
204
+ with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
205
+ train_dl_overridden_value = train_dl.batch_sampler.even_batches
206
+ valid_dl_overridden_value = valid_dl.batch_sampler.even_batches
207
+
208
+ assert train_dl_overridden_value == overridden_even_batches
209
+ assert valid_dl_overridden_value == overridden_even_batches
210
+ assert train_dl.batch_sampler.even_batches == default_even_batches
211
+ assert valid_dl.batch_sampler.even_batches == default_even_batches
212
+
213
+
214
+ def test_join_can_override_for_mixed_type_dataloaders():
215
+ default_even_batches = True
216
+ overridden_even_batches = False
217
+ accelerator = create_accelerator(even_batches=default_even_batches)
218
+ model = torch.nn.Linear(1, 1)
219
+ ddp_model = accelerator.prepare(model)
220
+ create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
221
+ batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
222
+
223
+ with warnings.catch_warnings():
224
+ warnings.filterwarnings("ignore")
225
+ try:
226
+ with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
227
+ batch_dl_overridden_value = batch_dl.batch_sampler.even_batches
228
+ except AttributeError:
229
+ # ensure attribute error is not raised when processing iterable dl
230
+ raise AssertionError
231
+
232
+ assert batch_dl_overridden_value == overridden_even_batches
233
+ assert batch_dl.batch_sampler.even_batches == default_even_batches
234
+
235
+
236
+ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
237
+ accelerator = create_accelerator()
238
+ model = torch.nn.Linear(1, 1)
239
+ ddp_model = accelerator.prepare(model)
240
+ create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
241
+
242
+ with warnings.catch_warnings(record=True) as w:
243
+ with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
244
+ pass
245
+
246
+ assert issubclass(w[-1].category, UserWarning)
247
+ assert "only supported for map-style datasets" in str(w[-1].message)
248
+
249
+
250
+ def test_pickle_accelerator():
251
+ accelerator = create_accelerator()
252
+ data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
253
+ _ = accelerator.prepare(data_loader)
254
+ pickled_accelerator = pickle.dumps(accelerator)
255
+ unpickled_accelerator = pickle.loads(pickled_accelerator)
256
+ # TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
257
+ assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
258
+
259
+
260
+ def test_data_loader(data_loader, accelerator):
261
+ # Prepare the DataLoader
262
+ data_loader = accelerator.prepare(data_loader)
263
+
264
+ all_examples = []
265
+ for i, batch in enumerate(data_loader):
266
+ index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"]))
267
+ all_examples.extend(index.detach().cpu().numpy().tolist())
268
+
269
+ # Sort the examples
270
+ sorted_all_examples = sorted(all_examples)
271
+
272
+ # Check if all elements are present in the sorted list of iterated samples
273
+ assert len(set(sorted_all_examples)) == NUM_ELEMENTS, (
274
+ "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
275
+ )
276
+
277
+
278
+ def test_stateful_dataloader(accelerator):
279
+ """
280
+ Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
281
+ resumed from the saved state.
282
+
283
+ The result should be the same as the rest of the data that iterated over after saving.
284
+ """
285
+ old_dataloader_config = accelerator.dataloader_config
286
+ try:
287
+ accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
288
+ prepared_dl = create_dataloader(
289
+ accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
290
+ )
291
+ untrained_batches = []
292
+ # Calculate what step that will be
293
+ total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
294
+ last_batch_num = total_batches - 1
295
+ for step, batch in enumerate(prepared_dl):
296
+ # Step just before
297
+ if step == last_batch_num - 1:
298
+ state_dict = prepared_dl.state_dict()
299
+ if step >= last_batch_num:
300
+ # Otherwise grab the "unseen" batches
301
+ untrained_batches.append(batch)
302
+ not_skipped_batches = accelerator.gather(untrained_batches)
303
+ prepared_dl.load_state_dict(state_dict)
304
+ resumed_batches = []
305
+ for batch in prepared_dl:
306
+ resumed_batches.append(batch)
307
+ resumed_batches = accelerator.gather(resumed_batches)
308
+ for b1, b2 in zip(not_skipped_batches, resumed_batches):
309
+ for v1, v2 in zip(b1, b2):
310
+ assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
311
+ finally:
312
+ accelerator.dataloader_config = old_dataloader_config
313
+
314
+
315
+ def test_stateful_dataloader_save_state(accelerator):
316
+ """
317
+ Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
318
+ and then resumed from the saved state.
319
+
320
+ The result should be the same as the rest of the data that iterated over after saving.
321
+ """
322
+ old_dataloader_config = accelerator.dataloader_config
323
+ try:
324
+ with tempfile.TemporaryDirectory() as tmpdir:
325
+ accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
326
+ prepared_dl = create_dataloader(
327
+ accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
328
+ )
329
+ untrained_batches = []
330
+ # Calculate what step that will be
331
+ total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
332
+ last_batch_num = total_batches - 1
333
+ for step, batch in enumerate(prepared_dl):
334
+ # Step just before
335
+ if step == last_batch_num - 1:
336
+ accelerator.save_state(tmpdir)
337
+ if step >= last_batch_num:
338
+ # Otherwise grab the "unseen" batches
339
+ untrained_batches.append(batch)
340
+ not_skipped_batches = accelerator.gather(untrained_batches)
341
+ accelerator.load_state(tmpdir)
342
+ resumed_batches = []
343
+ for batch in prepared_dl:
344
+ resumed_batches.append(batch)
345
+ resumed_batches = accelerator.gather(resumed_batches)
346
+ for b1, b2 in zip(not_skipped_batches, resumed_batches):
347
+ for v1, v2 in zip(b1, b2):
348
+ assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
349
+ finally:
350
+ accelerator.dataloader_config = old_dataloader_config
351
+
352
+
353
+ def main():
354
+ accelerator = create_accelerator()
355
+ torch.manual_seed(accelerator.process_index)
356
+
357
+ accelerator.print("Test that even_batches variable ensures uniform batches across processes")
358
+ test_default_ensures_even_batch_sizes()
359
+
360
+ accelerator.print("Run tests with even_batches disabled")
361
+ test_can_disable_even_batches()
362
+
363
+ accelerator.print("Test joining uneven inputs")
364
+ test_can_join_uneven_inputs()
365
+
366
+ accelerator.print("Test overriding even_batches when joining uneven inputs")
367
+ test_join_can_override_even_batches()
368
+
369
+ accelerator.print("Test overriding even_batches for mixed dataloader types")
370
+ test_join_can_override_for_mixed_type_dataloaders()
371
+
372
+ accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders")
373
+ test_join_raises_warning_for_iterable_when_overriding_even_batches()
374
+
375
+ accelerator.print("Test join with non DDP distributed raises warning")
376
+ original_state = accelerator.state.distributed_type
377
+ accelerator.state.distributed_type = DistributedType.FSDP
378
+ test_join_raises_warning_for_non_ddp_distributed(accelerator)
379
+ accelerator.state.distributed_type = original_state
380
+
381
+ accelerator.print("Test pickling an accelerator")
382
+ test_pickle_accelerator()
383
+
384
+ dataset = DummyDataset()
385
+
386
+ accelerator.print("Test DataLoader with shuffle=False")
387
+ loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
388
+ test_data_loader(loader, accelerator)
389
+
390
+ accelerator.print("Test DataLoader with shuffle=True")
391
+ loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
392
+ test_data_loader(loader, accelerator)
393
+
394
+ accelerator.print("Test DataLoader with batch_sampler")
395
+ sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
396
+ loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS)
397
+ test_data_loader(loader, accelerator)
398
+
399
+ accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`")
400
+ sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
401
+ loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
402
+ test_data_loader(loader, accelerator)
403
+ test_stateful_dataloader(accelerator)
404
+ test_stateful_dataloader_save_state(accelerator)
405
+
406
+ accelerator.end_training()
407
+
408
+
409
+ if __name__ == "__main__":
410
+ main()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_merge_weights.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import gc
16
+ import logging
17
+ import shutil
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+ from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType
23
+ from torch.utils.data import DataLoader
24
+
25
+ from accelerate import Accelerator, FullyShardedDataParallelPlugin
26
+ from accelerate.commands.merge import merge_command, merge_command_parser
27
+ from accelerate.state import AcceleratorState
28
+ from accelerate.test_utils import torch_device
29
+ from accelerate.test_utils.training import RegressionDataset
30
+ from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model
31
+
32
+
33
+ logging.basicConfig(level=logging.INFO)
34
+
35
+ parser = merge_command_parser()
36
+
37
+
38
+ class TinyModel(torch.nn.Module):
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.linear1 = torch.nn.Linear(16, 16)
42
+ self.activation = torch.nn.ReLU()
43
+ self.linear2 = torch.nn.Linear(16, 16)
44
+ self.softmax = torch.nn.Softmax()
45
+
46
+ def forward(self, x):
47
+ return self.linear2(self.activation(self.linear1(x)))
48
+
49
+
50
+ def setup():
51
+ if AcceleratorState._shared_state != {}:
52
+ AcceleratorState()._reset_state()
53
+ plugin = FullyShardedDataParallelPlugin(
54
+ sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
55
+ )
56
+ model = TinyModel()
57
+ with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
58
+ plugin.set_auto_wrap_policy(model)
59
+ accelerator = Accelerator(fsdp_plugin=plugin)
60
+ model = accelerator.prepare(model)
61
+ return model, plugin, accelerator
62
+
63
+
64
+ def mock_training(accelerator, model):
65
+ train_set = RegressionDataset(length=128, seed=42)
66
+ train_dl = DataLoader(train_set, batch_size=16, shuffle=False)
67
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
68
+
69
+ train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
70
+ for _ in range(3):
71
+ for batch in train_dl:
72
+ model.zero_grad()
73
+ output = model(batch["x"])
74
+ loss = torch.nn.functional.mse_loss(output, batch["y"])
75
+ accelerator.backward(loss)
76
+ optimizer.step()
77
+ return model
78
+
79
+
80
+ def check_weights(operation, state_1, state_2):
81
+ for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
82
+ if operation == "same":
83
+ assert torch.allclose(weight_1, weight_2)
84
+ else:
85
+ assert not torch.allclose(weight_1, weight_2)
86
+
87
+
88
+ def check_safetensors_weights(path, model):
89
+ safe_state_dict = load_file(path / "model.safetensors")
90
+ safe_loaded_model = TinyModel().to(torch_device)
91
+ check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
92
+ safe_loaded_model.load_state_dict(safe_state_dict)
93
+ check_weights("same", model.state_dict(), safe_loaded_model.state_dict())
94
+
95
+
96
+ def check_pytorch_weights(path, model):
97
+ nonsafe_state_dict = torch.load(path / "pytorch_model.bin", weights_only=True)
98
+ nonsafe_loaded_model = TinyModel().to(torch_device)
99
+ check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
100
+ nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
101
+ check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())
102
+
103
+
104
+ def test_merge_weights_safetensors(model, path):
105
+ # Should now be saved at `path/merged.safetensors`
106
+ merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True)
107
+ check_safetensors_weights(path, model)
108
+
109
+
110
+ def test_merge_weights_command_safetensors(model, path):
111
+ args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)])
112
+ merge_command(args)
113
+ check_safetensors_weights(path, model)
114
+
115
+
116
+ def test_merge_weights_pytorch(model, path):
117
+ # Should now be saved at `path/merged.bin`
118
+ merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False)
119
+ check_pytorch_weights(path, model)
120
+
121
+
122
+ def test_merge_weights_command_pytorch(model, path):
123
+ args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"])
124
+ merge_command(args)
125
+ check_pytorch_weights(path, model)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ # Note this test requires at least two accelerators!
130
+ model, plugin, accelerator = setup()
131
+ if accelerator.num_processes > 1:
132
+ try:
133
+ # Initial setup for things
134
+ out_path = Path("test_merge_weights_fsdp_weights")
135
+ if not out_path.exists():
136
+ out_path.mkdir(parents=True, exist_ok=True)
137
+
138
+ # Train briefly once weights aren't the baseline
139
+ model = mock_training(accelerator, model)
140
+ accelerator.wait_for_everyone()
141
+
142
+ gc.collect() # Needed for some lingering refs after training
143
+ save_fsdp_model(plugin, accelerator, model, out_path)
144
+ accelerator.wait_for_everyone()
145
+
146
+ # Finally we can test
147
+ test_merge_weights_safetensors(model, out_path)
148
+ test_merge_weights_command_safetensors(model, out_path)
149
+ test_merge_weights_pytorch(model, out_path)
150
+ test_merge_weights_command_pytorch(model, out_path)
151
+ except Exception:
152
+ raise
153
+ finally:
154
+ # Cleanup in case of any failures
155
+ if accelerator.is_main_process:
156
+ shutil.rmtree(out_path)
157
+ accelerator.wait_for_everyone()
158
+ accelerator.end_training()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/test_notebook.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Test file to ensure that in general certain situational setups for notebooks work.
16
+ """
17
+
18
+ import os
19
+ import time
20
+ from multiprocessing import Queue
21
+
22
+ from pytest import mark, raises
23
+ from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
24
+
25
+ from accelerate import PartialState, notebook_launcher
26
+ from accelerate.test_utils import require_bnb
27
+ from accelerate.utils import is_bnb_available
28
+
29
+
30
+ def basic_function():
31
+ # Just prints the PartialState
32
+ print(f"PartialState:\n{PartialState()}")
33
+
34
+
35
+ def tough_nut_function(queue: Queue):
36
+ if queue.empty():
37
+ return
38
+ trial = queue.get()
39
+ if trial > 0:
40
+ queue.put(trial - 1)
41
+ raise RuntimeError("The nut hasn't cracked yet! Try again.")
42
+
43
+ print(f"PartialState:\n{PartialState()}")
44
+
45
+
46
+ def bipolar_sleep_function(sleep_sec: int):
47
+ state = PartialState()
48
+ if state.process_index % 2 == 0:
49
+ raise RuntimeError("I'm an even process. I don't like to sleep.")
50
+ else:
51
+ time.sleep(sleep_sec)
52
+
53
+
54
+ NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1))
55
+
56
+
57
+ def test_can_initialize():
58
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
59
+
60
+
61
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends")
62
+ def test_static_rdzv_backend():
63
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static")
64
+
65
+
66
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends")
67
+ def test_c10d_rdzv_backend():
68
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d")
69
+
70
+
71
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance")
72
+ def test_fault_tolerant(max_restarts: int = 3):
73
+ queue = Queue()
74
+ queue.put(max_restarts)
75
+ notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts)
76
+
77
+
78
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring")
79
+ def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100):
80
+ start_time = time.time()
81
+ with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."):
82
+ notebook_launcher(
83
+ bipolar_sleep_function,
84
+ (sleep_sec,),
85
+ num_processes=NUM_PROCESSES,
86
+ monitor_interval=monitor_interval,
87
+ )
88
+ assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time."
89
+
90
+
91
+ @require_bnb
92
+ def test_problematic_imports():
93
+ with raises(RuntimeError, match="Please keep these imports"):
94
+ import bitsandbytes as bnb # noqa: F401
95
+
96
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
97
+
98
+
99
+ def main():
100
+ print("Test basic notebook can be ran")
101
+ test_can_initialize()
102
+ print("Test static rendezvous backend")
103
+ test_static_rdzv_backend()
104
+ print("Test c10d rendezvous backend")
105
+ test_c10d_rdzv_backend()
106
+ print("Test fault tolerant")
107
+ test_fault_tolerant()
108
+ print("Test monitoring")
109
+ test_monitoring()
110
+ if is_bnb_available():
111
+ print("Test problematic imports (bnb)")
112
+ test_problematic_imports()
113
+ if NUM_PROCESSES > 1:
114
+ PartialState().destroy_process_group()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()