salmankhanpm commited on
Commit
4f4376a
·
verified ·
1 Parent(s): 69e1a8d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. accelerate/commands/__init__.py +13 -0
  2. accelerate/commands/accelerate_cli.py +54 -0
  3. accelerate/commands/config/__init__.py +52 -0
  4. accelerate/commands/config/cluster.py +939 -0
  5. accelerate/commands/config/config.py +89 -0
  6. accelerate/commands/config/config_args.py +252 -0
  7. accelerate/commands/config/config_utils.py +122 -0
  8. accelerate/commands/config/default.py +172 -0
  9. accelerate/commands/config/sagemaker.py +274 -0
  10. accelerate/commands/config/update.py +63 -0
  11. accelerate/commands/env.py +143 -0
  12. accelerate/commands/estimate.py +318 -0
  13. accelerate/commands/launch.py +1415 -0
  14. accelerate/commands/menu/__init__.py +14 -0
  15. accelerate/commands/menu/cursor.py +65 -0
  16. accelerate/commands/menu/helpers.py +59 -0
  17. accelerate/commands/menu/input.py +84 -0
  18. accelerate/commands/menu/keymap.py +133 -0
  19. accelerate/commands/menu/selection_menu.py +145 -0
  20. accelerate/commands/merge.py +69 -0
  21. accelerate/commands/test.py +65 -0
  22. accelerate/commands/to_fsdp2.py +172 -0
  23. accelerate/commands/tpu.py +157 -0
  24. accelerate/commands/utils.py +123 -0
  25. accelerate/test_utils/__init__.py +66 -0
  26. accelerate/test_utils/examples.py +148 -0
  27. accelerate/test_utils/scripts/__init__.py +13 -0
  28. accelerate/test_utils/scripts/external_deps/__init__.py +13 -0
  29. accelerate/test_utils/scripts/external_deps/test_checkpointing.py +269 -0
  30. accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py +131 -0
  31. accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py +331 -0
  32. accelerate/test_utils/scripts/external_deps/test_metrics.py +307 -0
  33. accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py +323 -0
  34. accelerate/test_utils/scripts/external_deps/test_performance.py +299 -0
  35. accelerate/test_utils/scripts/external_deps/test_pippy.py +117 -0
  36. accelerate/test_utils/scripts/external_deps/test_zero3_integration.py +59 -0
  37. accelerate/test_utils/scripts/test_cli.py +32 -0
  38. accelerate/test_utils/scripts/test_ddp_comm_hook.py +85 -0
  39. accelerate/test_utils/scripts/test_distributed_data_loop.py +429 -0
  40. accelerate/test_utils/scripts/test_merge_weights.py +158 -0
  41. accelerate/test_utils/scripts/test_notebook.py +125 -0
  42. accelerate/test_utils/scripts/test_ops.py +181 -0
  43. accelerate/test_utils/scripts/test_script.py +909 -0
  44. accelerate/test_utils/scripts/test_sync.py +413 -0
  45. accelerate/test_utils/testing.py +889 -0
  46. accelerate/test_utils/training.py +150 -0
  47. accelerate/utils/__init__.py +304 -0
  48. accelerate/utils/ao.py +143 -0
  49. accelerate/utils/bnb.py +464 -0
  50. accelerate/utils/constants.py +108 -0
accelerate/commands/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
accelerate/commands/accelerate_cli.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from accelerate.commands.config import get_config_parser
18
+ from accelerate.commands.env import env_command_parser
19
+ from accelerate.commands.estimate import estimate_command_parser
20
+ from accelerate.commands.launch import launch_command_parser
21
+ from accelerate.commands.merge import merge_command_parser
22
+ from accelerate.commands.test import test_command_parser
23
+ from accelerate.commands.to_fsdp2 import to_fsdp2_command_parser
24
+ from accelerate.commands.tpu import tpu_command_parser
25
+ from accelerate.commands.utils import CustomArgumentParser
26
+
27
+
28
+ def main():
29
+ parser = CustomArgumentParser("Accelerate CLI tool", usage="accelerate <command> [<args>]", allow_abbrev=False)
30
+ subparsers = parser.add_subparsers(help="accelerate command helpers")
31
+
32
+ # Register commands
33
+ get_config_parser(subparsers=subparsers)
34
+ estimate_command_parser(subparsers=subparsers)
35
+ env_command_parser(subparsers=subparsers)
36
+ launch_command_parser(subparsers=subparsers)
37
+ merge_command_parser(subparsers=subparsers)
38
+ tpu_command_parser(subparsers=subparsers)
39
+ test_command_parser(subparsers=subparsers)
40
+ to_fsdp2_command_parser(subparsers=subparsers)
41
+
42
+ # Let's go
43
+ args = parser.parse_args()
44
+
45
+ if not hasattr(args, "func"):
46
+ parser.print_help()
47
+ exit(1)
48
+
49
+ # Run
50
+ args.func(args)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
accelerate/commands/config/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ from .config import config_command_parser
20
+ from .config_args import default_config_file, load_config_from_file # noqa: F401
21
+ from .default import default_command_parser
22
+ from .update import update_command_parser
23
+
24
+
25
+ def get_config_parser(subparsers=None):
26
+ parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
27
+ # The main config parser
28
+ config_parser = config_command_parser(subparsers)
29
+ # The subparser to add commands to
30
+ subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand")
31
+
32
+ # Then add other parsers with the parent parser
33
+ default_command_parser(subcommands, parents=[parent_parser])
34
+ update_command_parser(subcommands, parents=[parent_parser])
35
+
36
+ return config_parser
37
+
38
+
39
+ def main():
40
+ config_parser = get_config_parser()
41
+ args = config_parser.parse_args()
42
+
43
+ if not hasattr(args, "func"):
44
+ config_parser.print_help()
45
+ exit(1)
46
+
47
+ # Run
48
+ args.func(args)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
accelerate/commands/config/cluster.py ADDED
@@ -0,0 +1,939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+
19
+ from ...utils import (
20
+ ComputeEnvironment,
21
+ DistributedType,
22
+ is_deepspeed_available,
23
+ is_fp8_available,
24
+ is_hpu_available,
25
+ is_mlu_available,
26
+ is_mps_available,
27
+ is_msamp_available,
28
+ is_musa_available,
29
+ is_neuron_available,
30
+ is_npu_available,
31
+ is_sdaa_available,
32
+ is_torchao_available,
33
+ is_transformer_engine_available,
34
+ is_transformers_available,
35
+ is_xpu_available,
36
+ )
37
+ from ...utils.constants import (
38
+ DEEPSPEED_MULTINODE_LAUNCHERS,
39
+ FSDP2_STATE_DICT_TYPE,
40
+ FSDP_AUTO_WRAP_POLICY,
41
+ FSDP_BACKWARD_PREFETCH,
42
+ FSDP_SHARDING_STRATEGY,
43
+ FSDP_STATE_DICT_TYPE,
44
+ TORCH_DYNAMO_MODES,
45
+ )
46
+ from .config_args import ClusterConfig
47
+ from .config_utils import (
48
+ DYNAMO_BACKENDS,
49
+ _ask_field,
50
+ _ask_options,
51
+ _convert_distributed_mode,
52
+ _convert_dynamo_backend,
53
+ _convert_fp8_backend,
54
+ _convert_mixed_precision,
55
+ _convert_yes_no_to_bool,
56
+ )
57
+
58
+
59
+ def get_cluster_input():
60
+ distributed_type = _ask_options(
61
+ "Which type of machine are you using?",
62
+ [
63
+ "No distributed training",
64
+ "multi-CPU",
65
+ "multi-XPU",
66
+ "multi-HPU",
67
+ "multi-GPU",
68
+ "multi-NPU",
69
+ "multi-MLU",
70
+ "multi-SDAA",
71
+ "multi-MUSA",
72
+ "multi-NEURON",
73
+ "TPU",
74
+ ],
75
+ _convert_distributed_mode,
76
+ )
77
+
78
+ machine_rank = 0
79
+ num_machines = 1
80
+ num_processes = 1
81
+ gpu_ids = None
82
+ main_process_ip = None
83
+ main_process_port = None
84
+ rdzv_backend = "static"
85
+ same_network = True
86
+ debug = False
87
+
88
+ if distributed_type in [
89
+ DistributedType.MULTI_GPU,
90
+ DistributedType.MULTI_MLU,
91
+ DistributedType.MULTI_SDAA,
92
+ DistributedType.MULTI_MUSA,
93
+ DistributedType.MULTI_NPU,
94
+ DistributedType.MULTI_XPU,
95
+ DistributedType.MULTI_CPU,
96
+ DistributedType.MULTI_HPU,
97
+ DistributedType.MULTI_NEURON,
98
+ ]:
99
+ num_machines = _ask_field(
100
+ "How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
101
+ int,
102
+ default=1,
103
+ )
104
+ if num_machines > 1:
105
+ machine_rank = _ask_options(
106
+ "What is the rank of this machine?",
107
+ list(range(num_machines)),
108
+ int,
109
+ )
110
+ main_process_ip = _ask_field(
111
+ "What is the IP address of the machine that will host the main process? ",
112
+ )
113
+ main_process_port = _ask_field(
114
+ "What is the port you will use to communicate with the main process? ",
115
+ int,
116
+ )
117
+ same_network = _ask_field(
118
+ "Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: ",
119
+ _convert_yes_no_to_bool,
120
+ default=True,
121
+ error_message="Please enter yes or no.",
122
+ )
123
+ if not same_network:
124
+ rdzv_backend = _ask_field(
125
+ "What rendezvous backend will you use? ('static', 'c10d', ...): ", default="static"
126
+ )
127
+ debug = _ask_field(
128
+ "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
129
+ _convert_yes_no_to_bool,
130
+ default=False,
131
+ error_message="Please enter yes or no.",
132
+ )
133
+
134
+ if distributed_type == DistributedType.NO:
135
+ use_cpu = _ask_field(
136
+ "Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:",
137
+ _convert_yes_no_to_bool,
138
+ default=False,
139
+ error_message="Please enter yes or no.",
140
+ )
141
+ elif distributed_type == DistributedType.MULTI_CPU:
142
+ use_cpu = True
143
+ else:
144
+ use_cpu = False
145
+
146
+ mpirun_config = {}
147
+
148
+ if use_cpu:
149
+ if distributed_type == DistributedType.MULTI_CPU:
150
+ use_mpirun = _ask_field(
151
+ "Do you want accelerate to launch mpirun? [yes/NO]: ",
152
+ _convert_yes_no_to_bool,
153
+ default=False,
154
+ error_message="Please enter yes or no.",
155
+ )
156
+ if use_mpirun:
157
+ mpirun_hostfile = _ask_field(
158
+ "Please enter the path to the hostfile to use with mpirun [~/hostfile]: ",
159
+ str,
160
+ default="~/hostfile",
161
+ )
162
+ mpirun_config["mpirun_hostfile"] = os.path.expanduser(mpirun_hostfile.strip())
163
+
164
+ dynamo_config = {}
165
+ use_dynamo = _ask_field(
166
+ "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
167
+ _convert_yes_no_to_bool,
168
+ default=False,
169
+ error_message="Please enter yes or no.",
170
+ )
171
+ if use_dynamo:
172
+ prefix = "dynamo_"
173
+ dynamo_config[prefix + "backend"] = _ask_options(
174
+ "Which dynamo backend would you like to use?",
175
+ [x.lower() for x in DYNAMO_BACKENDS],
176
+ _convert_dynamo_backend,
177
+ default=2,
178
+ )
179
+ use_custom_options = _ask_field(
180
+ "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
181
+ _convert_yes_no_to_bool,
182
+ default=False,
183
+ error_message="Please enter yes or no.",
184
+ )
185
+
186
+ if use_custom_options:
187
+ dynamo_config[prefix + "mode"] = _ask_options(
188
+ "Which mode do you want to use?",
189
+ TORCH_DYNAMO_MODES,
190
+ lambda x: TORCH_DYNAMO_MODES[int(x)],
191
+ default=0,
192
+ )
193
+ dynamo_config[prefix + "use_fullgraph"] = _ask_field(
194
+ "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
195
+ _convert_yes_no_to_bool,
196
+ default=False,
197
+ error_message="Please enter yes or no.",
198
+ )
199
+ dynamo_config[prefix + "use_dynamic"] = _ask_field(
200
+ "Do you want to enable dynamic shape tracing? [yes/NO]: ",
201
+ _convert_yes_no_to_bool,
202
+ default=False,
203
+ error_message="Please enter yes or no.",
204
+ )
205
+ dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
206
+ "Do you want to enable regional compilation? [yes/NO]: ",
207
+ _convert_yes_no_to_bool,
208
+ default=False,
209
+ error_message="Please enter yes or no.",
210
+ )
211
+
212
+ use_mps = not use_cpu and is_mps_available()
213
+ deepspeed_config = {}
214
+ if (
215
+ distributed_type
216
+ in [
217
+ DistributedType.MULTI_GPU,
218
+ DistributedType.MULTI_XPU,
219
+ DistributedType.MULTI_HPU,
220
+ DistributedType.MULTI_NPU,
221
+ DistributedType.MULTI_MLU,
222
+ DistributedType.MULTI_SDAA,
223
+ DistributedType.MULTI_MUSA,
224
+ DistributedType.MULTI_NEURON,
225
+ DistributedType.NO,
226
+ ]
227
+ and not use_mps
228
+ ):
229
+ use_deepspeed = _ask_field(
230
+ "Do you want to use DeepSpeed? [yes/NO]: ",
231
+ _convert_yes_no_to_bool,
232
+ default=False,
233
+ error_message="Please enter yes or no.",
234
+ )
235
+ if use_deepspeed:
236
+ if distributed_type is DistributedType.MULTI_NEURON:
237
+ raise RuntimeError("DeepSpeed is not supported on Neuron devices.")
238
+
239
+ distributed_type = DistributedType.DEEPSPEED
240
+ assert is_deepspeed_available(), (
241
+ "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
242
+ )
243
+
244
+ if distributed_type == DistributedType.DEEPSPEED:
245
+ use_deepspeed_config = _ask_field(
246
+ "Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ",
247
+ _convert_yes_no_to_bool,
248
+ default=False,
249
+ error_message="Please enter yes or no.",
250
+ )
251
+ if use_deepspeed_config:
252
+ deepspeed_config["deepspeed_config_file"] = _ask_field(
253
+ "Please enter the path to the json DeepSpeed config file: ",
254
+ str,
255
+ default="none",
256
+ )
257
+ else:
258
+ deepspeed_config["zero_stage"] = _ask_options(
259
+ "What should be your DeepSpeed's ZeRO optimization stage?",
260
+ [0, 1, 2, 3],
261
+ int,
262
+ default=2,
263
+ )
264
+
265
+ deepspeed_devices = ["none", "cpu", "nvme"]
266
+ if deepspeed_config["zero_stage"] >= 2:
267
+ deepspeed_config["offload_optimizer_device"] = _ask_options(
268
+ "Where to offload optimizer states?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
269
+ )
270
+ deepspeed_config["offload_param_device"] = _ask_options(
271
+ "Where to offload parameters?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
272
+ )
273
+ if deepspeed_config["offload_param_device"] == "nvme":
274
+ deepspeed_config["offload_param_nvme_path"] = _ask_field(
275
+ "Nvme Path to offload parameters?",
276
+ str,
277
+ default="/nvme",
278
+ )
279
+ if deepspeed_config["offload_optimizer_device"] == "nvme":
280
+ deepspeed_config["offload_optimizer_nvme_path"] = _ask_field(
281
+ "Nvme Path to offload optimizer states?",
282
+ str,
283
+ default="/nvme",
284
+ )
285
+ deepspeed_config["gradient_accumulation_steps"] = _ask_field(
286
+ "How many gradient accumulation steps you're passing in your script? [1]: ",
287
+ int,
288
+ default=1,
289
+ )
290
+ use_gradient_clipping = _ask_field(
291
+ "Do you want to use gradient clipping? [yes/NO]: ",
292
+ _convert_yes_no_to_bool,
293
+ default=False,
294
+ error_message="Please enter yes or no.",
295
+ )
296
+ if use_gradient_clipping:
297
+ deepspeed_config["gradient_clipping"] = _ask_field(
298
+ "What is the gradient clipping value? [1.0]: ",
299
+ float,
300
+ default=1.0,
301
+ )
302
+ if deepspeed_config["zero_stage"] == 3:
303
+ deepspeed_config["zero3_save_16bit_model"] = _ask_field(
304
+ "Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ",
305
+ _convert_yes_no_to_bool,
306
+ default=False,
307
+ error_message="Please enter yes or no.",
308
+ )
309
+ deepspeed_config["zero3_init_flag"] = _ask_field(
310
+ "Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ",
311
+ _convert_yes_no_to_bool,
312
+ default=False,
313
+ error_message="Please enter yes or no.",
314
+ )
315
+ if deepspeed_config["zero3_init_flag"]:
316
+ if not is_transformers_available():
317
+ raise Exception(
318
+ "When `zero3_init_flag` is set, it requires Transformers to be installed. "
319
+ "Please run `pip3 install transformers`."
320
+ )
321
+ use_moe = _ask_field(
322
+ "Do you want to enable Mixture-of-Experts training (MoE)? [yes/NO]: ",
323
+ _convert_yes_no_to_bool,
324
+ default=False,
325
+ error_message="Please enter yes or no.",
326
+ )
327
+ if use_moe:
328
+ deepspeed_config["deepspeed_moe_layer_cls_names"] = _ask_field(
329
+ "Specify the comma-separated list of transformers MoE layer class names (case-sensitive), e.g : "
330
+ " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ... : ",
331
+ str,
332
+ )
333
+
334
+ if num_machines > 1:
335
+ launcher_query = "Which Type of launcher do you want to use?"
336
+ deepspeed_config["deepspeed_multinode_launcher"] = _ask_options(
337
+ launcher_query,
338
+ DEEPSPEED_MULTINODE_LAUNCHERS,
339
+ lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)],
340
+ )
341
+
342
+ if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
343
+ deepspeed_config["deepspeed_hostfile"] = _ask_field(
344
+ "DeepSpeed configures multi-node compute resources with hostfile. "
345
+ "Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; "
346
+ "for more information please refer official [documentation]"
347
+ "(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). "
348
+ "Please specify the location of hostfile: ",
349
+ str,
350
+ )
351
+
352
+ is_exclusion_filter = _ask_field(
353
+ "Do you want to specify exclusion filter string? [yes/NO]: ",
354
+ _convert_yes_no_to_bool,
355
+ default=False,
356
+ error_message="Please enter yes or no.",
357
+ )
358
+ if is_exclusion_filter:
359
+ deepspeed_config["deepspeed_exclusion_filter"] = _ask_field(
360
+ "DeepSpeed exclusion filter string: ",
361
+ str,
362
+ )
363
+
364
+ is_inclusion_filter = _ask_field(
365
+ "Do you want to specify inclusion filter string? [yes/NO]: ",
366
+ _convert_yes_no_to_bool,
367
+ default=False,
368
+ error_message="Please enter yes or no.",
369
+ )
370
+ if is_inclusion_filter:
371
+ deepspeed_config["deepspeed_inclusion_filter"] = _ask_field(
372
+ "DeepSpeed inclusion filter string: ",
373
+ str,
374
+ )
375
+
376
+ fsdp_config = {}
377
+
378
+ if distributed_type in [
379
+ DistributedType.MULTI_GPU,
380
+ DistributedType.MULTI_NPU,
381
+ DistributedType.MULTI_MLU,
382
+ DistributedType.MULTI_SDAA,
383
+ DistributedType.MULTI_MUSA,
384
+ DistributedType.MULTI_XPU,
385
+ DistributedType.MULTI_HPU,
386
+ DistributedType.MULTI_NEURON,
387
+ ]:
388
+ use_fsdp = _ask_field(
389
+ "Do you want to use FullyShardedDataParallel? [yes/NO]: ",
390
+ _convert_yes_no_to_bool,
391
+ default=False,
392
+ error_message="Please enter yes or no.",
393
+ )
394
+ if use_fsdp:
395
+ if distributed_type is DistributedType.MULTI_NEURON:
396
+ raise NotImplementedError("FSDP is not currently supported on Neuron devices.")
397
+ distributed_type = DistributedType.FSDP
398
+
399
+ if distributed_type == DistributedType.FSDP:
400
+ fsdp_config["fsdp_version"] = _ask_options(
401
+ "What should be your FSDP version? [2]: ",
402
+ [1, 2],
403
+ lambda x: int(x) + 1,
404
+ default=1,
405
+ )
406
+ fsdp_version = fsdp_config["fsdp_version"] # extract to a variable to simplify usage later
407
+
408
+ if fsdp_version == 1:
409
+ sharding_strategy_query = "What should be your sharding strategy?"
410
+ fsdp_config["fsdp_reshard_after_forward"] = _ask_options(
411
+ sharding_strategy_query,
412
+ FSDP_SHARDING_STRATEGY,
413
+ lambda x: FSDP_SHARDING_STRATEGY[int(x)],
414
+ )
415
+ else:
416
+ fsdp_config["fsdp_reshard_after_forward"] = _ask_field(
417
+ "Do you want to enable resharding after forward? [YES/no]: ",
418
+ _convert_yes_no_to_bool,
419
+ default=True,
420
+ error_message="Please enter yes or no.",
421
+ )
422
+
423
+ fsdp_config["fsdp_offload_params"] = _ask_field(
424
+ "Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
425
+ _convert_yes_no_to_bool,
426
+ default=False,
427
+ error_message="Please enter yes or no.",
428
+ )
429
+
430
+ fsdp_wrap_query = "What should be your auto wrap policy?"
431
+ fsdp_config["fsdp_auto_wrap_policy"] = _ask_options(
432
+ fsdp_wrap_query,
433
+ FSDP_AUTO_WRAP_POLICY,
434
+ lambda x: FSDP_AUTO_WRAP_POLICY[int(x)],
435
+ )
436
+ if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]:
437
+ use_no_split_modules = _ask_field(
438
+ "Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: ",
439
+ _convert_yes_no_to_bool,
440
+ default=False,
441
+ error_message="Please enter yes or no.",
442
+ )
443
+ if not use_no_split_modules:
444
+ fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field(
445
+ "Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :"
446
+ "`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : ",
447
+ str,
448
+ )
449
+ elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]:
450
+ fsdp_config["fsdp_min_num_params"] = _ask_field(
451
+ "What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ",
452
+ int,
453
+ default=100000000,
454
+ )
455
+ # Removed in FSDP2, ask for user input for FSDP1
456
+ if fsdp_version == 1:
457
+ fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?"
458
+ fsdp_config["fsdp_backward_prefetch"] = _ask_options(
459
+ fsdp_backward_prefetch_query,
460
+ FSDP_BACKWARD_PREFETCH,
461
+ lambda x: FSDP_BACKWARD_PREFETCH[int(x)],
462
+ )
463
+
464
+ fsdp_state_dict_type_query = "What should be your FSDP's state dict type?"
465
+ fsdp_config["fsdp_state_dict_type"] = _ask_options(
466
+ fsdp_state_dict_type_query,
467
+ FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE,
468
+ lambda x: FSDP_STATE_DICT_TYPE[int(x)] if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE[int(x)],
469
+ default=0,
470
+ )
471
+ # Not implemented in FSDP2, ask for user input for FSDP1
472
+ if fsdp_version == 1:
473
+ fsdp_config["fsdp_forward_prefetch"] = _ask_field(
474
+ "Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ",
475
+ _convert_yes_no_to_bool,
476
+ default=False,
477
+ error_message="Please enter yes or no.",
478
+ )
479
+ # Obsolete in FSDP2, ask for user input for FSDP1
480
+ if fsdp_version == 1:
481
+ fsdp_config["fsdp_use_orig_params"] = _ask_field(
482
+ "Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ",
483
+ _convert_yes_no_to_bool,
484
+ default=True,
485
+ error_message="Please enter yes or no.",
486
+ )
487
+ fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field(
488
+ "Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ",
489
+ _convert_yes_no_to_bool,
490
+ default=True,
491
+ error_message="Please enter yes or no.",
492
+ )
493
+ # Obsolete in FSDP2, ask for user input for FSDP1
494
+ if fsdp_version == 1:
495
+ if fsdp_config["fsdp_cpu_ram_efficient_loading"]:
496
+ fsdp_config["fsdp_sync_module_states"] = True
497
+ else:
498
+ fsdp_config["fsdp_sync_module_states"] = _ask_field(
499
+ "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
500
+ _convert_yes_no_to_bool,
501
+ default=True,
502
+ error_message="Please enter yes or no.",
503
+ )
504
+ fsdp_config["fsdp_activation_checkpointing"] = _ask_field(
505
+ "Do you want to enable FSDP activation checkpointing? [yes/NO]: ",
506
+ _convert_yes_no_to_bool,
507
+ default=False,
508
+ error_message="Please enter yes or no.",
509
+ )
510
+
511
+ parallelism_config = {}
512
+
513
+ if fsdp_config.get("fsdp_version", 1) == 2:
514
+ use_parallelism_config = _ask_field(
515
+ "Do you want to use the parallelism config? [yes/NO]: ",
516
+ _convert_yes_no_to_bool,
517
+ default=False,
518
+ error_message="Please enter yes or no.",
519
+ )
520
+
521
+ if use_parallelism_config:
522
+ prefix = "parallelism_config_"
523
+ parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
524
+ "What is the data parallelism replicate size? [1]: ",
525
+ int,
526
+ default=1,
527
+ error_message="Please enter an integer.",
528
+ )
529
+
530
+ parallelism_config[prefix + "dp_shard_size"] = _ask_field(
531
+ "What is the FSDP shard size? [1]: ",
532
+ int,
533
+ default=1,
534
+ error_message="Please enter an integer.",
535
+ )
536
+
537
+ parallelism_config[prefix + "tp_size"] = _ask_field(
538
+ "What is the tensor parallelism size? [1]: ",
539
+ int,
540
+ default=1,
541
+ error_message="Please enter an integer.",
542
+ )
543
+
544
+ parallelism_config[prefix + "cp_size"] = _ask_field(
545
+ "What is the context parallelism size? [1]: ",
546
+ int,
547
+ default=1,
548
+ error_message="Please enter an integer.",
549
+ )
550
+ if parallelism_config[prefix + "cp_size"] > 1:
551
+ parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
552
+ "What is the compute parallelism communication strategy?",
553
+ ["allgather", "alltoall"],
554
+ lambda x: ["allgather", "alltoall"][int(x)],
555
+ default=0,
556
+ )
557
+
558
+ megatron_lm_config = {}
559
+ if distributed_type in [DistributedType.MULTI_GPU]:
560
+ use_megatron_lm = _ask_field(
561
+ "Do you want to use Megatron-LM ? [yes/NO]: ",
562
+ _convert_yes_no_to_bool,
563
+ default=False,
564
+ error_message="Please enter yes or no.",
565
+ )
566
+ if use_megatron_lm:
567
+ distributed_type = DistributedType.MEGATRON_LM
568
+ if distributed_type == DistributedType.MEGATRON_LM:
569
+ prefix = "megatron_lm_"
570
+ megatron_lm_config[prefix + "tp_degree"] = _ask_field(
571
+ "What is the Tensor Parallelism degree/size? [1]:",
572
+ int,
573
+ default=1,
574
+ error_message="Please enter an integer.",
575
+ )
576
+ if megatron_lm_config[prefix + "tp_degree"] > 1:
577
+ megatron_lm_config[prefix + "sequence_parallelism"] = _ask_field(
578
+ "Do you want to enable Sequence Parallelism? [YES/no]: ",
579
+ _convert_yes_no_to_bool,
580
+ default=True,
581
+ error_message="Please enter yes or no.",
582
+ )
583
+
584
+ megatron_lm_config[prefix + "pp_degree"] = _ask_field(
585
+ "What is the Pipeline Parallelism degree/size? [1]:",
586
+ int,
587
+ default=1,
588
+ error_message="Please enter an integer.",
589
+ )
590
+ if megatron_lm_config[prefix + "pp_degree"] > 1:
591
+ megatron_lm_config[prefix + "num_micro_batches"] = _ask_field(
592
+ "What is the number of micro-batches? [1]:",
593
+ int,
594
+ default=1,
595
+ error_message="Please enter an integer.",
596
+ )
597
+
598
+ megatron_lm_config[prefix + "recompute_activations"] = _ask_field(
599
+ "Do you want to enable selective activation recomputation? [YES/no]: ",
600
+ _convert_yes_no_to_bool,
601
+ default=True,
602
+ error_message="Please enter yes or no.",
603
+ )
604
+
605
+ megatron_lm_config[prefix + "use_distributed_optimizer"] = _ask_field(
606
+ "Do you want to use distributed optimizer "
607
+ "which shards optimizer state and gradients across data parallel ranks? [YES/no]: ",
608
+ _convert_yes_no_to_bool,
609
+ default=True,
610
+ error_message="Please enter yes or no.",
611
+ )
612
+
613
+ megatron_lm_config[prefix + "gradient_clipping"] = _ask_field(
614
+ "What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: ",
615
+ float,
616
+ default=1.0,
617
+ )
618
+ # TPU specific defaults
619
+ tpu_commands = None
620
+ tpu_command_file = None
621
+ tpu_downcast_bf16 = "no"
622
+ tpu_env = []
623
+ tpu_name = None
624
+ tpu_vm = None
625
+ tpu_zone = None
626
+ tpu_use_sudo = False
627
+ tpu_use_cluster = False
628
+
629
+ if distributed_type in [
630
+ DistributedType.MULTI_CPU,
631
+ DistributedType.MULTI_XPU,
632
+ DistributedType.MULTI_HPU,
633
+ DistributedType.MULTI_GPU,
634
+ DistributedType.MULTI_MLU,
635
+ DistributedType.MULTI_SDAA,
636
+ DistributedType.MULTI_MUSA,
637
+ DistributedType.MULTI_NPU,
638
+ DistributedType.MULTI_NEURON,
639
+ DistributedType.XLA,
640
+ ]:
641
+ machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
642
+ if machine_type in ["TPU", "NEURON"]:
643
+ machine_type += " cores"
644
+ elif machine_type == "CPU":
645
+ machine_type = "processes"
646
+ else:
647
+ machine_type += "(s)"
648
+ num_processes = _ask_field(
649
+ f"How many {machine_type} should be used for distributed training? [1]:",
650
+ int,
651
+ default=1,
652
+ error_message="Please enter an integer.",
653
+ )
654
+ elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
655
+ num_processes = _ask_field(
656
+ "How many GPU(s) should be used for distributed training? [1]:",
657
+ int,
658
+ default=1,
659
+ error_message="Please enter an integer.",
660
+ )
661
+ else:
662
+ num_processes = 1
663
+
664
+ if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1):
665
+ raise ValueError(
666
+ f"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using."
667
+ )
668
+
669
+ if (
670
+ distributed_type
671
+ in [
672
+ DistributedType.MULTI_GPU,
673
+ DistributedType.MULTI_MLU,
674
+ DistributedType.MULTI_SDAA,
675
+ DistributedType.MULTI_MUSA,
676
+ DistributedType.MULTI_NPU,
677
+ DistributedType.MULTI_XPU,
678
+ DistributedType.MULTI_HPU,
679
+ DistributedType.MULTI_NEURON,
680
+ DistributedType.NO,
681
+ ]
682
+ and not use_cpu
683
+ and not use_mps
684
+ ):
685
+ if is_npu_available():
686
+ machine_type = "NPU(s)"
687
+ elif is_mlu_available():
688
+ machine_type = "MLU(s)"
689
+ elif is_sdaa_available():
690
+ machine_type = "SDAA(s)"
691
+ elif is_musa_available():
692
+ machine_type = "MUSA(s)"
693
+ elif is_xpu_available():
694
+ machine_type = "XPU(s)"
695
+ elif is_hpu_available():
696
+ machine_type = "HPU(s)"
697
+ elif is_neuron_available():
698
+ machine_type = "Neuron cores"
699
+ else:
700
+ machine_type = "GPU(s)"
701
+ gpu_ids = _ask_field(
702
+ f"What {machine_type} (by id) should be used for training on this machine as a comma-separated list? [all]:",
703
+ default="all",
704
+ )
705
+
706
+ # CPU affinity is only supported on NVIDIA hardware for now
707
+ enable_cpu_affinity = False
708
+ if distributed_type in (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps:
709
+ enable_cpu_affinity = _ask_field(
710
+ "Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: ",
711
+ _convert_yes_no_to_bool,
712
+ default=False,
713
+ error_message="Please enter yes or no.",
714
+ )
715
+
716
+ fp8_config = None
717
+ if distributed_type == DistributedType.XLA:
718
+ mixed_precision = "no"
719
+ main_training_function = _ask_field(
720
+ "What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
721
+ default="main",
722
+ )
723
+ tpu_use_cluster = _ask_field(
724
+ "Are you using a TPU cluster? [yes/NO]: ",
725
+ _convert_yes_no_to_bool,
726
+ default=False,
727
+ error_message="Please enter yes or no.",
728
+ )
729
+ if tpu_use_cluster:
730
+ tpu_name = _ask_field(
731
+ "What is the name of your TPU cluster? ",
732
+ default=None,
733
+ error_message="Please enter the name of your TPU cluster.",
734
+ )
735
+ tpu_zone = _ask_field(
736
+ "What is the zone of your TPU cluster? ",
737
+ default=None,
738
+ error_message="Please enter the zone of your TPU cluster.",
739
+ )
740
+ tpu_use_sudo = _ask_field(
741
+ "To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: ",
742
+ default=False,
743
+ error_message="Please enter yes or no.",
744
+ )
745
+ run_commands = _ask_field(
746
+ "Do you have code you wish to run on startup in each pod? [yes/NO]: ",
747
+ _convert_yes_no_to_bool,
748
+ default=False,
749
+ error_message="Please enter yes or no.",
750
+ )
751
+ if run_commands:
752
+ use_command_file = _ask_field(
753
+ "Is this code located in a bash script? [yes/NO]: ",
754
+ _convert_yes_no_to_bool,
755
+ default=False,
756
+ error_message="Please enter yes or no.",
757
+ )
758
+ if use_command_file:
759
+ tpu_command_file = _ask_field(
760
+ "What is the path to your bash script? ",
761
+ default=None,
762
+ error_message="Please enter the path to your bash script.",
763
+ )
764
+ tpu_command_file = os.path.abspath(tpu_command_file)
765
+ else:
766
+ print("Please enter each command separately you wish to run on startup in each pod.")
767
+ tpu_commands = []
768
+ another_command = True
769
+ while another_command:
770
+ tpu_commands.append(
771
+ _ask_field(
772
+ "Please enter a single command to be ran ",
773
+ default=None,
774
+ error_message="Please enter the commands you wish to run on startup in each pod as a single string.",
775
+ )
776
+ )
777
+ another_command = _ask_field(
778
+ "Do you wish to add another command? [yes/NO]: ",
779
+ _convert_yes_no_to_bool,
780
+ default=False,
781
+ error_message="Please enter yes or no.",
782
+ )
783
+ tpu_vm = _ask_field(
784
+ "If not using an instance group, what are the names of the Compute VM instances to be used, separated by a comma: ",
785
+ default="",
786
+ ).split(",")
787
+ tpu_env = _ask_field(
788
+ "What environment variables do you wish to set in each pod, separated by a comma: ",
789
+ default="",
790
+ ).split(",")
791
+
792
+ else:
793
+ main_training_function = "main"
794
+ if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
795
+ mixed_precision = None
796
+ else:
797
+ mixed_precision = _ask_options(
798
+ "Do you wish to use mixed precision?",
799
+ ["no", "fp16", "bf16", "fp8"],
800
+ _convert_mixed_precision,
801
+ )
802
+ if mixed_precision == "fp8":
803
+ if not is_fp8_available():
804
+ raise ValueError(
805
+ "FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine."
806
+ )
807
+ fp8_config = {}
808
+ fp8_config["backend"] = _ask_options(
809
+ "Which FP8 backend do you want to use?",
810
+ ["ao", "te", "msamp"],
811
+ _convert_fp8_backend,
812
+ )
813
+ if fp8_config["backend"] == "TE":
814
+ if not is_transformer_engine_available():
815
+ raise ValueError("TransformersEngine was selected, but it is not installed on this machine.")
816
+ fp8_config["use_autocast_during_eval"] = _ask_field(
817
+ "Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ",
818
+ _convert_yes_no_to_bool,
819
+ default=False,
820
+ )
821
+ fp8_config["margin"] = _ask_field(
822
+ "What margin should be used for gradient scaling? [0]: ",
823
+ int,
824
+ default=0,
825
+ )
826
+ fp8_config["interval"] = _ask_field(
827
+ "What interval should be used for for how often the scaling factor is recomputed? [1]: ",
828
+ int,
829
+ default=1,
830
+ )
831
+ fp8_config["fp8_format"] = _ask_options(
832
+ "Which weight format should be used?",
833
+ ["HYBRID", "E4M3", "E5M2"],
834
+ lambda i: ["HYBRID", "E4M3", "E5M2"][i],
835
+ default=0,
836
+ )
837
+ fp8_config["amax_history_length"] = _ask_field(
838
+ "What length of history should be used for the amax scaling factor computation? [1024]: ",
839
+ int,
840
+ default=1024,
841
+ )
842
+ fp8_config["amax_compute_algorithm"] = _ask_options(
843
+ "Which algorithm should be used for the amax scaling factor computation?",
844
+ ["max", "most_recent"],
845
+ lambda x: "max" if x == 0 else "most_recent",
846
+ default=0,
847
+ )
848
+ fp8_config["override_linear_precision"] = _ask_field(
849
+ "Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ",
850
+ _convert_yes_no_to_bool,
851
+ default=False,
852
+ )
853
+ if fp8_config["override_linear_precision"]:
854
+ fprop = _ask_field(
855
+ "Should `fprop` be executed in higher precision? [yes/NO]: ",
856
+ _convert_yes_no_to_bool,
857
+ default=False,
858
+ )
859
+ dgrad = _ask_field(
860
+ "Should `dgrad` be executed in higher precision? [yes/NO]: ",
861
+ _convert_yes_no_to_bool,
862
+ default=False,
863
+ )
864
+ wgrad = _ask_field(
865
+ "Should `wgrad` be executed in higher precision? [yes/NO]: ",
866
+ _convert_yes_no_to_bool,
867
+ default=False,
868
+ )
869
+ fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad)
870
+ else:
871
+ fp8_config["override_linear_precision"] = (False, False, False)
872
+
873
+ elif fp8_config["backend"] == "MSAMP":
874
+ if not is_msamp_available():
875
+ raise ValueError("MSAMP was selected, but it is not installed on this machine.")
876
+ fp8_config["optimization_level"] = _ask_options(
877
+ "Which optimization level should be used?",
878
+ ["O1", "O2"],
879
+ lambda x: "O1" if x == 0 else "O2",
880
+ default=1,
881
+ )
882
+
883
+ elif fp8_config["backend"] == "AO":
884
+ if not is_torchao_available():
885
+ raise ValueError("torchao was selected, but it is not installed on this machine.")
886
+ fp8_config["enable_fsdp_float8_all_gather"] = _ask_field(
887
+ "Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: ",
888
+ _convert_yes_no_to_bool,
889
+ default=True,
890
+ )
891
+ fp8_config["pad_inner_dim"] = _ask_field(
892
+ "Do you want to pad the inner dimension of weight matrices before float8 matmuls? This is required for _scaled_mm which has strict alignment requirements. Note: padding may cause memory spikes. [YES/no]: ",
893
+ _convert_yes_no_to_bool,
894
+ default=True,
895
+ )
896
+
897
+ if use_dynamo and mixed_precision == "no" and not use_cpu:
898
+ print(
899
+ "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
900
+ )
901
+
902
+ if distributed_type == DistributedType.XLA and mixed_precision == "bf16":
903
+ tpu_downcast_bf16 = _ask_field(
904
+ "Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
905
+ )
906
+
907
+ return ClusterConfig(
908
+ compute_environment=ComputeEnvironment.LOCAL_MACHINE,
909
+ distributed_type=distributed_type,
910
+ num_processes=num_processes,
911
+ gpu_ids=gpu_ids,
912
+ mixed_precision=mixed_precision,
913
+ downcast_bf16=tpu_downcast_bf16,
914
+ machine_rank=machine_rank,
915
+ num_machines=num_machines,
916
+ main_process_ip=main_process_ip,
917
+ main_process_port=main_process_port,
918
+ main_training_function=main_training_function,
919
+ fp8_config=fp8_config,
920
+ deepspeed_config=deepspeed_config,
921
+ fsdp_config=fsdp_config,
922
+ parallelism_config=parallelism_config,
923
+ megatron_lm_config=megatron_lm_config,
924
+ mpirun_config=mpirun_config,
925
+ use_cpu=use_cpu,
926
+ rdzv_backend=rdzv_backend,
927
+ same_network=same_network,
928
+ commands=tpu_commands,
929
+ command_file=tpu_command_file,
930
+ tpu_env=tpu_env,
931
+ tpu_name=tpu_name,
932
+ tpu_vm=tpu_vm,
933
+ tpu_zone=tpu_zone,
934
+ tpu_use_sudo=tpu_use_sudo,
935
+ tpu_use_cluster=tpu_use_cluster,
936
+ dynamo_config=dynamo_config,
937
+ debug=debug,
938
+ enable_cpu_affinity=enable_cpu_affinity,
939
+ )
accelerate/commands/config/config.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+
20
+ from accelerate.utils import ComputeEnvironment
21
+
22
+ from .cluster import get_cluster_input
23
+ from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401
24
+ from .config_utils import _ask_field, _ask_options, _convert_compute_environment # noqa: F401
25
+ from .sagemaker import get_sagemaker_input
26
+
27
+
28
+ description = "Launches a series of prompts to create and save a `default_config.yaml` configuration file for your training system. Should always be ran first on your machine"
29
+
30
+
31
+ def get_user_input():
32
+ compute_environment = _ask_options(
33
+ "In which compute environment are you running?",
34
+ ["This machine", "AWS (Amazon SageMaker)"],
35
+ _convert_compute_environment,
36
+ )
37
+ if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
38
+ config = get_sagemaker_input()
39
+ else:
40
+ config = get_cluster_input()
41
+ return config
42
+
43
+
44
+ def config_command_parser(subparsers=None):
45
+ if subparsers is not None:
46
+ parser = subparsers.add_parser("config", description=description)
47
+ else:
48
+ parser = argparse.ArgumentParser("Accelerate config command", description=description)
49
+
50
+ parser.add_argument(
51
+ "--config_file",
52
+ default=None,
53
+ help=(
54
+ "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
55
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
56
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
57
+ "with 'huggingface'."
58
+ ),
59
+ )
60
+
61
+ if subparsers is not None:
62
+ parser.set_defaults(func=config_command)
63
+ return parser
64
+
65
+
66
+ def config_command(args):
67
+ config = get_user_input()
68
+ if args.config_file is not None:
69
+ config_file = args.config_file
70
+ else:
71
+ if not os.path.isdir(cache_dir):
72
+ os.makedirs(cache_dir)
73
+ config_file = default_yaml_config_file
74
+
75
+ if config_file.endswith(".json"):
76
+ config.to_json_file(config_file)
77
+ else:
78
+ config.to_yaml_file(config_file)
79
+ print(f"accelerate configuration saved at {config_file}")
80
+
81
+
82
+ def main():
83
+ parser = config_command_parser()
84
+ args = parser.parse_args()
85
+ config_command(args)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
accelerate/commands/config/config_args.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ from dataclasses import dataclass
20
+ from enum import Enum
21
+ from typing import Optional, Union
22
+
23
+ import yaml
24
+
25
+ from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
26
+ from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION
27
+
28
+
29
+ hf_cache_home = os.path.expanduser(
30
+ os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
31
+ )
32
+ cache_dir = os.path.join(hf_cache_home, "accelerate")
33
+ default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
34
+ default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
35
+
36
+ # For backward compatibility: the default config is the json one if it's the only existing file.
37
+ if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
38
+ default_config_file = default_yaml_config_file
39
+ else:
40
+ default_config_file = default_json_config_file
41
+
42
+
43
+ def load_config_from_file(config_file):
44
+ if config_file is not None:
45
+ if not os.path.isfile(config_file):
46
+ raise FileNotFoundError(
47
+ f"The passed configuration file `{config_file}` does not exist. "
48
+ "Please pass an existing file to `accelerate launch`, or use the default one "
49
+ "created through `accelerate config` and run `accelerate launch` "
50
+ "without the `--config_file` argument."
51
+ )
52
+ else:
53
+ config_file = default_config_file
54
+ with open(config_file, encoding="utf-8") as f:
55
+ if config_file.endswith(".json"):
56
+ if (
57
+ json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
58
+ == ComputeEnvironment.LOCAL_MACHINE
59
+ ):
60
+ config_class = ClusterConfig
61
+ else:
62
+ config_class = SageMakerConfig
63
+ return config_class.from_json_file(json_file=config_file)
64
+ else:
65
+ if (
66
+ yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
67
+ == ComputeEnvironment.LOCAL_MACHINE
68
+ ):
69
+ config_class = ClusterConfig
70
+ else:
71
+ config_class = SageMakerConfig
72
+ return config_class.from_yaml_file(yaml_file=config_file)
73
+
74
+
75
+ @dataclass
76
+ class BaseConfig:
77
+ compute_environment: ComputeEnvironment
78
+ distributed_type: Union[DistributedType, SageMakerDistributedType]
79
+ mixed_precision: str
80
+ use_cpu: bool
81
+ debug: bool
82
+
83
+ def to_dict(self):
84
+ result = self.__dict__
85
+ # For serialization, it's best to convert Enums to strings (or their underlying value type).
86
+
87
+ def _convert_enums(value):
88
+ if isinstance(value, Enum):
89
+ return value.value
90
+ if isinstance(value, dict):
91
+ if not bool(value):
92
+ return None
93
+ for key1, value1 in value.items():
94
+ value[key1] = _convert_enums(value1)
95
+ return value
96
+
97
+ for key, value in result.items():
98
+ result[key] = _convert_enums(value)
99
+ result = {k: v for k, v in result.items() if v is not None}
100
+ return result
101
+
102
+ @staticmethod
103
+ def process_config(config_dict):
104
+ """
105
+ Processes `config_dict` and sets default values for any missing keys
106
+ """
107
+ if "compute_environment" not in config_dict:
108
+ config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
109
+ if "distributed_type" not in config_dict:
110
+ raise ValueError("A `distributed_type` must be specified in the config file.")
111
+ if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
112
+ config_dict["num_processes"] = 1
113
+ if "mixed_precision" not in config_dict:
114
+ config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
115
+ if "fp16" in config_dict: # Convert the config to the new format.
116
+ del config_dict["fp16"]
117
+ if "dynamo_backend" in config_dict: # Convert the config to the new format.
118
+ dynamo_backend = config_dict.pop("dynamo_backend")
119
+ config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
120
+ if "use_cpu" not in config_dict:
121
+ config_dict["use_cpu"] = False
122
+ if "debug" not in config_dict:
123
+ config_dict["debug"] = False
124
+ if "enable_cpu_affinity" not in config_dict:
125
+ config_dict["enable_cpu_affinity"] = False
126
+ return config_dict
127
+
128
+ @classmethod
129
+ def from_json_file(cls, json_file=None):
130
+ json_file = default_json_config_file if json_file is None else json_file
131
+ with open(json_file, encoding="utf-8") as f:
132
+ config_dict = json.load(f)
133
+ config_dict = cls.process_config(config_dict)
134
+ extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
135
+ if len(extra_keys) > 0:
136
+ raise ValueError(
137
+ f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
138
+ " version or fix (and potentially remove) these keys from your config file."
139
+ )
140
+
141
+ return cls(**config_dict)
142
+
143
+ def to_json_file(self, json_file):
144
+ with open(json_file, "w", encoding="utf-8") as f:
145
+ content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
146
+ f.write(content)
147
+
148
+ @classmethod
149
+ def from_yaml_file(cls, yaml_file=None):
150
+ yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
151
+ with open(yaml_file, encoding="utf-8") as f:
152
+ config_dict = yaml.safe_load(f)
153
+ config_dict = cls.process_config(config_dict)
154
+ extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
155
+ if len(extra_keys) > 0:
156
+ raise ValueError(
157
+ f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
158
+ " version or fix (and potentially remove) these keys from your config file."
159
+ )
160
+ return cls(**config_dict)
161
+
162
+ def to_yaml_file(self, yaml_file):
163
+ with open(yaml_file, "w", encoding="utf-8") as f:
164
+ yaml.safe_dump(self.to_dict(), f)
165
+
166
+ def __post_init__(self):
167
+ if isinstance(self.compute_environment, str):
168
+ self.compute_environment = ComputeEnvironment(self.compute_environment)
169
+ if isinstance(self.distributed_type, str):
170
+ if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
171
+ self.distributed_type = SageMakerDistributedType(self.distributed_type)
172
+ else:
173
+ self.distributed_type = DistributedType(self.distributed_type)
174
+ if getattr(self, "dynamo_config", None) is None:
175
+ self.dynamo_config = {}
176
+
177
+
178
+ @dataclass
179
+ class ClusterConfig(BaseConfig):
180
+ num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
181
+ machine_rank: int = 0
182
+ num_machines: int = 1
183
+ gpu_ids: Optional[str] = None
184
+ main_process_ip: Optional[str] = None
185
+ main_process_port: Optional[int] = None
186
+ rdzv_backend: Optional[str] = "static"
187
+ same_network: Optional[bool] = False
188
+ main_training_function: str = "main"
189
+ enable_cpu_affinity: bool = False
190
+
191
+ # args for FP8 training
192
+ fp8_config: Optional[dict] = None
193
+ # args for deepspeed_plugin
194
+ deepspeed_config: Optional[dict] = None
195
+ # args for fsdp
196
+ fsdp_config: Optional[dict] = None
197
+ # args for parallelism config
198
+ parallelism_config: Optional[dict] = None
199
+ # args for megatron_lm
200
+ megatron_lm_config: Optional[dict] = None
201
+ # args for mpirun
202
+ mpirun_config: Optional[dict] = None
203
+ # args for TPU
204
+ downcast_bf16: bool = False
205
+
206
+ # args for TPU pods
207
+ tpu_name: Optional[str] = None
208
+ tpu_zone: Optional[str] = None
209
+ tpu_use_cluster: bool = False
210
+ tpu_use_sudo: bool = False
211
+ command_file: Optional[str] = None
212
+ commands: list[str] = None
213
+ tpu_vm: list[str] = None
214
+ tpu_env: list[str] = None
215
+
216
+ # args for dynamo
217
+ dynamo_config: Optional[dict] = None
218
+
219
+ def __post_init__(self):
220
+ if self.deepspeed_config is None:
221
+ self.deepspeed_config = {}
222
+ if self.fsdp_config is None:
223
+ self.fsdp_config = {}
224
+ if self.megatron_lm_config is None:
225
+ self.megatron_lm_config = {}
226
+ if self.mpirun_config is None:
227
+ self.mpirun_config = {}
228
+ if self.fp8_config is None:
229
+ self.fp8_config = {}
230
+ if self.parallelism_config is None:
231
+ self.parallelism_config = {}
232
+ return super().__post_init__()
233
+
234
+
235
+ @dataclass
236
+ class SageMakerConfig(BaseConfig):
237
+ ec2_instance_type: str
238
+ iam_role_name: str
239
+ image_uri: Optional[str] = None
240
+ profile: Optional[str] = None
241
+ region: str = "us-east-1"
242
+ num_machines: int = 1
243
+ gpu_ids: str = "all"
244
+ base_job_name: str = f"accelerate-sagemaker-{num_machines}"
245
+ pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
246
+ transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
247
+ py_version: str = SAGEMAKER_PYTHON_VERSION
248
+ sagemaker_inputs_file: Optional[str] = None
249
+ sagemaker_metrics_file: Optional[str] = None
250
+ additional_args: Optional[dict] = None
251
+ dynamo_config: Optional[dict] = None
252
+ enable_cpu_affinity: bool = False
accelerate/commands/config/config_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ from ...utils.dataclasses import (
20
+ ComputeEnvironment,
21
+ DistributedType,
22
+ DynamoBackend,
23
+ FP8BackendType,
24
+ PrecisionType,
25
+ SageMakerDistributedType,
26
+ )
27
+ from ..menu import BulletMenu
28
+
29
+
30
+ DYNAMO_BACKENDS = [
31
+ "EAGER",
32
+ "AOT_EAGER",
33
+ "INDUCTOR",
34
+ "AOT_TS_NVFUSER",
35
+ "NVPRIMS_NVFUSER",
36
+ "CUDAGRAPHS",
37
+ "OFI",
38
+ "FX2TRT",
39
+ "ONNXRT",
40
+ "TENSORRT",
41
+ "AOT_TORCHXLA_TRACE_ONCE",
42
+ "TORHCHXLA_TRACE_ONCE",
43
+ "TVM",
44
+ ]
45
+
46
+
47
+ def _ask_field(input_text, convert_value=None, default=None, error_message=None):
48
+ ask_again = True
49
+ while ask_again:
50
+ result = input(input_text)
51
+ try:
52
+ if default is not None and len(result) == 0:
53
+ return default
54
+ return convert_value(result) if convert_value is not None else result
55
+ except Exception:
56
+ if error_message is not None:
57
+ print(error_message)
58
+
59
+
60
+ def _ask_options(input_text, options=[], convert_value=None, default=0):
61
+ menu = BulletMenu(input_text, options)
62
+ result = menu.run(default_choice=default)
63
+ return convert_value(result) if convert_value is not None else result
64
+
65
+
66
+ def _convert_compute_environment(value):
67
+ value = int(value)
68
+ return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
69
+
70
+
71
+ def _convert_distributed_mode(value):
72
+ value = int(value)
73
+ return DistributedType(
74
+ [
75
+ "NO",
76
+ "MULTI_CPU",
77
+ "MULTI_XPU",
78
+ "MULTI_HPU",
79
+ "MULTI_GPU",
80
+ "MULTI_NPU",
81
+ "MULTI_MLU",
82
+ "MULTI_SDAA",
83
+ "MULTI_MUSA",
84
+ "MULTI_NEURON",
85
+ "XLA",
86
+ ][value]
87
+ )
88
+
89
+
90
+ def _convert_dynamo_backend(value):
91
+ value = int(value)
92
+ return DynamoBackend(DYNAMO_BACKENDS[value]).value
93
+
94
+
95
+ def _convert_mixed_precision(value):
96
+ value = int(value)
97
+ return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
98
+
99
+
100
+ def _convert_sagemaker_distributed_mode(value):
101
+ value = int(value)
102
+ return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
103
+
104
+
105
+ def _convert_fp8_backend(value):
106
+ value = int(value)
107
+ return FP8BackendType(["AO", "TE", "MSAMP"][value])
108
+
109
+
110
+ def _convert_yes_no_to_bool(value):
111
+ return {"yes": True, "no": False}[value.lower()]
112
+
113
+
114
+ class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
115
+ """
116
+ A custom formatter that will remove the usage line from the help message for subcommands.
117
+ """
118
+
119
+ def _format_usage(self, usage, actions, groups, prefix):
120
+ usage = super()._format_usage(usage, actions, groups, prefix)
121
+ usage = usage.replace("<command> [<args>] ", "")
122
+ return usage
accelerate/commands/config/default.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ from ...utils import (
22
+ is_hpu_available,
23
+ is_mlu_available,
24
+ is_musa_available,
25
+ is_neuron_available,
26
+ is_npu_available,
27
+ is_sdaa_available,
28
+ is_xpu_available,
29
+ )
30
+ from .config_args import ClusterConfig, default_json_config_file
31
+ from .config_utils import SubcommandHelpFormatter
32
+
33
+
34
+ description = "Create a default config file for Accelerate with only a few flags set."
35
+
36
+
37
+ def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file):
38
+ """
39
+ Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
40
+ set CPU if it is a CPU-only machine.
41
+
42
+ Args:
43
+ mixed_precision (`str`, *optional*, defaults to "no"):
44
+ Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
45
+ save_location (`str`, *optional*, defaults to `default_json_config_file`):
46
+ Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default
47
+ location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overridden by setting
48
+ the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.
49
+ """
50
+ path = Path(save_location)
51
+ path.parent.mkdir(parents=True, exist_ok=True)
52
+ if path.exists():
53
+ print(
54
+ f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
55
+ )
56
+ return False
57
+ mixed_precision = mixed_precision.lower()
58
+ if mixed_precision not in ["no", "fp16", "bf16", "fp8"]:
59
+ raise ValueError(
60
+ f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}"
61
+ )
62
+ config = {
63
+ "compute_environment": "LOCAL_MACHINE",
64
+ "mixed_precision": mixed_precision,
65
+ }
66
+ if is_mlu_available():
67
+ num_mlus = torch.mlu.device_count()
68
+ config["num_processes"] = num_mlus
69
+ config["use_cpu"] = False
70
+ if num_mlus > 1:
71
+ config["distributed_type"] = "MULTI_MLU"
72
+ else:
73
+ config["distributed_type"] = "NO"
74
+ if is_sdaa_available():
75
+ num_sdaas = torch.sdaa.device_count()
76
+ config["num_processes"] = num_sdaas
77
+ config["use_cpu"] = False
78
+ if num_sdaas > 1:
79
+ config["distributed_type"] = "MULTI_SDAA"
80
+ else:
81
+ config["distributed_type"] = "NO"
82
+ elif is_musa_available():
83
+ num_musas = torch.musa.device_count()
84
+ config["num_processes"] = num_musas
85
+ config["use_cpu"] = False
86
+ if num_musas > 1:
87
+ config["distributed_type"] = "MULTI_MUSA"
88
+ else:
89
+ config["distributed_type"] = "NO"
90
+ elif is_hpu_available():
91
+ num_hpus = torch.hpu.device_count()
92
+ config["num_processes"] = num_hpus
93
+ config["use_cpu"] = False
94
+ if num_hpus > 1:
95
+ config["distributed_type"] = "MULTI_HPU"
96
+ else:
97
+ config["distributed_type"] = "NO"
98
+ elif torch.cuda.is_available():
99
+ num_gpus = torch.cuda.device_count()
100
+ config["num_processes"] = num_gpus
101
+ config["use_cpu"] = False
102
+ if num_gpus > 1:
103
+ config["distributed_type"] = "MULTI_GPU"
104
+ else:
105
+ config["distributed_type"] = "NO"
106
+ elif is_xpu_available():
107
+ num_xpus = torch.xpu.device_count()
108
+ config["num_processes"] = num_xpus
109
+ config["use_cpu"] = False
110
+ if num_xpus > 1:
111
+ config["distributed_type"] = "MULTI_XPU"
112
+ else:
113
+ config["distributed_type"] = "NO"
114
+ elif is_npu_available():
115
+ num_npus = torch.npu.device_count()
116
+ config["num_processes"] = num_npus
117
+ config["use_cpu"] = False
118
+ if num_npus > 1:
119
+ config["distributed_type"] = "MULTI_NPU"
120
+ else:
121
+ config["distributed_type"] = "NO"
122
+ elif is_neuron_available():
123
+ num_neuron_cores = torch.neuron.device_count()
124
+ config["num_processes"] = num_neuron_cores
125
+ config["use_cpu"] = False
126
+ if num_neuron_cores > 1:
127
+ config["distributed_type"] = "MULTI_NEURON"
128
+ else:
129
+ config["distributed_type"] = "NO"
130
+ else:
131
+ num_xpus = 0
132
+ config["use_cpu"] = True
133
+ config["num_processes"] = 1
134
+ config["distributed_type"] = "NO"
135
+ config["debug"] = False
136
+ config["enable_cpu_affinity"] = False
137
+ config = ClusterConfig(**config)
138
+ config.to_json_file(path)
139
+ return path
140
+
141
+
142
+ def default_command_parser(parser, parents):
143
+ parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
144
+ parser.add_argument(
145
+ "--config_file",
146
+ default=default_json_config_file,
147
+ help=(
148
+ "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
149
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
150
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
151
+ "with 'huggingface'."
152
+ ),
153
+ dest="save_location",
154
+ )
155
+
156
+ parser.add_argument(
157
+ "--mixed_precision",
158
+ choices=["no", "fp16", "bf16"],
159
+ type=str,
160
+ help="Whether or not to use mixed precision training. "
161
+ "Choose between FP16 and BF16 (bfloat16) training. "
162
+ "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
163
+ default="no",
164
+ )
165
+ parser.set_defaults(func=default_config_command)
166
+ return parser
167
+
168
+
169
+ def default_config_command(args):
170
+ config_file = write_basic_config(args.mixed_precision, args.save_location)
171
+ if config_file:
172
+ print(f"accelerate configuration saved at {config_file}")
accelerate/commands/config/sagemaker.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import json
17
+ import os
18
+
19
+ from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
20
+ from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
21
+ from ...utils.imports import is_boto3_available
22
+ from .config_args import SageMakerConfig
23
+ from .config_utils import (
24
+ DYNAMO_BACKENDS,
25
+ _ask_field,
26
+ _ask_options,
27
+ _convert_dynamo_backend,
28
+ _convert_mixed_precision,
29
+ _convert_sagemaker_distributed_mode,
30
+ _convert_yes_no_to_bool,
31
+ )
32
+
33
+
34
+ if is_boto3_available():
35
+ import boto3 # noqa: F401
36
+
37
+
38
+ def _create_iam_role_for_sagemaker(role_name):
39
+ iam_client = boto3.client("iam")
40
+
41
+ sagemaker_trust_policy = {
42
+ "Version": "2012-10-17",
43
+ "Statement": [
44
+ {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
45
+ ],
46
+ }
47
+ try:
48
+ # create the role, associated with the chosen trust policy
49
+ iam_client.create_role(
50
+ RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
51
+ )
52
+ policy_document = {
53
+ "Version": "2012-10-17",
54
+ "Statement": [
55
+ {
56
+ "Effect": "Allow",
57
+ "Action": [
58
+ "sagemaker:*",
59
+ "ecr:GetDownloadUrlForLayer",
60
+ "ecr:BatchGetImage",
61
+ "ecr:BatchCheckLayerAvailability",
62
+ "ecr:GetAuthorizationToken",
63
+ "cloudwatch:PutMetricData",
64
+ "cloudwatch:GetMetricData",
65
+ "cloudwatch:GetMetricStatistics",
66
+ "cloudwatch:ListMetrics",
67
+ "logs:CreateLogGroup",
68
+ "logs:CreateLogStream",
69
+ "logs:DescribeLogStreams",
70
+ "logs:PutLogEvents",
71
+ "logs:GetLogEvents",
72
+ "s3:CreateBucket",
73
+ "s3:ListBucket",
74
+ "s3:GetBucketLocation",
75
+ "s3:GetObject",
76
+ "s3:PutObject",
77
+ ],
78
+ "Resource": "*",
79
+ }
80
+ ],
81
+ }
82
+ # attach policy to role
83
+ iam_client.put_role_policy(
84
+ RoleName=role_name,
85
+ PolicyName=f"{role_name}_policy_permission",
86
+ PolicyDocument=json.dumps(policy_document, indent=2),
87
+ )
88
+ except iam_client.exceptions.EntityAlreadyExistsException:
89
+ print(f"role {role_name} already exists. Using existing one")
90
+
91
+
92
+ def _get_iam_role_arn(role_name):
93
+ iam_client = boto3.client("iam")
94
+ return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
95
+
96
+
97
+ def get_sagemaker_input():
98
+ credentials_configuration = _ask_options(
99
+ "How do you want to authorize?",
100
+ ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
101
+ int,
102
+ )
103
+ aws_profile = None
104
+ if credentials_configuration == 0:
105
+ aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
106
+ os.environ["AWS_PROFILE"] = aws_profile
107
+ else:
108
+ print(
109
+ "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
110
+ "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
111
+ )
112
+ aws_access_key_id = _ask_field("AWS Access Key ID: ")
113
+ os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
114
+
115
+ aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
116
+ os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
117
+
118
+ aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
119
+ os.environ["AWS_DEFAULT_REGION"] = aws_region
120
+
121
+ role_management = _ask_options(
122
+ "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
123
+ ["Provide IAM Role name", "Create new IAM role using credentials"],
124
+ int,
125
+ )
126
+ if role_management == 0:
127
+ iam_role_name = _ask_field("Enter your IAM role name: ")
128
+ else:
129
+ iam_role_name = "accelerate_sagemaker_execution_role"
130
+ print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
131
+ _create_iam_role_for_sagemaker(iam_role_name)
132
+
133
+ is_custom_docker_image = _ask_field(
134
+ "Do you want to use custom Docker image? [yes/NO]: ",
135
+ _convert_yes_no_to_bool,
136
+ default=False,
137
+ error_message="Please enter yes or no.",
138
+ )
139
+ docker_image = None
140
+ if is_custom_docker_image:
141
+ docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
142
+
143
+ is_sagemaker_inputs_enabled = _ask_field(
144
+ "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
145
+ _convert_yes_no_to_bool,
146
+ default=False,
147
+ error_message="Please enter yes or no.",
148
+ )
149
+ sagemaker_inputs_file = None
150
+ if is_sagemaker_inputs_enabled:
151
+ sagemaker_inputs_file = _ask_field(
152
+ "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
153
+ lambda x: str(x).lower(),
154
+ )
155
+
156
+ is_sagemaker_metrics_enabled = _ask_field(
157
+ "Do you want to enable SageMaker metrics? [yes/NO]: ",
158
+ _convert_yes_no_to_bool,
159
+ default=False,
160
+ error_message="Please enter yes or no.",
161
+ )
162
+ sagemaker_metrics_file = None
163
+ if is_sagemaker_metrics_enabled:
164
+ sagemaker_metrics_file = _ask_field(
165
+ "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
166
+ lambda x: str(x).lower(),
167
+ )
168
+
169
+ distributed_type = _ask_options(
170
+ "What is the distributed mode?",
171
+ ["No distributed training", "Data parallelism"],
172
+ _convert_sagemaker_distributed_mode,
173
+ )
174
+ dynamo_config = {}
175
+ use_dynamo = _ask_field(
176
+ "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
177
+ _convert_yes_no_to_bool,
178
+ default=False,
179
+ error_message="Please enter yes or no.",
180
+ )
181
+ if use_dynamo:
182
+ prefix = "dynamo_"
183
+ dynamo_config[prefix + "backend"] = _ask_options(
184
+ "Which dynamo backend would you like to use?",
185
+ [x.lower() for x in DYNAMO_BACKENDS],
186
+ _convert_dynamo_backend,
187
+ default=2,
188
+ )
189
+ use_custom_options = _ask_field(
190
+ "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
191
+ _convert_yes_no_to_bool,
192
+ default=False,
193
+ error_message="Please enter yes or no.",
194
+ )
195
+
196
+ if use_custom_options:
197
+ dynamo_config[prefix + "mode"] = _ask_options(
198
+ "Which mode do you want to use?",
199
+ TORCH_DYNAMO_MODES,
200
+ lambda x: TORCH_DYNAMO_MODES[int(x)],
201
+ default="default",
202
+ )
203
+ dynamo_config[prefix + "use_fullgraph"] = _ask_field(
204
+ "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
205
+ _convert_yes_no_to_bool,
206
+ default=False,
207
+ error_message="Please enter yes or no.",
208
+ )
209
+ dynamo_config[prefix + "use_dynamic"] = _ask_field(
210
+ "Do you want to enable dynamic shape tracing? [yes/NO]: ",
211
+ _convert_yes_no_to_bool,
212
+ default=False,
213
+ error_message="Please enter yes or no.",
214
+ )
215
+ dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
216
+ "Do you want to enable regional compilation? [yes/NO]: ",
217
+ _convert_yes_no_to_bool,
218
+ default=False,
219
+ error_message="Please enter yes or no.",
220
+ )
221
+
222
+ ec2_instance_query = "Which EC2 instance type you want to use for your training?"
223
+ if distributed_type != SageMakerDistributedType.NO:
224
+ ec2_instance_type = _ask_options(
225
+ ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
226
+ )
227
+ else:
228
+ ec2_instance_query += "? [ml.p3.2xlarge]:"
229
+ ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
230
+
231
+ debug = False
232
+ if distributed_type != SageMakerDistributedType.NO:
233
+ debug = _ask_field(
234
+ "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
235
+ _convert_yes_no_to_bool,
236
+ default=False,
237
+ error_message="Please enter yes or no.",
238
+ )
239
+
240
+ num_machines = 1
241
+ if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
242
+ num_machines = _ask_field(
243
+ "How many machines do you want use? [1]: ",
244
+ int,
245
+ default=1,
246
+ )
247
+
248
+ mixed_precision = _ask_options(
249
+ "Do you wish to use FP16 or BF16 (mixed precision)?",
250
+ ["no", "fp16", "bf16", "fp8"],
251
+ _convert_mixed_precision,
252
+ )
253
+
254
+ if use_dynamo and mixed_precision == "no":
255
+ print(
256
+ "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
257
+ )
258
+
259
+ return SageMakerConfig(
260
+ image_uri=docker_image,
261
+ compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
262
+ distributed_type=distributed_type,
263
+ use_cpu=False,
264
+ dynamo_config=dynamo_config,
265
+ ec2_instance_type=ec2_instance_type,
266
+ profile=aws_profile,
267
+ region=aws_region,
268
+ iam_role_name=iam_role_name,
269
+ mixed_precision=mixed_precision,
270
+ num_machines=num_machines,
271
+ sagemaker_inputs_file=sagemaker_inputs_file,
272
+ sagemaker_metrics_file=sagemaker_metrics_file,
273
+ debug=debug,
274
+ )
accelerate/commands/config/update.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from pathlib import Path
18
+
19
+ from .config_args import default_config_file, load_config_from_file
20
+ from .config_utils import SubcommandHelpFormatter
21
+
22
+
23
+ description = "Update an existing config file with the latest defaults while maintaining the old configuration."
24
+
25
+
26
+ def update_config(args):
27
+ """
28
+ Update an existing config file with the latest defaults while maintaining the old configuration.
29
+ """
30
+ config_file = args.config_file
31
+ if config_file is None and Path(default_config_file).exists():
32
+ config_file = default_config_file
33
+ elif not Path(config_file).exists():
34
+ raise ValueError(f"The passed config file located at {config_file} doesn't exist.")
35
+ config = load_config_from_file(config_file)
36
+
37
+ if config_file.endswith(".json"):
38
+ config.to_json_file(config_file)
39
+ else:
40
+ config.to_yaml_file(config_file)
41
+ return config_file
42
+
43
+
44
+ def update_command_parser(parser, parents):
45
+ parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
46
+ parser.add_argument(
47
+ "--config_file",
48
+ default=None,
49
+ help=(
50
+ "The path to the config file to update. Will default to a file named default_config.yaml in the cache "
51
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
52
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
53
+ "with 'huggingface'."
54
+ ),
55
+ )
56
+
57
+ parser.set_defaults(func=update_config_command)
58
+ return parser
59
+
60
+
61
+ def update_config_command(args):
62
+ config_file = update_config(args)
63
+ print(f"Successfully updated the configuration file at {config_file}.")
accelerate/commands/env.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+ import platform
20
+ import subprocess
21
+
22
+ import numpy as np
23
+ import psutil
24
+ import torch
25
+
26
+ from accelerate import __version__ as version
27
+ from accelerate.commands.config import default_config_file, load_config_from_file
28
+
29
+ from ..utils import (
30
+ is_mlu_available,
31
+ is_musa_available,
32
+ is_neuron_available,
33
+ is_npu_available,
34
+ is_sdaa_available,
35
+ is_xpu_available,
36
+ )
37
+
38
+
39
+ def env_command_parser(subparsers=None):
40
+ if subparsers is not None:
41
+ parser = subparsers.add_parser("env")
42
+ else:
43
+ parser = argparse.ArgumentParser("Accelerate env command")
44
+
45
+ parser.add_argument(
46
+ "--config_file", default=None, help="The config file to use for the default values in the launching script."
47
+ )
48
+
49
+ if subparsers is not None:
50
+ parser.set_defaults(func=env_command)
51
+ return parser
52
+
53
+
54
+ def env_command(args):
55
+ pt_version = torch.__version__
56
+ pt_cuda_available = torch.cuda.is_available()
57
+ pt_xpu_available = is_xpu_available()
58
+ pt_mlu_available = is_mlu_available()
59
+ pt_sdaa_available = is_sdaa_available()
60
+ pt_musa_available = is_musa_available()
61
+ pt_npu_available = is_npu_available()
62
+ pt_neuron_available = is_neuron_available()
63
+
64
+ accelerator = "N/A"
65
+ if pt_cuda_available:
66
+ accelerator = "CUDA"
67
+ elif pt_xpu_available:
68
+ accelerator = "XPU"
69
+ elif pt_mlu_available:
70
+ accelerator = "MLU"
71
+ elif pt_sdaa_available:
72
+ accelerator = "SDAA"
73
+ elif pt_musa_available:
74
+ accelerator = "MUSA"
75
+ elif pt_npu_available:
76
+ accelerator = "NPU"
77
+ elif pt_neuron_available:
78
+ accelerator = "NEURON"
79
+
80
+ accelerate_config = "Not found"
81
+ # Get the default from the config file.
82
+ if args.config_file is not None or os.path.isfile(default_config_file):
83
+ accelerate_config = load_config_from_file(args.config_file).to_dict()
84
+
85
+ # if we can run which, get it
86
+ command = None
87
+ bash_location = "Not found"
88
+ if os.name == "nt":
89
+ command = ["where", "accelerate"]
90
+ elif os.name == "posix":
91
+ command = ["which", "accelerate"]
92
+ if command is not None:
93
+ bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()
94
+ info = {
95
+ "`Accelerate` version": version,
96
+ "Platform": platform.platform(),
97
+ "`accelerate` bash location": bash_location,
98
+ "Python version": platform.python_version(),
99
+ "Numpy version": np.__version__,
100
+ "PyTorch version": f"{pt_version}",
101
+ "PyTorch accelerator": accelerator,
102
+ "System RAM": f"{psutil.virtual_memory().total / 1024**3:.2f} GB",
103
+ }
104
+ if pt_cuda_available:
105
+ info["GPU type"] = torch.cuda.get_device_name()
106
+ elif pt_xpu_available:
107
+ info["XPU type"] = torch.xpu.get_device_name()
108
+ elif pt_mlu_available:
109
+ info["MLU type"] = torch.mlu.get_device_name()
110
+ elif pt_sdaa_available:
111
+ info["SDAA type"] = torch.sdaa.get_device_name()
112
+ elif pt_musa_available:
113
+ info["MUSA type"] = torch.musa.get_device_name()
114
+ elif pt_neuron_available:
115
+ info["NEURON type"] = torch.neuron.get_device_name()
116
+ elif pt_npu_available:
117
+ info["CANN version"] = torch.version.cann
118
+
119
+ print("\nCopy-and-paste the text below in your GitHub issue\n")
120
+ print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]))
121
+
122
+ print("- `Accelerate` default config:" if args.config_file is None else "- `Accelerate` config passed:")
123
+ accelerate_config_str = (
124
+ "\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
125
+ if isinstance(accelerate_config, dict)
126
+ else f"\t{accelerate_config}"
127
+ )
128
+ print(accelerate_config_str)
129
+
130
+ info["`Accelerate` configs"] = accelerate_config
131
+
132
+ return info
133
+
134
+
135
+ def main() -> int:
136
+ parser = env_command_parser()
137
+ args = parser.parse_args()
138
+ env_command(args)
139
+ return 0
140
+
141
+
142
+ if __name__ == "__main__":
143
+ raise SystemExit(main())
accelerate/commands/estimate.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from huggingface_hub import model_info
20
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
21
+
22
+ from accelerate import init_empty_weights
23
+ from accelerate.commands.utils import CustomArgumentParser
24
+ from accelerate.utils import (
25
+ calculate_maximum_sizes,
26
+ convert_bytes,
27
+ is_timm_available,
28
+ is_transformers_available,
29
+ )
30
+
31
+
32
+ if is_transformers_available():
33
+ import transformers
34
+ from transformers import AutoConfig, AutoModel
35
+
36
+ if is_timm_available():
37
+ import timm
38
+
39
+
40
+ def verify_on_hub(repo: str, token: Optional[str] = None):
41
+ "Verifies that the model is on the hub and returns the model info."
42
+ try:
43
+ return model_info(repo, token=token)
44
+ except (OSError, GatedRepoError):
45
+ return "gated"
46
+ except RepositoryNotFoundError:
47
+ return "repo"
48
+
49
+
50
+ def check_has_model(error):
51
+ """
52
+ Checks what library spawned `error` when a model is not found
53
+ """
54
+ if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]:
55
+ return "timm"
56
+ elif (
57
+ is_transformers_available()
58
+ and isinstance(error, OSError)
59
+ and "does not appear to have a file named" in error.args[0]
60
+ ):
61
+ return "transformers"
62
+ else:
63
+ return "unknown"
64
+
65
+
66
+ def create_empty_model(
67
+ model_name: str, library_name: str, trust_remote_code: bool = False, access_token: Optional[str] = None
68
+ ):
69
+ """
70
+ Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory
71
+ consumption.
72
+
73
+ Args:
74
+ model_name (`str`):
75
+ The model name on the Hub
76
+ library_name (`str`):
77
+ The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no
78
+ metadata on the Hub to determine the library.
79
+ trust_remote_code (`bool`, `optional`, defaults to `False`):
80
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
81
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
82
+ execute code present on the Hub on your local machine.
83
+ access_token (`str`, `optional`, defaults to `None`):
84
+ The access token to use to access private or gated models on the Hub. (for use on the Gradio app)
85
+
86
+ Returns:
87
+ `torch.nn.Module`: The torch model that has been initialized on the `meta` device.
88
+
89
+ """
90
+ model_info = verify_on_hub(model_name, access_token)
91
+ # Simplified errors
92
+ if model_info == "gated":
93
+ raise OSError(
94
+ f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`."
95
+ )
96
+ elif model_info == "repo":
97
+ raise OSError(
98
+ f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo,"
99
+ " make sure you are authenticated via `huggingface-cli login` and have access."
100
+ )
101
+ if library_name is None:
102
+ library_name = getattr(model_info, "library_name", False)
103
+ if not library_name:
104
+ raise ValueError(
105
+ f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)"
106
+ )
107
+ if library_name == "transformers":
108
+ if not is_transformers_available():
109
+ raise ImportError(
110
+ f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`"
111
+ )
112
+ print(f"Loading pretrained config for `{model_name}` from `transformers`...")
113
+ if model_info.config is None:
114
+ raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.")
115
+
116
+ auto_map = model_info.config.get("auto_map", False)
117
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token)
118
+ with init_empty_weights():
119
+ # remote code could specify a specific `AutoModel` class in the `auto_map`
120
+ constructor = AutoModel
121
+ if isinstance(auto_map, dict):
122
+ value = None
123
+ for key in auto_map.keys():
124
+ if key.startswith("AutoModelFor"):
125
+ value = key
126
+ break
127
+ if value is not None:
128
+ constructor = getattr(transformers, value)
129
+ # we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config
130
+ model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code)
131
+ elif library_name == "timm":
132
+ if not is_timm_available():
133
+ raise ImportError(
134
+ f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`"
135
+ )
136
+ print(f"Loading pretrained config for `{model_name}` from `timm`...")
137
+ with init_empty_weights():
138
+ model = timm.create_model(model_name, pretrained=False)
139
+ else:
140
+ raise ValueError(
141
+ f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support."
142
+ )
143
+ return model
144
+
145
+
146
+ def create_ascii_table(headers: list, rows: list, title: str):
147
+ "Creates a pretty table from a list of rows, minimal version of `tabulate`."
148
+ sep_char, in_between = "│", "─"
149
+ column_widths = []
150
+ for i in range(len(headers)):
151
+ column_values = [row[i] for row in rows] + [headers[i]]
152
+ max_column_width = max(len(value) for value in column_values)
153
+ column_widths.append(max_column_width)
154
+
155
+ formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))]
156
+
157
+ pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}"
158
+ diff = 0
159
+
160
+ def make_row(left_char, middle_char, right_char):
161
+ return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}"
162
+
163
+ separator = make_row("├", "┼", "┤")
164
+ if len(title) > sum(column_widths):
165
+ diff = abs(len(title) - len(separator))
166
+ column_widths[-1] += diff
167
+
168
+ # Update with diff
169
+ separator = make_row("├", "┼", "┤")
170
+ initial_rows = [
171
+ make_row("┌", in_between, "┐"),
172
+ f"{sep_char}{title.center(len(separator) - 2)}{sep_char}",
173
+ make_row("├", "┬", "┤"),
174
+ ]
175
+ table = "\n".join(initial_rows) + "\n"
176
+ column_widths[-1] += diff
177
+ centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)]
178
+ table += f"{pattern % tuple(centered_line)}\n{separator}\n"
179
+ for i, line in enumerate(rows):
180
+ centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)]
181
+ table += f"{pattern % tuple(centered_line)}\n"
182
+ table += f"└{'┴'.join([in_between * n for n in column_widths])}┘"
183
+
184
+ return table
185
+
186
+
187
+ def estimate_command_parser(subparsers=None):
188
+ if subparsers is not None:
189
+ parser = subparsers.add_parser("estimate-memory")
190
+ else:
191
+ parser = CustomArgumentParser(
192
+ description="Model size estimator for fitting a model onto device(e.g. cuda, xpu) memory."
193
+ )
194
+
195
+ parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.")
196
+ parser.add_argument(
197
+ "--library_name",
198
+ type=str,
199
+ help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.",
200
+ choices=["timm", "transformers"],
201
+ )
202
+ parser.add_argument(
203
+ "--dtypes",
204
+ type=str,
205
+ nargs="+",
206
+ default=["float32", "float16", "int8", "int4"],
207
+ help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`",
208
+ choices=["float32", "float16", "int8", "int4"],
209
+ )
210
+ parser.add_argument(
211
+ "--trust_remote_code",
212
+ action="store_true",
213
+ help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag
214
+ should only be used for repositories you trust and in which you have read the code, as it will execute
215
+ code present on the Hub on your local machine.""",
216
+ default=False,
217
+ )
218
+
219
+ if subparsers is not None:
220
+ parser.set_defaults(func=estimate_command)
221
+ return parser
222
+
223
+
224
+ def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: Optional[str] = None) -> dict:
225
+ """
226
+ Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of
227
+ 1.
228
+
229
+ Args:
230
+ bytes (`int`):
231
+ The size of the model being trained.
232
+ mixed_precision (`str`):
233
+ The mixed precision that would be ran.
234
+ msamp_config (`str`):
235
+ The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`.
236
+ """
237
+ memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1}
238
+ fp32_size = bytes
239
+ fp16_size = bytes // 2
240
+
241
+ if mixed_precision == "float32":
242
+ memory_sizes["model"] = fp32_size
243
+ memory_sizes["gradients"] = fp32_size
244
+ memory_sizes["optimizer"] = fp32_size * 2
245
+ memory_sizes["step"] = fp32_size * 4
246
+ elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None):
247
+ # With native `TransformersEngine`, there is no memory savings with FP8
248
+ # With mixed precision training, the model has weights stored
249
+ # in FP16 and FP32
250
+ memory_sizes["model"] = fp32_size
251
+ # 1.5 from weight gradient + computation (GEMM)
252
+ memory_sizes["gradients"] = fp32_size + fp16_size
253
+ # 2x from optimizer states
254
+ memory_sizes["optimizer"] = fp32_size * 2 # Optimizer states
255
+ memory_sizes["step"] = memory_sizes["optimizer"]
256
+ return memory_sizes
257
+
258
+
259
+ def gather_data(args):
260
+ "Creates an empty model and gathers the data for the sizes"
261
+ try:
262
+ model = create_empty_model(
263
+ args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code
264
+ )
265
+ except (RuntimeError, OSError) as e:
266
+ library = check_has_model(e)
267
+ if library != "unknown":
268
+ raise RuntimeError(
269
+ f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo."
270
+ )
271
+ raise e
272
+
273
+ total_size, largest_layer = calculate_maximum_sizes(model)
274
+
275
+ data = []
276
+
277
+ for dtype in args.dtypes:
278
+ dtype_total_size = total_size
279
+ dtype_largest_layer = largest_layer[0]
280
+ dtype_training_size = estimate_training_usage(dtype_total_size, dtype)
281
+ if dtype == "float16":
282
+ dtype_total_size /= 2
283
+ dtype_largest_layer /= 2
284
+ elif dtype == "int8":
285
+ dtype_total_size /= 4
286
+ dtype_largest_layer /= 4
287
+ elif dtype == "int4":
288
+ dtype_total_size /= 8
289
+ dtype_largest_layer /= 8
290
+ data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size])
291
+ return data
292
+
293
+
294
+ def estimate_command(args):
295
+ data = gather_data(args)
296
+ for row in data:
297
+ for i, item in enumerate(row):
298
+ if isinstance(item, (int, float)):
299
+ row[i] = convert_bytes(item)
300
+ elif isinstance(item, dict):
301
+ training_usage = max(item.values())
302
+ row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A"
303
+
304
+ headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"]
305
+
306
+ title = f"Memory Usage for loading `{args.model_name}`"
307
+ table = create_ascii_table(headers, data, title)
308
+ print(table)
309
+
310
+
311
+ def main():
312
+ parser = estimate_command_parser()
313
+ args = parser.parse_args()
314
+ estimate_command(args)
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()
accelerate/commands/launch.py ADDED
@@ -0,0 +1,1415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import importlib
19
+ import logging
20
+ import os
21
+ import subprocess
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ import torch
26
+
27
+ from accelerate.commands.config import default_config_file, load_config_from_file
28
+ from accelerate.commands.config.config_args import SageMakerConfig
29
+ from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
30
+ from accelerate.commands.utils import CustomArgumentParser
31
+ from accelerate.state import get_int_from_env
32
+ from accelerate.utils import (
33
+ ComputeEnvironment,
34
+ DistributedType,
35
+ PrepareForLaunch,
36
+ _filter_args,
37
+ check_cuda_p2p_ib_support,
38
+ convert_dict_to_env_variables,
39
+ is_bf16_available,
40
+ is_deepspeed_available,
41
+ is_hpu_available,
42
+ is_mlu_available,
43
+ is_musa_available,
44
+ is_neuron_available,
45
+ is_npu_available,
46
+ is_rich_available,
47
+ is_sagemaker_available,
48
+ is_sdaa_available,
49
+ is_torch_xla_available,
50
+ is_xpu_available,
51
+ patch_environment,
52
+ prepare_deepspeed_cmd_env,
53
+ prepare_multi_gpu_env,
54
+ prepare_sagemager_args_inputs,
55
+ prepare_simple_launcher_cmd_env,
56
+ prepare_tpu,
57
+ str_to_bool,
58
+ )
59
+ from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
60
+
61
+
62
+ if is_rich_available():
63
+ from rich import get_console
64
+ from rich.logging import RichHandler
65
+
66
+ FORMAT = "%(message)s"
67
+ logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()])
68
+
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ options_to_group = {
74
+ "multi_gpu": "Distributed GPUs",
75
+ "tpu": "TPU",
76
+ "use_deepspeed": "DeepSpeed Arguments",
77
+ "use_fsdp": "FSDP Arguments",
78
+ "use_megatron_lm": "Megatron-LM Arguments",
79
+ "fp8_backend": "FP8 Arguments",
80
+ }
81
+
82
+
83
+ def clean_option(option):
84
+ "Finds all cases of - after the first two characters and changes them to _"
85
+ if "fp8_backend" in option:
86
+ option = "--fp8_backend"
87
+ if option.startswith("--"):
88
+ return option[2:].replace("-", "_")
89
+
90
+
91
+ class CustomHelpFormatter(argparse.HelpFormatter):
92
+ """
93
+ This is a custom help formatter that will hide all arguments that are not used in the command line when the help is
94
+ called. This is useful for the case where the user is using a specific platform and only wants to see the arguments
95
+ for that platform.
96
+ """
97
+
98
+ def __init__(self, *args, **kwargs):
99
+ super().__init__(*args, **kwargs)
100
+ self.titles = [
101
+ "Hardware Selection Arguments",
102
+ "Resource Selection Arguments",
103
+ "Training Paradigm Arguments",
104
+ "positional arguments",
105
+ "optional arguments",
106
+ ]
107
+
108
+ def add_argument(self, action: argparse.Action):
109
+ if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]:
110
+ args = sys.argv[2:]
111
+ else:
112
+ args = sys.argv[1:]
113
+
114
+ if len(args) > 1:
115
+ args = list(map(clean_option, args))
116
+ used_platforms = [arg for arg in args if arg in options_to_group.keys()]
117
+ used_titles = [options_to_group[o] for o in used_platforms]
118
+ if action.container.title not in self.titles + used_titles:
119
+ action.help = argparse.SUPPRESS
120
+ elif action.container.title == "Hardware Selection Arguments":
121
+ if set(action.option_strings).isdisjoint(set(args)):
122
+ action.help = argparse.SUPPRESS
123
+ else:
124
+ action.help = action.help + " (currently selected)"
125
+ elif action.container.title == "Training Paradigm Arguments":
126
+ if set(action.option_strings).isdisjoint(set(args)):
127
+ action.help = argparse.SUPPRESS
128
+ else:
129
+ action.help = action.help + " (currently selected)"
130
+
131
+ action.option_strings = [s for s in action.option_strings if "-" not in s[2:]]
132
+ super().add_argument(action)
133
+
134
+ def end_section(self):
135
+ if len(self._current_section.items) < 2:
136
+ self._current_section.items = []
137
+ self._current_section.heading = ""
138
+ super().end_section()
139
+
140
+
141
+ def launch_command_parser(subparsers=None):
142
+ description = "Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)"
143
+ if subparsers is not None:
144
+ parser = subparsers.add_parser(
145
+ "launch", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter
146
+ )
147
+ else:
148
+ parser = CustomArgumentParser(
149
+ "Accelerate launch command",
150
+ description=description,
151
+ add_help=False,
152
+ allow_abbrev=False,
153
+ formatter_class=CustomHelpFormatter,
154
+ )
155
+
156
+ parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.")
157
+
158
+ parser.add_argument(
159
+ "--config_file",
160
+ default=None,
161
+ help="The config file to use for the default values in the launching script.",
162
+ )
163
+ parser.add_argument(
164
+ "--quiet",
165
+ "-q",
166
+ action="store_true",
167
+ help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)",
168
+ )
169
+ # Hardware selection arguments
170
+ hardware_args = parser.add_argument_group(
171
+ "Hardware Selection Arguments", "Arguments for selecting the hardware to be used."
172
+ )
173
+ hardware_args.add_argument(
174
+ "--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU."
175
+ )
176
+ hardware_args.add_argument(
177
+ "--multi_gpu",
178
+ default=False,
179
+ action="store_true",
180
+ help="Whether or not this should launch a distributed GPU training.",
181
+ )
182
+ hardware_args.add_argument(
183
+ "--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
184
+ )
185
+ # Resource selection arguments
186
+ resource_args = parser.add_argument_group(
187
+ "Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
188
+ )
189
+ resource_args.add_argument(
190
+ "--mixed_precision",
191
+ type=str,
192
+ choices=["no", "fp16", "bf16", "fp8"],
193
+ help="Whether or not to use mixed precision training. "
194
+ "Choose between FP16 and BF16 (bfloat16) training. "
195
+ "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
196
+ )
197
+ resource_args.add_argument(
198
+ "--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel."
199
+ )
200
+ resource_args.add_argument(
201
+ "--num_machines", type=int, default=None, help="The total number of machines used in this training."
202
+ )
203
+ resource_args.add_argument(
204
+ "--num_cpu_threads_per_process",
205
+ type=int,
206
+ default=None,
207
+ help="The number of CPU threads per process. Can be tuned for optimal performance.",
208
+ )
209
+ resource_args.add_argument(
210
+ "--enable_cpu_affinity",
211
+ default=False,
212
+ action="store_true",
213
+ help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.",
214
+ )
215
+ # Dynamo arguments
216
+ resource_args.add_argument(
217
+ "--dynamo_backend",
218
+ type=str,
219
+ choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS],
220
+ help="Choose a backend to optimize your training with dynamo, see more at "
221
+ "https://github.com/pytorch/torchdynamo.",
222
+ )
223
+ resource_args.add_argument(
224
+ "--dynamo_mode",
225
+ type=str,
226
+ default="default",
227
+ choices=TORCH_DYNAMO_MODES,
228
+ help="Choose a mode to optimize your training with dynamo.",
229
+ )
230
+ resource_args.add_argument(
231
+ "--dynamo_use_fullgraph",
232
+ default=False,
233
+ action="store_true",
234
+ help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
235
+ )
236
+ resource_args.add_argument(
237
+ "--dynamo_use_dynamic",
238
+ default=False,
239
+ action="store_true",
240
+ help="Whether to enable dynamic shape tracing.",
241
+ )
242
+ resource_args.add_argument(
243
+ "--dynamo_use_regional_compilation",
244
+ default=False,
245
+ action="store_true",
246
+ help="Whether to enable regional compilation.",
247
+ )
248
+
249
+ # Training Paradigm arguments
250
+ paradigm_args = parser.add_argument_group(
251
+ "Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used."
252
+ )
253
+ paradigm_args.add_argument(
254
+ "--use_deepspeed",
255
+ default=False,
256
+ action="store_true",
257
+ help="Whether to use deepspeed.",
258
+ )
259
+ paradigm_args.add_argument(
260
+ "--use_fsdp",
261
+ default=False,
262
+ action="store_true",
263
+ help="Whether to use fsdp.",
264
+ )
265
+ paradigm_args.add_argument(
266
+ "--use_parallelism_config",
267
+ default=False,
268
+ action="store_true",
269
+ help="Whether to use the parallelism config to configure the N-d distributed training.",
270
+ )
271
+ paradigm_args.add_argument(
272
+ "--use_megatron_lm",
273
+ default=False,
274
+ action="store_true",
275
+ help="Whether to use Megatron-LM.",
276
+ )
277
+
278
+ # distributed GPU training arguments
279
+ distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.")
280
+ distributed_args.add_argument(
281
+ "--gpu_ids",
282
+ default=None,
283
+ help="What GPUs (by id) should be used for training on this machine as a comma-separated list",
284
+ )
285
+ distributed_args.add_argument(
286
+ "--same_network",
287
+ default=False,
288
+ action="store_true",
289
+ help="Whether all machines used for multinode training exist on the same local network.",
290
+ )
291
+ distributed_args.add_argument(
292
+ "--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched."
293
+ )
294
+ distributed_args.add_argument(
295
+ "--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0."
296
+ )
297
+ distributed_args.add_argument(
298
+ "--main_process_port",
299
+ type=int,
300
+ default=None,
301
+ help="The port to use to communicate with the machine of rank 0.",
302
+ )
303
+ distributed_args.add_argument(
304
+ "-t",
305
+ "--tee",
306
+ default="0",
307
+ type=str,
308
+ help="Tee std streams into a log file and also to console.",
309
+ )
310
+ distributed_args.add_argument(
311
+ "--log_dir",
312
+ type=str,
313
+ default=None,
314
+ help=(
315
+ "Base directory to use for log files when using torchrun/torch.distributed.run as launcher. "
316
+ "Use with --tee to redirect std streams info log files."
317
+ ),
318
+ )
319
+ distributed_args.add_argument(
320
+ "--role",
321
+ type=str,
322
+ default="default",
323
+ help="User-defined role for the workers.",
324
+ )
325
+ # Rendezvous related arguments
326
+ distributed_args.add_argument(
327
+ "--rdzv_backend",
328
+ type=str,
329
+ default="static",
330
+ help="The rendezvous method to use, such as 'static' (the default) or 'c10d'",
331
+ )
332
+ distributed_args.add_argument(
333
+ "--rdzv_conf",
334
+ type=str,
335
+ default="",
336
+ help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
337
+ )
338
+ distributed_args.add_argument(
339
+ "--max_restarts",
340
+ type=int,
341
+ default=0,
342
+ help="Maximum number of worker group restarts before failing.",
343
+ )
344
+ distributed_args.add_argument(
345
+ "--monitor_interval",
346
+ type=float,
347
+ default=0.1,
348
+ help="Interval, in seconds, to monitor the state of workers.",
349
+ )
350
+ parser.add_argument(
351
+ "-m",
352
+ "--module",
353
+ action="store_true",
354
+ help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.",
355
+ )
356
+ parser.add_argument(
357
+ "--no_python",
358
+ action="store_true",
359
+ help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.",
360
+ )
361
+
362
+ # TPU arguments
363
+ tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.")
364
+ tpu_args.add_argument(
365
+ "--tpu_cluster",
366
+ action="store_true",
367
+ dest="tpu_use_cluster",
368
+ help="Whether to use a GCP TPU pod for training.",
369
+ )
370
+ tpu_args.add_argument(
371
+ "--no_tpu_cluster",
372
+ action="store_false",
373
+ dest="tpu_use_cluster",
374
+ help="Should not be passed explicitly, this is for internal use only.",
375
+ )
376
+ tpu_args.add_argument(
377
+ "--tpu_use_sudo",
378
+ action="store_true",
379
+ help="Whether to use `sudo` when running the TPU training script in each pod.",
380
+ )
381
+ tpu_args.add_argument(
382
+ "--vm",
383
+ type=str,
384
+ action="append",
385
+ help=(
386
+ "List of single Compute VM instance names. "
387
+ "If not provided we assume usage of instance groups. For TPU pods."
388
+ ),
389
+ )
390
+ tpu_args.add_argument(
391
+ "--env",
392
+ type=str,
393
+ action="append",
394
+ help="List of environment variables to set on the Compute VM instances. For TPU pods.",
395
+ )
396
+ tpu_args.add_argument(
397
+ "--main_training_function",
398
+ type=str,
399
+ default=None,
400
+ help="The name of the main function to be executed in your script (only for TPU training).",
401
+ )
402
+ tpu_args.add_argument(
403
+ "--downcast_bf16",
404
+ action="store_true",
405
+ help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.",
406
+ )
407
+
408
+ # DeepSpeed arguments
409
+ deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.")
410
+ deepspeed_args.add_argument(
411
+ "--deepspeed_config_file",
412
+ default=None,
413
+ type=str,
414
+ help="DeepSpeed config file.",
415
+ )
416
+ deepspeed_args.add_argument(
417
+ "--zero_stage",
418
+ default=None,
419
+ type=int,
420
+ help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). "
421
+ "If unspecified, will default to `2`.",
422
+ )
423
+ deepspeed_args.add_argument(
424
+ "--offload_optimizer_device",
425
+ default=None,
426
+ type=str,
427
+ help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
428
+ "If unspecified, will default to 'none'.",
429
+ )
430
+ deepspeed_args.add_argument(
431
+ "--offload_param_device",
432
+ default=None,
433
+ type=str,
434
+ help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). "
435
+ "If unspecified, will default to 'none'.",
436
+ )
437
+ deepspeed_args.add_argument(
438
+ "--offload_optimizer_nvme_path",
439
+ default=None,
440
+ type=str,
441
+ help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
442
+ "If unspecified, will default to 'none'.",
443
+ )
444
+ deepspeed_args.add_argument(
445
+ "--offload_param_nvme_path",
446
+ default=None,
447
+ type=str,
448
+ help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). "
449
+ "If unspecified, will default to 'none'.",
450
+ )
451
+ deepspeed_args.add_argument(
452
+ "--gradient_accumulation_steps",
453
+ default=None,
454
+ type=int,
455
+ help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). "
456
+ "If unspecified, will default to `1`.",
457
+ )
458
+ deepspeed_args.add_argument(
459
+ "--gradient_clipping",
460
+ default=None,
461
+ type=float,
462
+ help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). "
463
+ "If unspecified, will default to `1.0`.",
464
+ )
465
+ deepspeed_args.add_argument(
466
+ "--zero3_init_flag",
467
+ default=None,
468
+ type=str,
469
+ help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. "
470
+ "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.",
471
+ )
472
+ deepspeed_args.add_argument(
473
+ "--zero3_save_16bit_model",
474
+ default=None,
475
+ type=str,
476
+ help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
477
+ "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.",
478
+ )
479
+ deepspeed_args.add_argument(
480
+ "--deepspeed_hostfile",
481
+ default=None,
482
+ type=str,
483
+ help="DeepSpeed hostfile for configuring multi-node compute resources.",
484
+ )
485
+ deepspeed_args.add_argument(
486
+ "--deepspeed_exclusion_filter",
487
+ default=None,
488
+ type=str,
489
+ help="DeepSpeed exclusion filter string when using multi-node setup.",
490
+ )
491
+ deepspeed_args.add_argument(
492
+ "--deepspeed_inclusion_filter",
493
+ default=None,
494
+ type=str,
495
+ help="DeepSpeed inclusion filter string when using multi-node setup.",
496
+ )
497
+ deepspeed_args.add_argument(
498
+ "--deepspeed_multinode_launcher",
499
+ default=None,
500
+ type=str,
501
+ help="DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.",
502
+ )
503
+ deepspeed_args.add_argument(
504
+ "--deepspeed_moe_layer_cls_names",
505
+ default=None,
506
+ type=str,
507
+ help="comma-separated list of transformer MoE layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..."
508
+ " (useful only when `use_deepspeed` flag is passed).",
509
+ )
510
+
511
+ # fsdp arguments
512
+ fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.")
513
+ fsdp_args.add_argument(
514
+ "--fsdp_version",
515
+ type=str,
516
+ default="1",
517
+ choices=["1", "2"],
518
+ help="FSDP version to use. (useful only when `use_fsdp` flag is passed).",
519
+ )
520
+ fsdp_args.add_argument(
521
+ "--fsdp_offload_params",
522
+ default="false",
523
+ type=str,
524
+ help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).",
525
+ )
526
+ fsdp_args.add_argument(
527
+ "--fsdp_min_num_params",
528
+ type=int,
529
+ default=int(1e8),
530
+ help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).",
531
+ )
532
+ # We enable this for backwards compatibility, throw a warning if this is set in `FullyShardedDataParallelPlugin`
533
+ fsdp_args.add_argument(
534
+ "--fsdp_sharding_strategy",
535
+ type=str,
536
+ default="FULL_SHARD",
537
+ help="FSDP's sharding strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version=1`).",
538
+ )
539
+ fsdp_args.add_argument(
540
+ "--fsdp_reshard_after_forward",
541
+ type=str,
542
+ default="true",
543
+ help="FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).",
544
+ )
545
+ fsdp_args.add_argument(
546
+ "--fsdp_auto_wrap_policy",
547
+ type=str,
548
+ default=None,
549
+ help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).",
550
+ )
551
+ fsdp_args.add_argument(
552
+ "--fsdp_transformer_layer_cls_to_wrap",
553
+ default=None,
554
+ type=str,
555
+ help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
556
+ "(useful only when `use_fsdp` flag is passed).",
557
+ )
558
+ fsdp_args.add_argument(
559
+ "--fsdp_backward_prefetch",
560
+ default=None,
561
+ type=str,
562
+ help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
563
+ )
564
+ fsdp_args.add_argument(
565
+ "--fsdp_state_dict_type",
566
+ default=None,
567
+ type=str,
568
+ help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).",
569
+ )
570
+ fsdp_args.add_argument(
571
+ "--fsdp_forward_prefetch",
572
+ default="false",
573
+ type=str,
574
+ help="If True, then FSDP explicitly prefetches the next upcoming "
575
+ "all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).",
576
+ )
577
+ fsdp_args.add_argument(
578
+ "--fsdp_use_orig_params",
579
+ default="true",
580
+ type=str,
581
+ help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters."
582
+ " (useful only when `use_fsdp` flag is passed).",
583
+ )
584
+ fsdp_args.add_argument(
585
+ "--fsdp_cpu_ram_efficient_loading",
586
+ default="true",
587
+ type=str,
588
+ help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
589
+ "Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. "
590
+ "(useful only when `use_fsdp` flag is passed).",
591
+ )
592
+ fsdp_args.add_argument(
593
+ "--fsdp_sync_module_states",
594
+ default="true",
595
+ type=str,
596
+ help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
597
+ " (useful only when `use_fsdp` flag is passed).",
598
+ )
599
+ fsdp_args.add_argument(
600
+ "--fsdp_activation_checkpointing",
601
+ default="false",
602
+ type=str,
603
+ help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
604
+ )
605
+
606
+ # megatron_lm args
607
+ megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
608
+ megatron_lm_args.add_argument(
609
+ "--megatron_lm_tp_degree",
610
+ type=int,
611
+ default=1,
612
+ help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).",
613
+ )
614
+ megatron_lm_args.add_argument(
615
+ "--megatron_lm_use_custom_fsdp",
616
+ type=bool,
617
+ default=False,
618
+ help="Whether to use custom FSDP. (useful only when `use_megatron_lm` flag is passed).",
619
+ )
620
+ megatron_lm_args.add_argument(
621
+ "--megatron_lm_no_load_optim",
622
+ type=bool,
623
+ default=False,
624
+ help="Whether to not load optimizer. (useful only when `use_megatron_lm` flag is passed).",
625
+ )
626
+ megatron_lm_args.add_argument(
627
+ "--megatron_lm_eod_mask_loss",
628
+ type=bool,
629
+ default=False,
630
+ help="Whether to use eod mask loss. (useful only when `use_megatron_lm` flag is passed).",
631
+ )
632
+ megatron_lm_args.add_argument(
633
+ "--megatron_lm_overlap_cpu_optimizer_d2h_h2d",
634
+ type=bool,
635
+ default=False,
636
+ help="Whether to overlap CPU optimizer step, gradients D2H and updated parameters H2D. (useful only when `use_megatron_lm` flag is passed).",
637
+ )
638
+ megatron_lm_args.add_argument(
639
+ "--megatron_lm_no_save_optim",
640
+ type=bool,
641
+ default=False,
642
+ help="Whether to not save optimizer. (useful only when `use_megatron_lm` flag is passed).",
643
+ )
644
+ megatron_lm_args.add_argument(
645
+ "--megatron_lm_optimizer_cpu_offload",
646
+ type=bool,
647
+ default=False,
648
+ help="Whether to use CPU offload for optimizer. (useful only when `use_megatron_lm` flag is passed).",
649
+ )
650
+ megatron_lm_args.add_argument(
651
+ "--megatron_lm_use_precision_aware_optimizer",
652
+ type=bool,
653
+ default=False,
654
+ help="Whether to use precision aware optimizer. (useful only when `use_megatron_lm` flag is passed).",
655
+ )
656
+ megatron_lm_args.add_argument(
657
+ "--megatron_lm_decoder_last_pipeline_num_layers",
658
+ type=int,
659
+ default=None,
660
+ help="Megatron-LM's decoder last pipeline number of layers, default None is even split of transformer layers across all pipeline stages.",
661
+ )
662
+ megatron_lm_args.add_argument(
663
+ "--megatron_lm_pp_degree",
664
+ type=int,
665
+ default=1,
666
+ help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).",
667
+ )
668
+ megatron_lm_args.add_argument(
669
+ "--megatron_lm_num_micro_batches",
670
+ type=int,
671
+ default=None,
672
+ help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).",
673
+ )
674
+ megatron_lm_args.add_argument(
675
+ "--megatron_lm_sequence_parallelism",
676
+ default=None,
677
+ type=str,
678
+ help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. "
679
+ "(useful only when `use_megatron_lm` flag is passed).",
680
+ )
681
+ megatron_lm_args.add_argument(
682
+ "--megatron_lm_recompute_activations",
683
+ default=None,
684
+ type=str,
685
+ help="Decides Whether (true|false) to enable Selective Activation Recomputation. "
686
+ "(useful only when `use_megatron_lm` flag is passed).",
687
+ )
688
+ megatron_lm_args.add_argument(
689
+ "--megatron_lm_use_distributed_optimizer",
690
+ default=None,
691
+ type=str,
692
+ help="Decides Whether (true|false) to use distributed optimizer "
693
+ "which shards optimizer state and gradients across Data Pralellel (DP) ranks. "
694
+ "(useful only when `use_megatron_lm` flag is passed).",
695
+ )
696
+ megatron_lm_args.add_argument(
697
+ "--megatron_lm_gradient_clipping",
698
+ default=1.0,
699
+ type=float,
700
+ help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). "
701
+ "(useful only when `use_megatron_lm` flag is passed).",
702
+ )
703
+ megatron_lm_args.add_argument(
704
+ "--megatron_lm_recompute_granularity",
705
+ default=None,
706
+ type=str,
707
+ help="Megatron-LM's recompute granularity (full, selective). "
708
+ "(useful only when `use_megatron_lm` flag is passed).",
709
+ )
710
+ megatron_lm_args.add_argument(
711
+ "--megatron_lm_recompute_method",
712
+ default=None,
713
+ type=str,
714
+ help="Megatron-LM's recompute method (uniform, block). (useful only when `use_megatron_lm` flag is passed).",
715
+ )
716
+ megatron_lm_args.add_argument(
717
+ "--megatron_lm_recompute_num_layers",
718
+ default=None,
719
+ type=int,
720
+ help="Megatron-LM's number of layers to recompute. (useful only when `use_megatron_lm` flag is passed).",
721
+ )
722
+ megatron_lm_args.add_argument(
723
+ "--megatron_lm_attention_backend",
724
+ default=None,
725
+ type=str,
726
+ help="Decides Whether (true|false) to enable attention backend. "
727
+ "(useful only when `use_megatron_lm` flag is passed).",
728
+ )
729
+ megatron_lm_args.add_argument(
730
+ "--megatron_lm_expert_model_parallel_size",
731
+ default=None,
732
+ type=int,
733
+ help="Megatron-LM's expert model parallel size. (useful only when `use_megatron_lm` flag is passed).",
734
+ )
735
+ megatron_lm_args.add_argument(
736
+ "--megatron_lm_context_parallel_size",
737
+ default=None,
738
+ type=int,
739
+ help="Megatron-LM's context parallel size. (useful only when `use_megatron_lm` flag is passed).",
740
+ )
741
+ megatron_lm_args.add_argument(
742
+ "--megatron_lm_attention_dropout",
743
+ default=None,
744
+ type=float,
745
+ help="Megatron-LM's attention dropout rate. (useful only when `use_megatron_lm` flag is passed).",
746
+ )
747
+ megatron_lm_args.add_argument(
748
+ "--megatron_lm_hidden_dropout",
749
+ default=None,
750
+ type=float,
751
+ help="Megatron-LM's hidden dropout rate. (useful only when `use_megatron_lm` flag is passed).",
752
+ )
753
+ megatron_lm_args.add_argument(
754
+ "--megatron_lm_attention_softmax_in_fp32",
755
+ default=None,
756
+ type=str,
757
+ help="Decides Whether (true|false) to use fp32 for attention softmax. "
758
+ "(useful only when `use_megatron_lm` flag is passed).",
759
+ )
760
+ megatron_lm_args.add_argument(
761
+ "--megatron_lm_expert_tensor_parallel_size",
762
+ default=None,
763
+ type=int,
764
+ help="Megatron-LM's expert tensor parallel size. (useful only when `use_megatron_lm` flag is passed).",
765
+ )
766
+ megatron_lm_args.add_argument(
767
+ "--megatron_lm_calculate_per_token_loss",
768
+ default=None,
769
+ type=str,
770
+ help="Decides Whether (true|false) to calculate per token loss. "
771
+ "(useful only when `use_megatron_lm` flag is passed).",
772
+ )
773
+ megatron_lm_args.add_argument(
774
+ "--megatron_lm_use_rotary_position_embeddings",
775
+ default=None,
776
+ type=str,
777
+ help="Decides Whether (true|false) to use rotary position embeddings. "
778
+ "(useful only when `use_megatron_lm` flag is passed).",
779
+ )
780
+
781
+ # FP8 arguments
782
+ fp8_args = parser.add_argument_group(
783
+ "FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)"
784
+ )
785
+ fp8_args.add_argument(
786
+ "--fp8_backend",
787
+ type=str,
788
+ choices=["ao", "te", "msamp"],
789
+ help="Choose a backend to train with FP8 (ao: torchao, te: TransformerEngine, msamp: MS-AMP)",
790
+ )
791
+ fp8_args.add_argument(
792
+ "--fp8_use_autocast_during_eval",
793
+ default=False,
794
+ action="store_true",
795
+ help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.",
796
+ )
797
+ fp8_args.add_argument(
798
+ "--fp8_margin",
799
+ type=int,
800
+ default=0,
801
+ help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).",
802
+ )
803
+ fp8_args.add_argument(
804
+ "--fp8_interval",
805
+ type=int,
806
+ default=1,
807
+ help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).",
808
+ )
809
+ fp8_args.add_argument(
810
+ "--fp8_format",
811
+ type=str,
812
+ default="HYBRID",
813
+ choices=["HYBRID", "E4M3", "E5M2"],
814
+ help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
815
+ )
816
+ fp8_args.add_argument(
817
+ "--fp8_amax_history_len",
818
+ type=int,
819
+ default=1024,
820
+ help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).",
821
+ )
822
+ fp8_args.add_argument(
823
+ "--fp8_amax_compute_algo",
824
+ type=str,
825
+ default="most_recent",
826
+ choices=["max", "most_recent"],
827
+ help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).",
828
+ )
829
+ fp8_args.add_argument(
830
+ "--fp8_override_linear_precision",
831
+ type=lambda x: tuple(map(str_to_bool, x.split(","))),
832
+ default=(False, False, False),
833
+ help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-separated string of booleans (useful only when `--fp8_backend=te` is passed).",
834
+ )
835
+ fp8_args.add_argument(
836
+ "--fp8_opt_level",
837
+ type=str,
838
+ default="O2",
839
+ choices=["O1", "O2"],
840
+ help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
841
+ )
842
+ fp8_args.add_argument(
843
+ "--fp8_enable_fsdp_float8_all_gather",
844
+ default="true",
845
+ type=str_to_bool,
846
+ help="Whether to enable FSDP2 float8 all gather (useful only when `--fp8_backend=ao` is passed).",
847
+ )
848
+ fp8_args.add_argument(
849
+ "--fp8_pad_inner_dim",
850
+ default="true",
851
+ type=str_to_bool,
852
+ help="Whether to pad the inner dimension for FP8 GEMMs (useful only when `--fp8_backend=ao` is passed).",
853
+ )
854
+
855
+ # AWS arguments
856
+ aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
857
+ aws_args.add_argument(
858
+ "--aws_access_key_id",
859
+ type=str,
860
+ default=None,
861
+ help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job",
862
+ )
863
+ aws_args.add_argument(
864
+ "--aws_secret_access_key",
865
+ type=str,
866
+ default=None,
867
+ help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.",
868
+ )
869
+ parser.add_argument(
870
+ "--debug",
871
+ action="store_true",
872
+ help="Whether to print out the torch.distributed stack trace when something fails.",
873
+ )
874
+ parser.add_argument(
875
+ "training_script",
876
+ type=str,
877
+ help=(
878
+ "The full path to the script to be launched in parallel, followed by all the arguments for the training "
879
+ "script."
880
+ ),
881
+ )
882
+
883
+ # MPI arguments
884
+ mpirun_args = parser.add_argument_group("MPI Arguments", "Arguments related to mpirun for Multi-CPU")
885
+ mpirun_args.add_argument(
886
+ "--mpirun_hostfile",
887
+ type=str,
888
+ default=None,
889
+ help="Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will "
890
+ "get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.",
891
+ )
892
+
893
+ # ParallelismConfig arguments
894
+ parallelism_config_args = parser.add_argument_group(
895
+ "ParallelismConfig Arguments",
896
+ "Arguments related to the ParallelismConfig used for distributed training.",
897
+ )
898
+
899
+ parallelism_config_args.add_argument(
900
+ "--parallelism_config_dp_replicate_size",
901
+ type=int,
902
+ default=1,
903
+ help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
904
+ )
905
+
906
+ parallelism_config_args.add_argument(
907
+ "--parallelism_config_dp_shard_size",
908
+ type=int,
909
+ default=1,
910
+ help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
911
+ )
912
+
913
+ parallelism_config_args.add_argument(
914
+ "--parallelism_config_tp_size",
915
+ type=int,
916
+ default=1,
917
+ help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
918
+ )
919
+
920
+ parallelism_config_args.add_argument(
921
+ "--parallelism_config_cp_size",
922
+ type=int,
923
+ default=1,
924
+ help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
925
+ )
926
+
927
+ parallelism_config_args.add_argument(
928
+ "--parallelism_config_cp_backend",
929
+ type=str,
930
+ choices=["torch"],
931
+ default="torch",
932
+ help="Context Parallelism backend: torch (FSDP2) or deepspeed (ALST/Ulysses)",
933
+ )
934
+
935
+ parallelism_config_args.add_argument(
936
+ "--parallelism_config_cp_comm_strategy",
937
+ type=str,
938
+ default="allgather",
939
+ help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
940
+ )
941
+
942
+ parallelism_config_args.add_argument(
943
+ "--parallelism_config_sp_size",
944
+ type=int,
945
+ default=1,
946
+ help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
947
+ )
948
+
949
+ parallelism_config_args.add_argument(
950
+ "--parallelism_config_sp_backend",
951
+ type=str,
952
+ choices=["deepspeed"],
953
+ default="deepspeed",
954
+ help="Sequence Parallelism backend: deepspeed (ALST/Ulysses)",
955
+ )
956
+
957
+ parallelism_config_args.add_argument(
958
+ "--parallelism_config_sp_seq_length",
959
+ type=str,
960
+ default=None,
961
+ help="Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `parallelism_config_sp_seq_length_is_variable=True`",
962
+ )
963
+
964
+ parallelism_config_args.add_argument(
965
+ "--parallelism_config_sp_seq_length_is_variable",
966
+ type=bool,
967
+ default=True,
968
+ help="If `True` will work with a sequence length that may change between batches, in which case `parallelism_config_sp_seq_length` value can be set to anything divisible by sp size or remain unset. If `False` then `parallelism_config_sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`.",
969
+ )
970
+
971
+ parallelism_config_args.add_argument(
972
+ "--parallelism_config_sp_attn_implementation",
973
+ type=str,
974
+ default="sdpa",
975
+ help="Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`.",
976
+ )
977
+
978
+ # Other arguments of the training scripts
979
+ parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
980
+
981
+ if subparsers is not None:
982
+ parser.set_defaults(func=launch_command)
983
+ return parser
984
+
985
+
986
+ def simple_launcher(args):
987
+ cmd, current_env = prepare_simple_launcher_cmd_env(args)
988
+
989
+ process = subprocess.Popen(cmd, env=current_env)
990
+ process.wait()
991
+ if process.returncode != 0:
992
+ if not args.quiet:
993
+ raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
994
+ else:
995
+ sys.exit(1)
996
+
997
+
998
+ def multi_gpu_launcher(args):
999
+ import torch.distributed.run as distrib_run
1000
+
1001
+ current_env = prepare_multi_gpu_env(args)
1002
+ if not check_cuda_p2p_ib_support():
1003
+ message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
1004
+ warn = False
1005
+ if "NCCL_P2P_DISABLE" not in current_env:
1006
+ current_env["NCCL_P2P_DISABLE"] = "1"
1007
+ warn = True
1008
+ if "NCCL_IB_DISABLE" not in current_env:
1009
+ current_env["NCCL_IB_DISABLE"] = "1"
1010
+ warn = True
1011
+ if warn:
1012
+ logger.warning(message)
1013
+
1014
+ debug = getattr(args, "debug", False)
1015
+ args = _filter_args(
1016
+ args,
1017
+ distrib_run.get_args_parser(),
1018
+ ["--training_script", args.training_script, "--training_script_args", args.training_script_args],
1019
+ )
1020
+
1021
+ with patch_environment(**current_env):
1022
+ try:
1023
+ distrib_run.run(args)
1024
+ except Exception:
1025
+ if is_rich_available() and debug:
1026
+ console = get_console()
1027
+ console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
1028
+ console.print_exception(suppress=[__file__], show_locals=False)
1029
+ else:
1030
+ raise
1031
+
1032
+
1033
+ def deepspeed_launcher(args):
1034
+ import torch.distributed.run as distrib_run
1035
+
1036
+ if not is_deepspeed_available():
1037
+ raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
1038
+ else:
1039
+ from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME
1040
+
1041
+ cmd, current_env = prepare_deepspeed_cmd_env(args)
1042
+ if not check_cuda_p2p_ib_support():
1043
+ message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
1044
+ warn = False
1045
+ if "NCCL_P2P_DISABLE" not in current_env:
1046
+ current_env["NCCL_P2P_DISABLE"] = "1"
1047
+ warn = True
1048
+ if "NCCL_IB_DISABLE" not in current_env:
1049
+ current_env["NCCL_IB_DISABLE"] = "1"
1050
+ warn = True
1051
+ if warn:
1052
+ logger.warning(message)
1053
+
1054
+ if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
1055
+ with open(DEEPSPEED_ENVIRONMENT_NAME, "a") as f:
1056
+ valid_env_items = convert_dict_to_env_variables(current_env)
1057
+ if len(valid_env_items) > 1:
1058
+ f.writelines(valid_env_items)
1059
+
1060
+ process = subprocess.Popen(cmd, env=current_env)
1061
+ process.wait()
1062
+ if process.returncode != 0:
1063
+ if not args.quiet:
1064
+ raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
1065
+ else:
1066
+ sys.exit(1)
1067
+ else:
1068
+ debug = getattr(args, "debug", False)
1069
+ args = _filter_args(
1070
+ args,
1071
+ distrib_run.get_args_parser(),
1072
+ ["--training_script", args.training_script, "--training_script_args", args.training_script_args],
1073
+ )
1074
+ with patch_environment(**current_env):
1075
+ try:
1076
+ distrib_run.run(args)
1077
+ except Exception:
1078
+ if is_rich_available() and debug:
1079
+ console = get_console()
1080
+ console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
1081
+ console.print_exception(suppress=[__file__], show_locals=False)
1082
+ else:
1083
+ raise
1084
+
1085
+
1086
+ def tpu_launcher(args):
1087
+ import torch_xla.distributed.xla_multiprocessing as xmp
1088
+
1089
+ if args.no_python:
1090
+ raise ValueError("--no_python cannot be used with TPU launcher")
1091
+
1092
+ args, current_env = prepare_tpu(args, {})
1093
+
1094
+ if args.module:
1095
+ mod_name = args.training_script
1096
+ else:
1097
+ # Import training_script as a module
1098
+ script_path = Path(args.training_script)
1099
+ sys.path.append(str(script_path.parent.resolve()))
1100
+ mod_name = script_path.stem
1101
+
1102
+ mod = importlib.import_module(mod_name)
1103
+ if not hasattr(mod, args.main_training_function):
1104
+ raise ValueError(
1105
+ f"Your training script should have a function named {args.main_training_function}, or you should pass a "
1106
+ "different value to `--main_training_function`."
1107
+ )
1108
+
1109
+ # Patch sys.argv
1110
+ sys.argv = [mod.__file__] + args.training_script_args
1111
+
1112
+ main_function = getattr(mod, args.main_training_function)
1113
+ with patch_environment(**current_env):
1114
+ xmp.spawn(PrepareForLaunch(main_function), args=())
1115
+
1116
+
1117
+ def tpu_pod_launcher(args):
1118
+ from torch_xla.distributed import xla_dist
1119
+
1120
+ current_env = {}
1121
+ args, current_env = prepare_tpu(args, current_env, True)
1122
+ debug = getattr(args, "debug", False)
1123
+
1124
+ training_script = args.training_script
1125
+ training_script_args = args.training_script_args
1126
+ new_args = _filter_args(
1127
+ args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"]
1128
+ )
1129
+
1130
+ if args.tpu_use_sudo:
1131
+ new_cmd = ["sudo"]
1132
+ else:
1133
+ new_cmd = []
1134
+
1135
+ new_cmd += [
1136
+ "accelerate-launch",
1137
+ "--tpu",
1138
+ "--no_tpu_cluster",
1139
+ "--num_machines",
1140
+ "1",
1141
+ "--mixed_precision",
1142
+ "no",
1143
+ "--dynamo_backend",
1144
+ "no",
1145
+ "--num_processes",
1146
+ str(args.num_processes),
1147
+ "--main_training_function",
1148
+ str(args.main_training_function),
1149
+ training_script,
1150
+ ] + training_script_args
1151
+
1152
+ new_args.positional = new_cmd
1153
+ bad_flags = ""
1154
+ for arg in vars(new_args):
1155
+ if arg.startswith("docker_"):
1156
+ value = getattr(new_args, arg)
1157
+ if value != "" and value is not None:
1158
+ bad_flags += f'{arg}="{value}"\n'
1159
+ if bad_flags != "":
1160
+ raise ValueError(
1161
+ f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}"
1162
+ )
1163
+ new_args.env = [f"{k}={v}" for k, v in current_env.items()]
1164
+ new_args.env.append("ACCELERATE_IN_TPU_POD=1")
1165
+ try:
1166
+ xla_dist.resolve_and_execute(new_args)
1167
+ except Exception:
1168
+ if is_rich_available() and debug:
1169
+ console = get_console()
1170
+ console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]")
1171
+ console.print_exception(suppress=[__file__], show_locals=False)
1172
+ else:
1173
+ raise
1174
+
1175
+
1176
+ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
1177
+ if not is_sagemaker_available():
1178
+ raise ImportError(
1179
+ "Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`"
1180
+ )
1181
+ if args.module or args.no_python:
1182
+ raise ValueError(
1183
+ "SageMaker requires a python training script file and cannot be used with --module or --no_python"
1184
+ )
1185
+
1186
+ from sagemaker.huggingface import HuggingFace
1187
+
1188
+ args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args)
1189
+
1190
+ huggingface_estimator = HuggingFace(**args)
1191
+
1192
+ huggingface_estimator.fit(inputs=sagemaker_inputs)
1193
+ print(f"You can find your model data at: {huggingface_estimator.model_data}")
1194
+
1195
+
1196
+ def _validate_launch_command(args):
1197
+ # Sanity checks
1198
+ if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
1199
+ raise ValueError(
1200
+ "You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
1201
+ )
1202
+ if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
1203
+ raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
1204
+
1205
+ if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
1206
+ raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
1207
+
1208
+ defaults = None
1209
+ warned = []
1210
+ mp_from_config_flag = False
1211
+ # Get the default from the config file.
1212
+ if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
1213
+ defaults = load_config_from_file(args.config_file)
1214
+ if (
1215
+ not args.multi_gpu
1216
+ and not args.tpu
1217
+ and not args.tpu_use_cluster
1218
+ and not args.use_deepspeed
1219
+ and not args.use_fsdp
1220
+ and not args.use_megatron_lm
1221
+ ):
1222
+ args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
1223
+ args.multi_gpu = (
1224
+ True
1225
+ if defaults.distributed_type
1226
+ in (
1227
+ DistributedType.MULTI_GPU,
1228
+ DistributedType.MULTI_NPU,
1229
+ DistributedType.MULTI_MLU,
1230
+ DistributedType.MULTI_SDAA,
1231
+ DistributedType.MULTI_MUSA,
1232
+ DistributedType.MULTI_XPU,
1233
+ DistributedType.MULTI_HPU,
1234
+ DistributedType.MULTI_NEURON,
1235
+ )
1236
+ else False
1237
+ )
1238
+ args.tpu = defaults.distributed_type == DistributedType.XLA
1239
+ args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
1240
+ args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
1241
+ args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
1242
+ args.use_parallelism_config = defaults.parallelism_config != {}
1243
+ if args.gpu_ids is None:
1244
+ if defaults.gpu_ids is not None:
1245
+ args.gpu_ids = defaults.gpu_ids
1246
+ else:
1247
+ args.gpu_ids = "all"
1248
+
1249
+ if args.multi_gpu and args.num_machines is None:
1250
+ args.num_machines = defaults.num_machines
1251
+
1252
+ if len(args.gpu_ids.split(",")) < 2 and (args.gpu_ids != "all") and args.multi_gpu and args.num_machines <= 1:
1253
+ raise ValueError(
1254
+ "Less than two GPU ids were configured and tried to run on on multiple GPUs. "
1255
+ "Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`."
1256
+ )
1257
+ if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
1258
+ # Update args with the defaults
1259
+ for name, attr in defaults.__dict__.items():
1260
+ if isinstance(attr, dict):
1261
+ # Copy defaults.somedict.somearg to args.somearg and
1262
+ # defaults.fsdp_config.x to args.fsdp_x
1263
+ for key, value in attr.items():
1264
+ if name == "fsdp_config" and not key.startswith("fsdp"):
1265
+ key = "fsdp_" + key
1266
+ elif name == "fp8_config" and not key.startswith("fp8"):
1267
+ key = "fp8_" + key
1268
+ if hasattr(args, "nondefault") and key not in args.nondefault:
1269
+ setattr(args, key, value)
1270
+ elif (
1271
+ name not in ["compute_environment", "mixed_precision", "distributed_type"]
1272
+ and getattr(args, name, None) is None
1273
+ ):
1274
+ # Those args are handled separately
1275
+ setattr(args, name, attr)
1276
+ if not args.debug:
1277
+ args.debug = defaults.debug
1278
+
1279
+ if not args.mixed_precision:
1280
+ if defaults.mixed_precision is None:
1281
+ args.mixed_precision = "no"
1282
+ else:
1283
+ args.mixed_precision = defaults.mixed_precision
1284
+ mp_from_config_flag = True
1285
+ else:
1286
+ native_amp = is_bf16_available(True)
1287
+ if (
1288
+ args.mixed_precision == "bf16"
1289
+ and not native_amp
1290
+ and not (args.tpu and is_torch_xla_available(check_is_tpu=True))
1291
+ ):
1292
+ raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
1293
+
1294
+ # Silently set the default here
1295
+ if args.dynamo_backend is None:
1296
+ args.dynamo_backend = "no"
1297
+ if args.num_processes == -1:
1298
+ raise ValueError("You need to manually pass in `--num_processes` using this config yaml.")
1299
+ else:
1300
+ if args.num_processes is None:
1301
+ if is_xpu_available():
1302
+ args.num_processes = torch.xpu.device_count()
1303
+ elif is_mlu_available():
1304
+ args.num_processes = torch.mlu.device_count()
1305
+ elif is_sdaa_available():
1306
+ args.num_processes = torch.sdaa.device_count()
1307
+ elif is_musa_available():
1308
+ args.num_processes = torch.musa.device_count()
1309
+ elif is_npu_available():
1310
+ args.num_processes = torch.npu.device_count()
1311
+ elif is_hpu_available():
1312
+ args.num_processes = torch.hpu.device_count()
1313
+ elif is_neuron_available():
1314
+ args.num_processes = torch.neuron.device_count()
1315
+ else:
1316
+ args.num_processes = torch.cuda.device_count()
1317
+ warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
1318
+ if args.debug is None:
1319
+ args.debug = False
1320
+ if (
1321
+ not args.multi_gpu
1322
+ and args.num_processes > 1
1323
+ and (
1324
+ (is_xpu_available() and torch.xpu.device_count() > 1)
1325
+ or (is_npu_available() and torch.npu.device_count() > 1)
1326
+ or (is_hpu_available() and torch.hpu.device_count() > 1)
1327
+ or (is_mlu_available() and torch.mlu.device_count() > 1)
1328
+ or (is_sdaa_available() and torch.sdaa.device_count() > 1)
1329
+ or (is_musa_available() and torch.musa.device_count() > 1)
1330
+ or (is_neuron_available() and torch.neuron.device_count() > 1)
1331
+ or (torch.cuda.is_available() and torch.cuda.device_count() > 1)
1332
+ )
1333
+ ):
1334
+ warned.append(
1335
+ "\t\tMore than one GPU was found, enabling multi-GPU training.\n"
1336
+ "\t\tIf this was unintended please pass in `--num_processes=1`."
1337
+ )
1338
+ args.multi_gpu = True
1339
+ if args.num_machines is None:
1340
+ warned.append("\t`--num_machines` was set to a value of `1`")
1341
+ args.num_machines = 1
1342
+ if args.mixed_precision is None:
1343
+ warned.append("\t`--mixed_precision` was set to a value of `'no'`")
1344
+ args.mixed_precision = "no"
1345
+ if not hasattr(args, "use_cpu"):
1346
+ args.use_cpu = args.cpu
1347
+ if args.dynamo_backend is None:
1348
+ warned.append("\t`--dynamo_backend` was set to a value of `'no'`")
1349
+ args.dynamo_backend = "no"
1350
+ if args.debug:
1351
+ logger.debug("Running script in debug mode, expect distributed operations to be slightly slower.")
1352
+
1353
+ is_aws_env_disabled = defaults is None or (
1354
+ defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER
1355
+ )
1356
+ if is_aws_env_disabled and args.num_cpu_threads_per_process is None:
1357
+ args.num_cpu_threads_per_process = get_int_from_env(["OMP_NUM_THREADS"], 1)
1358
+ if args.use_cpu and args.num_processes >= 1 and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0:
1359
+ local_size = get_int_from_env(
1360
+ ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"],
1361
+ max(int(args.num_processes / args.num_machines), 1),
1362
+ )
1363
+ import psutil
1364
+
1365
+ threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
1366
+ if threads_per_process > 1:
1367
+ args.num_cpu_threads_per_process = threads_per_process
1368
+ warned.append(
1369
+ f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs"
1370
+ )
1371
+
1372
+ if any(warned):
1373
+ message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n"
1374
+ message += "\n".join(warned)
1375
+ message += (
1376
+ "\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`."
1377
+ )
1378
+ logger.warning(message)
1379
+ return args, defaults, mp_from_config_flag
1380
+
1381
+
1382
+ def launch_command(args):
1383
+ args, defaults, mp_from_config_flag = _validate_launch_command(args)
1384
+ # Use the proper launcher
1385
+ if args.use_deepspeed and not args.cpu:
1386
+ args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else []
1387
+ if mp_from_config_flag:
1388
+ args.deepspeed_fields_from_accelerate_config.append("mixed_precision")
1389
+ args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config)
1390
+ deepspeed_launcher(args)
1391
+ elif args.use_fsdp and not args.cpu:
1392
+ multi_gpu_launcher(args)
1393
+ elif args.use_megatron_lm and not args.cpu:
1394
+ multi_gpu_launcher(args)
1395
+ elif args.multi_gpu and not args.cpu:
1396
+ multi_gpu_launcher(args)
1397
+ elif args.tpu and not args.cpu:
1398
+ if args.tpu_use_cluster:
1399
+ tpu_pod_launcher(args)
1400
+ else:
1401
+ tpu_launcher(args)
1402
+ elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
1403
+ sagemaker_launcher(defaults, args)
1404
+ else:
1405
+ simple_launcher(args)
1406
+
1407
+
1408
+ def main():
1409
+ parser = launch_command_parser()
1410
+ args = parser.parse_args()
1411
+ launch_command(args)
1412
+
1413
+
1414
+ if __name__ == "__main__":
1415
+ main()
accelerate/commands/menu/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .selection_menu import BulletMenu
accelerate/commands/menu/cursor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ A utility for showing and hiding the terminal cursor on Windows and Linux, based on https://github.com/bchao1/bullet
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ from contextlib import contextmanager
22
+
23
+
24
+ # Windows only
25
+ if os.name == "nt":
26
+ import ctypes
27
+ import msvcrt # noqa
28
+
29
+ class CursorInfo(ctypes.Structure):
30
+ # _fields is a specific attr expected by ctypes
31
+ _fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
32
+
33
+
34
+ def hide_cursor():
35
+ if os.name == "nt":
36
+ ci = CursorInfo()
37
+ handle = ctypes.windll.kernel32.GetStdHandle(-11)
38
+ ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
39
+ ci.visible = False
40
+ ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
41
+ elif os.name == "posix":
42
+ sys.stdout.write("\033[?25l")
43
+ sys.stdout.flush()
44
+
45
+
46
+ def show_cursor():
47
+ if os.name == "nt":
48
+ ci = CursorInfo()
49
+ handle = ctypes.windll.kernel32.GetStdHandle(-11)
50
+ ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
51
+ ci.visible = True
52
+ ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
53
+ elif os.name == "posix":
54
+ sys.stdout.write("\033[?25h")
55
+ sys.stdout.flush()
56
+
57
+
58
+ @contextmanager
59
+ def hide():
60
+ "Context manager to hide the terminal cursor"
61
+ try:
62
+ hide_cursor()
63
+ yield
64
+ finally:
65
+ show_cursor()
accelerate/commands/menu/helpers.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ A variety of helper functions and constants when dealing with terminal menu choices, based on
17
+ https://github.com/bchao1/bullet
18
+ """
19
+
20
+ import enum
21
+ import shutil
22
+ import sys
23
+
24
+
25
+ TERMINAL_WIDTH, _ = shutil.get_terminal_size()
26
+
27
+ CURSOR_TO_CHAR = {"UP": "A", "DOWN": "B", "RIGHT": "C", "LEFT": "D"}
28
+
29
+
30
+ class Direction(enum.Enum):
31
+ UP = 0
32
+ DOWN = 1
33
+
34
+
35
+ def forceWrite(content, end=""):
36
+ sys.stdout.write(str(content) + end)
37
+ sys.stdout.flush()
38
+
39
+
40
+ def writeColor(content, color, end=""):
41
+ forceWrite(f"\u001b[{color}m{content}\u001b[0m", end)
42
+
43
+
44
+ def reset_cursor():
45
+ forceWrite("\r")
46
+
47
+
48
+ def move_cursor(num_lines: int, direction: str):
49
+ forceWrite(f"\033[{num_lines}{CURSOR_TO_CHAR[direction.upper()]}")
50
+
51
+
52
+ def clear_line():
53
+ forceWrite(" " * TERMINAL_WIDTH)
54
+ reset_cursor()
55
+
56
+
57
+ def linebreak():
58
+ reset_cursor()
59
+ forceWrite("-" * TERMINAL_WIDTH)
accelerate/commands/menu/input.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This file contains utilities for handling input from the user and registering specific keys to specific functions,
17
+ based on https://github.com/bchao1/bullet
18
+ """
19
+
20
+ from .keymap import KEYMAP, get_character
21
+
22
+
23
+ def mark(key: str):
24
+ """
25
+ Mark the function with the key code so it can be handled in the register
26
+ """
27
+
28
+ def decorator(func):
29
+ handle = getattr(func, "handle_key", [])
30
+ handle += [key]
31
+ func.handle_key = handle
32
+ return func
33
+
34
+ return decorator
35
+
36
+
37
+ def mark_multiple(*keys: list[str]):
38
+ """
39
+ Mark the function with the key codes so it can be handled in the register
40
+ """
41
+
42
+ def decorator(func):
43
+ handle = getattr(func, "handle_key", [])
44
+ handle += keys
45
+ func.handle_key = handle
46
+ return func
47
+
48
+ return decorator
49
+
50
+
51
+ class KeyHandler(type):
52
+ """
53
+ Metaclass that adds the key handlers to the class
54
+ """
55
+
56
+ def __new__(cls, name, bases, attrs):
57
+ new_cls = super().__new__(cls, name, bases, attrs)
58
+ if not hasattr(new_cls, "key_handler"):
59
+ new_cls.key_handler = {}
60
+ new_cls.handle_input = KeyHandler.handle_input
61
+
62
+ for value in attrs.values():
63
+ handled_keys = getattr(value, "handle_key", [])
64
+ for key in handled_keys:
65
+ new_cls.key_handler[key] = value
66
+ return new_cls
67
+
68
+ @staticmethod
69
+ def handle_input(cls):
70
+ "Finds and returns the selected character if it exists in the handler"
71
+ char = get_character()
72
+ if char != KEYMAP["undefined"]:
73
+ char = ord(char)
74
+ handler = cls.key_handler.get(char)
75
+ if handler:
76
+ cls.current_selection = char
77
+ return handler(cls)
78
+ else:
79
+ return None
80
+
81
+
82
+ def register(cls):
83
+ """Adds KeyHandler metaclass to the class"""
84
+ return KeyHandler(cls.__name__, cls.__bases__, cls.__dict__.copy())
accelerate/commands/menu/keymap.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Utilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet
17
+ """
18
+
19
+ import os
20
+ import string
21
+ import sys
22
+
23
+
24
+ ARROW_KEY_FLAG = 1 << 8
25
+
26
+ KEYMAP = {
27
+ "tab": ord("\t"),
28
+ "newline": ord("\r"),
29
+ "esc": 27,
30
+ "up": 65 + ARROW_KEY_FLAG,
31
+ "down": 66 + ARROW_KEY_FLAG,
32
+ "right": 67 + ARROW_KEY_FLAG,
33
+ "left": 68 + ARROW_KEY_FLAG,
34
+ "mod_int": 91,
35
+ "undefined": sys.maxsize,
36
+ "interrupt": 3,
37
+ "insert": 50,
38
+ "delete": 51,
39
+ "pg_up": 53,
40
+ "pg_down": 54,
41
+ }
42
+
43
+ KEYMAP["arrow_begin"] = KEYMAP["up"]
44
+ KEYMAP["arrow_end"] = KEYMAP["left"]
45
+
46
+ if sys.platform == "win32":
47
+ WIN_CH_BUFFER = []
48
+ WIN_KEYMAP = {
49
+ b"\xe0H": KEYMAP["up"] - ARROW_KEY_FLAG,
50
+ b"\x00H": KEYMAP["up"] - ARROW_KEY_FLAG,
51
+ b"\xe0P": KEYMAP["down"] - ARROW_KEY_FLAG,
52
+ b"\x00P": KEYMAP["down"] - ARROW_KEY_FLAG,
53
+ b"\xe0M": KEYMAP["right"] - ARROW_KEY_FLAG,
54
+ b"\x00M": KEYMAP["right"] - ARROW_KEY_FLAG,
55
+ b"\xe0K": KEYMAP["left"] - ARROW_KEY_FLAG,
56
+ b"\x00K": KEYMAP["left"] - ARROW_KEY_FLAG,
57
+ }
58
+
59
+ for i in range(10):
60
+ KEYMAP[str(i)] = ord(str(i))
61
+
62
+
63
+ def get_raw_chars():
64
+ "Gets raw characters from inputs"
65
+ if os.name == "nt":
66
+ import msvcrt
67
+
68
+ encoding = "mbcs"
69
+ # Flush the keyboard buffer
70
+ while msvcrt.kbhit():
71
+ msvcrt.getch()
72
+ if len(WIN_CH_BUFFER) == 0:
73
+ # Read the keystroke
74
+ ch = msvcrt.getch()
75
+
76
+ # If it is a prefix char, get second part
77
+ if ch in (b"\x00", b"\xe0"):
78
+ ch2 = ch + msvcrt.getch()
79
+ # Translate actual Win chars to bullet char types
80
+ try:
81
+ chx = chr(WIN_KEYMAP[ch2])
82
+ WIN_CH_BUFFER.append(chr(KEYMAP["mod_int"]))
83
+ WIN_CH_BUFFER.append(chx)
84
+ if ord(chx) in (
85
+ KEYMAP["insert"] - 1 << 9,
86
+ KEYMAP["delete"] - 1 << 9,
87
+ KEYMAP["pg_up"] - 1 << 9,
88
+ KEYMAP["pg_down"] - 1 << 9,
89
+ ):
90
+ WIN_CH_BUFFER.append(chr(126))
91
+ ch = chr(KEYMAP["esc"])
92
+ except KeyError:
93
+ ch = ch2[1]
94
+ else:
95
+ ch = ch.decode(encoding)
96
+ else:
97
+ ch = WIN_CH_BUFFER.pop(0)
98
+ elif os.name == "posix":
99
+ import termios
100
+ import tty
101
+
102
+ fd = sys.stdin.fileno()
103
+ old_settings = termios.tcgetattr(fd)
104
+ try:
105
+ tty.setraw(fd)
106
+ ch = sys.stdin.read(1)
107
+ finally:
108
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
109
+ return ch
110
+
111
+
112
+ def get_character():
113
+ "Gets a character from the keyboard and returns the key code"
114
+ char = get_raw_chars()
115
+ if ord(char) in [KEYMAP["interrupt"], KEYMAP["newline"]]:
116
+ return char
117
+
118
+ elif ord(char) == KEYMAP["esc"]:
119
+ combo = get_raw_chars()
120
+ if ord(combo) == KEYMAP["mod_int"]:
121
+ key = get_raw_chars()
122
+ if ord(key) >= KEYMAP["arrow_begin"] - ARROW_KEY_FLAG and ord(key) <= KEYMAP["arrow_end"] - ARROW_KEY_FLAG:
123
+ return chr(ord(key) + ARROW_KEY_FLAG)
124
+ else:
125
+ return KEYMAP["undefined"]
126
+ else:
127
+ return get_raw_chars()
128
+
129
+ else:
130
+ if char in string.printable:
131
+ return char
132
+ else:
133
+ return KEYMAP["undefined"]
accelerate/commands/menu/selection_menu.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Main driver for the selection menu, based on https://github.com/bchao1/bullet
17
+ """
18
+
19
+ import builtins
20
+ import sys
21
+ from typing import Optional
22
+
23
+ from ...utils.imports import _is_package_available
24
+ from . import cursor, input
25
+ from .helpers import Direction, clear_line, forceWrite, linebreak, move_cursor, reset_cursor, writeColor
26
+ from .keymap import KEYMAP
27
+
28
+
29
+ in_colab = False
30
+ try:
31
+ in_colab = _is_package_available("google.colab")
32
+ except ModuleNotFoundError:
33
+ pass
34
+
35
+
36
+ @input.register
37
+ class BulletMenu:
38
+ """
39
+ A CLI menu to select a choice from a list of choices using the keyboard.
40
+ """
41
+
42
+ def __init__(self, prompt: Optional[str] = None, choices: list = []):
43
+ self.position = 0
44
+ self.choices = choices
45
+ self.prompt = prompt
46
+ if sys.platform == "win32":
47
+ self.arrow_char = "*"
48
+ else:
49
+ self.arrow_char = "➔ "
50
+
51
+ def write_choice(self, index, end: str = ""):
52
+ if sys.platform != "win32":
53
+ writeColor(self.choices[index], 32, end)
54
+ else:
55
+ forceWrite(self.choices[index], end)
56
+
57
+ def print_choice(self, index: int):
58
+ "Prints the choice at the given index"
59
+ if index == self.position:
60
+ forceWrite(f" {self.arrow_char} ")
61
+ self.write_choice(index)
62
+ else:
63
+ forceWrite(f" {self.choices[index]}")
64
+ reset_cursor()
65
+
66
+ def move_direction(self, direction: Direction, num_spaces: int = 1):
67
+ "Should not be directly called, used to move a direction of either up or down"
68
+ old_position = self.position
69
+ if direction == Direction.DOWN:
70
+ if self.position + 1 >= len(self.choices):
71
+ return
72
+ self.position += num_spaces
73
+ else:
74
+ if self.position - 1 < 0:
75
+ return
76
+ self.position -= num_spaces
77
+ clear_line()
78
+ self.print_choice(old_position)
79
+ move_cursor(num_spaces, direction.name)
80
+ self.print_choice(self.position)
81
+
82
+ @input.mark(KEYMAP["up"])
83
+ def move_up(self):
84
+ self.move_direction(Direction.UP)
85
+
86
+ @input.mark(KEYMAP["down"])
87
+ def move_down(self):
88
+ self.move_direction(Direction.DOWN)
89
+
90
+ @input.mark(KEYMAP["newline"])
91
+ def select(self):
92
+ move_cursor(len(self.choices) - self.position, "DOWN")
93
+ return self.position
94
+
95
+ @input.mark(KEYMAP["interrupt"])
96
+ def interrupt(self):
97
+ move_cursor(len(self.choices) - self.position, "DOWN")
98
+ raise KeyboardInterrupt
99
+
100
+ @input.mark_multiple(*[KEYMAP[str(number)] for number in range(10)])
101
+ def select_row(self):
102
+ index = int(chr(self.current_selection))
103
+ movement = index - self.position
104
+ if index == self.position:
105
+ return
106
+ if index < len(self.choices):
107
+ if self.position > index:
108
+ self.move_direction(Direction.UP, -movement)
109
+ elif self.position < index:
110
+ self.move_direction(Direction.DOWN, movement)
111
+ else:
112
+ return
113
+ else:
114
+ return
115
+
116
+ def run(self, default_choice: int = 0):
117
+ "Start the menu and return the selected choice"
118
+ if self.prompt:
119
+ linebreak()
120
+ forceWrite(self.prompt, "\n")
121
+ if in_colab:
122
+ forceWrite("Please input a choice index (starting from 0), and press enter", "\n")
123
+ else:
124
+ forceWrite("Please select a choice using the arrow or number keys, and selecting with enter", "\n")
125
+ self.position = default_choice
126
+ for i in range(len(self.choices)):
127
+ self.print_choice(i)
128
+ forceWrite("\n")
129
+ move_cursor(len(self.choices) - self.position, "UP")
130
+ with cursor.hide():
131
+ while True:
132
+ if in_colab:
133
+ try:
134
+ choice = int(builtins.input())
135
+ except ValueError:
136
+ choice = default_choice
137
+ else:
138
+ choice = self.handle_input()
139
+ if choice is not None:
140
+ reset_cursor()
141
+ for _ in range(len(self.choices) + 1):
142
+ move_cursor(1, "UP")
143
+ clear_line()
144
+ self.write_choice(choice, "\n")
145
+ return choice
accelerate/commands/merge.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from accelerate.commands.utils import CustomArgumentParser
17
+ from accelerate.utils import merge_fsdp_weights
18
+
19
+
20
+ description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
21
+ `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
22
+
23
+ This is a CPU-bound process and requires enough RAM to load the entire model state dict."""
24
+
25
+
26
+ def merge_command(args):
27
+ merge_fsdp_weights(
28
+ args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
29
+ )
30
+
31
+
32
+ def merge_command_parser(subparsers=None):
33
+ if subparsers is not None:
34
+ parser = subparsers.add_parser("merge-weights", description=description)
35
+ else:
36
+ parser = CustomArgumentParser(description=description)
37
+
38
+ parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
39
+ parser.add_argument(
40
+ "output_path",
41
+ type=str,
42
+ help="The path to save the merged weights. Defaults to the current directory. ",
43
+ )
44
+ parser.add_argument(
45
+ "--unsafe_serialization",
46
+ action="store_true",
47
+ default=False,
48
+ help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
49
+ )
50
+ parser.add_argument(
51
+ "--remove_checkpoint_dir",
52
+ action="store_true",
53
+ help="Whether to remove the checkpoint directory after merging.",
54
+ default=False,
55
+ )
56
+
57
+ if subparsers is not None:
58
+ parser.set_defaults(func=merge_command)
59
+ return parser
60
+
61
+
62
+ def main():
63
+ parser = merge_command_parser()
64
+ args = parser.parse_args()
65
+ merge_command(args)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
accelerate/commands/test.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package
20
+
21
+
22
+ def test_command_parser(subparsers=None):
23
+ if subparsers is not None:
24
+ parser = subparsers.add_parser("test")
25
+ else:
26
+ parser = argparse.ArgumentParser("Accelerate test command")
27
+
28
+ parser.add_argument(
29
+ "--config_file",
30
+ default=None,
31
+ help=(
32
+ "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
33
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
34
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
35
+ "with 'huggingface'."
36
+ ),
37
+ )
38
+
39
+ if subparsers is not None:
40
+ parser.set_defaults(func=test_command)
41
+ return parser
42
+
43
+
44
+ def test_command(args):
45
+ script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py")
46
+
47
+ if args.config_file is None:
48
+ test_args = [script_name]
49
+ else:
50
+ test_args = f"--config_file={args.config_file} {script_name}".split()
51
+
52
+ cmd = ["accelerate-launch"] + test_args
53
+ result = execute_subprocess_async(cmd)
54
+ if result.returncode == 0:
55
+ print("Test is a success! You are ready for your distributed training!")
56
+
57
+
58
+ def main():
59
+ parser = test_command_parser()
60
+ args = parser.parse_args()
61
+ test_command(args)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()
accelerate/commands/to_fsdp2.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import enum
18
+ import logging
19
+ from pathlib import Path
20
+
21
+ import yaml
22
+
23
+ from accelerate.commands.utils import CustomArgumentParser
24
+
25
+
26
+ class ConversionStatus(enum.Enum):
27
+ NOT_YET_IMPLEMENTED = 0
28
+ REMOVED = -1
29
+
30
+
31
+ ARGUMENT_KEY_MAPPING = {
32
+ # New keys in FSDP2
33
+ "fsdp_version": "fsdp_version",
34
+ "fsdp_reshard_after_forward": "fsdp_reshard_after_forward",
35
+ # https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
36
+ # https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
37
+ "fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy",
38
+ "fsdp_backward_prefetch": ConversionStatus.REMOVED,
39
+ "fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED,
40
+ "fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading",
41
+ "fsdp_offload_params": "fsdp_offload_params",
42
+ "fsdp_sharding_strategy": "fsdp_reshard_after_forward",
43
+ "fsdp_state_dict_type": "fsdp_state_dict_type",
44
+ "fsdp_sync_module_states": ConversionStatus.REMOVED,
45
+ "fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap",
46
+ "fsdp_min_num_params": "fsdp_min_num_params",
47
+ "fsdp_use_orig_params": ConversionStatus.REMOVED,
48
+ "fsdp_activation_checkpointing": "fsdp_activation_checkpointing",
49
+ }
50
+
51
+ ARGUMENT_VALUE_MAPPING = {
52
+ "fsdp_sharding_strategy": {
53
+ "FULL_SHARD": True,
54
+ "SHARD_GRAD_OP": False,
55
+ "HYBRID_SHARD": True,
56
+ "HYBRID_SHARD_ZERO2": False,
57
+ "NO_SHARD": False,
58
+ },
59
+ "fsdp_reshard_after_forward": { # Needed to convert newly created configs using FSDP1 to FSDP2
60
+ "FULL_SHARD": True,
61
+ "SHARD_GRAD_OP": False,
62
+ "HYBRID_SHARD": True,
63
+ "HYBRID_SHARD_ZERO2": False,
64
+ "NO_SHARD": False,
65
+ },
66
+ }
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ def _validate_to_fsdp2_args(args):
72
+ if not Path(args.config_file).exists():
73
+ raise FileNotFoundError(f"Config file {args.config_file} not found")
74
+
75
+ if not args.overwrite and args.output_file is None:
76
+ raise ValueError("If --overwrite is not set, --output_file must be provided")
77
+
78
+ if not args.overwrite and Path(args.output_file).exists():
79
+ raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set")
80
+
81
+
82
+ def convert_config_to_fsdp2(config: dict) -> dict:
83
+ fsdp_config = config.get("fsdp_config", {})
84
+
85
+ if not fsdp_config:
86
+ logger.info("No FSDP config found in the config file, skipping conversion...")
87
+ return config
88
+
89
+ new_fsdp_config = {}
90
+
91
+ if fsdp_config.get("fsdp_version", 1) == 2:
92
+ logger.warning("Config already specifies FSDP2, skipping conversion...")
93
+ logger.warning(
94
+ "If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
95
+ )
96
+ return config
97
+
98
+ for key, value in fsdp_config.items():
99
+ conversion_status = ARGUMENT_KEY_MAPPING.get(key, None)
100
+ if isinstance(conversion_status, ConversionStatus) or conversion_status is None:
101
+ conversion_status = key
102
+ new_fsdp_config[conversion_status] = value
103
+ continue
104
+
105
+ if conversion_status == ConversionStatus.REMOVED:
106
+ logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...")
107
+ continue
108
+
109
+ if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED:
110
+ logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...")
111
+ continue
112
+
113
+ if conversion_status is None:
114
+ logger.warning(f"Argument {key} is not being converted, skipping this key...")
115
+ new_fsdp_config[key] = value
116
+ else:
117
+ if key in ARGUMENT_VALUE_MAPPING:
118
+ value = ARGUMENT_VALUE_MAPPING[key].get(value, value)
119
+ new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value
120
+
121
+ new_fsdp_config["fsdp_version"] = 2
122
+ config["fsdp_config"] = new_fsdp_config
123
+ return config
124
+
125
+
126
+ def to_fsdp2_command_parser(subparsers=None):
127
+ description = "Convert an Accelerate config from FSDP1 to FSDP2"
128
+
129
+ if subparsers is not None:
130
+ parser = subparsers.add_parser("to-fsdp2", description=description)
131
+ else:
132
+ parser = CustomArgumentParser(description=description)
133
+
134
+ parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True)
135
+ parser.add_argument(
136
+ "--overwrite",
137
+ action="store_true",
138
+ help="Overwrite the config file if it exists",
139
+ default=False,
140
+ )
141
+ parser.add_argument(
142
+ "--output_file",
143
+ type=str,
144
+ help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)",
145
+ default=None,
146
+ )
147
+ if subparsers is not None:
148
+ parser.set_defaults(func=to_fsdp2_command)
149
+
150
+ return parser
151
+
152
+
153
+ def load_config(config_file: str) -> dict:
154
+ with open(config_file) as f:
155
+ config = yaml.safe_load(f)
156
+ if not config:
157
+ raise ValueError("Config file is empty")
158
+
159
+ return config
160
+
161
+
162
+ def to_fsdp2_command(args):
163
+ _validate_to_fsdp2_args(args)
164
+ config = load_config(args.config_file)
165
+
166
+ if args.overwrite and args.output_file is None:
167
+ args.output_file = args.config_file
168
+
169
+ new_config = convert_config_to_fsdp2(config)
170
+
171
+ with open(args.output_file, "w") as f:
172
+ yaml.dump(new_config, f)
accelerate/commands/tpu.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+ import subprocess
20
+
21
+ from packaging.version import Version, parse
22
+
23
+ from accelerate.commands.config.config_args import default_config_file, load_config_from_file
24
+
25
+
26
+ _description = "Run commands across TPU VMs for initial setup before running `accelerate launch`."
27
+
28
+
29
+ def tpu_command_parser(subparsers=None):
30
+ if subparsers is not None:
31
+ parser = subparsers.add_parser("tpu-config", description=_description)
32
+ else:
33
+ parser = argparse.ArgumentParser("Accelerate tpu-config command", description=_description)
34
+ # Core arguments
35
+ config_args = parser.add_argument_group(
36
+ "Config Arguments", "Arguments that can be configured through `accelerate config`."
37
+ )
38
+ config_args.add_argument(
39
+ "--config_file",
40
+ type=str,
41
+ default=None,
42
+ help="Path to the config file to use for accelerate.",
43
+ )
44
+ config_args.add_argument(
45
+ "--tpu_name",
46
+ default=None,
47
+ help="The name of the TPU to use. If not specified, will use the TPU specified in the config file.",
48
+ )
49
+ config_args.add_argument(
50
+ "--tpu_zone",
51
+ default=None,
52
+ help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.",
53
+ )
54
+ pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.")
55
+ pod_args.add_argument(
56
+ "--use_alpha",
57
+ action="store_true",
58
+ help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.",
59
+ )
60
+ pod_args.add_argument(
61
+ "--command_file",
62
+ default=None,
63
+ help="The path to the file containing the commands to run on the pod on startup.",
64
+ )
65
+ pod_args.add_argument(
66
+ "--command",
67
+ action="append",
68
+ nargs="+",
69
+ help="A command to run on the pod. Can be passed multiple times.",
70
+ )
71
+ pod_args.add_argument(
72
+ "--install_accelerate",
73
+ action="store_true",
74
+ help="Whether to install accelerate on the pod. Defaults to False.",
75
+ )
76
+ pod_args.add_argument(
77
+ "--accelerate_version",
78
+ default="latest",
79
+ help="The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.",
80
+ )
81
+ pod_args.add_argument(
82
+ "--debug", action="store_true", help="If set, will print the command that would be run instead of running it."
83
+ )
84
+
85
+ if subparsers is not None:
86
+ parser.set_defaults(func=tpu_command_launcher)
87
+ return parser
88
+
89
+
90
+ def tpu_command_launcher(args):
91
+ defaults = None
92
+
93
+ # Get the default from the config file if it exists.
94
+ if args.config_file is not None or os.path.isfile(default_config_file):
95
+ defaults = load_config_from_file(args.config_file)
96
+ if not args.command_file and defaults.command_file is not None and not args.command:
97
+ args.command_file = defaults.command_file
98
+ if not args.command and defaults.commands is not None:
99
+ args.command = defaults.commands
100
+ if not args.tpu_name:
101
+ args.tpu_name = defaults.tpu_name
102
+ if not args.tpu_zone:
103
+ args.tpu_zone = defaults.tpu_zone
104
+ if args.accelerate_version == "dev":
105
+ args.accelerate_version = "git+https://github.com/huggingface/accelerate.git"
106
+ elif args.accelerate_version == "latest":
107
+ args.accelerate_version = "accelerate -U"
108
+ elif isinstance(parse(args.accelerate_version), Version):
109
+ args.accelerate_version = f"accelerate=={args.accelerate_version}"
110
+
111
+ if not args.command_file and not args.command:
112
+ raise ValueError("You must specify either a command file or a command to run on the pod.")
113
+
114
+ if args.command_file:
115
+ with open(args.command_file) as f:
116
+ args.command = [f.read().splitlines()]
117
+
118
+ # To turn list of lists into list of strings
119
+ if isinstance(args.command[0], list):
120
+ args.command = [line for cmd in args.command for line in cmd]
121
+ # Default to the shared folder and install accelerate
122
+ new_cmd = ["cd /usr/share"]
123
+ if args.install_accelerate:
124
+ new_cmd += [f"pip install {args.accelerate_version}"]
125
+ new_cmd += args.command
126
+ args.command = "; ".join(new_cmd)
127
+
128
+ # Then send it to gcloud
129
+ # Eventually try to use google-api-core to do this instead of subprocess
130
+ cmd = ["gcloud"]
131
+ if args.use_alpha:
132
+ cmd += ["alpha"]
133
+ cmd += [
134
+ "compute",
135
+ "tpus",
136
+ "tpu-vm",
137
+ "ssh",
138
+ args.tpu_name,
139
+ "--zone",
140
+ args.tpu_zone,
141
+ "--command",
142
+ args.command,
143
+ "--worker",
144
+ "all",
145
+ ]
146
+ if args.debug:
147
+ print(f"Running {' '.join(cmd)}")
148
+ return
149
+ subprocess.run(cmd)
150
+ print("Successfully setup pod.")
151
+
152
+
153
+ def main():
154
+ parser = tpu_command_parser()
155
+ args = parser.parse_args()
156
+
157
+ tpu_command_launcher(args)
accelerate/commands/utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+
17
+
18
+ class _StoreAction(argparse.Action):
19
+ """
20
+ Custom action that allows for `-` or `_` to be passed in for an argument.
21
+ """
22
+
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ new_option_strings = []
26
+ for option_string in self.option_strings:
27
+ new_option_strings.append(option_string)
28
+ if "_" in option_string[2:]:
29
+ # Add `-` version to the option string
30
+ new_option_strings.append(option_string.replace("_", "-"))
31
+ self.option_strings = new_option_strings
32
+
33
+ def __call__(self, parser, namespace, values, option_string=None):
34
+ setattr(namespace, self.dest, values)
35
+ if not hasattr(namespace, "nondefault"):
36
+ namespace.nondefault = set()
37
+ namespace.nondefault.add(self.dest)
38
+
39
+
40
+ class _StoreConstAction(_StoreAction):
41
+ """
42
+ Same as `argparse._StoreConstAction` but uses the custom `_StoreAction`.
43
+ """
44
+
45
+ def __init__(self, option_strings, dest, const, default=None, required=False, help=None):
46
+ super().__init__(
47
+ option_strings=option_strings,
48
+ dest=dest,
49
+ nargs=0,
50
+ const=const,
51
+ default=default,
52
+ required=required,
53
+ help=help,
54
+ )
55
+
56
+ def __call__(self, parser, namespace, values, option_string=None):
57
+ super().__call__(parser, namespace, self.const, option_string)
58
+
59
+
60
+ class _StoreTrueAction(_StoreConstAction):
61
+ """
62
+ Same as `argparse._StoreTrueAction` but uses the custom `_StoreConstAction`.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ option_strings,
68
+ dest,
69
+ default=None,
70
+ required=False,
71
+ help=None,
72
+ ):
73
+ super().__init__(
74
+ option_strings=option_strings, dest=dest, const=True, default=default, required=required, help=help
75
+ )
76
+
77
+
78
+ class CustomArgumentGroup(argparse._ArgumentGroup):
79
+ """
80
+ Custom argument group that allows for the use of `-` or `_` in arguments passed and overrides the help for each
81
+ when applicable.
82
+ """
83
+
84
+ def _add_action(self, action):
85
+ args = vars(action)
86
+ if isinstance(action, argparse._StoreTrueAction):
87
+ action = _StoreTrueAction(
88
+ args["option_strings"], args["dest"], args["default"], args["required"], args["help"]
89
+ )
90
+ elif isinstance(action, argparse._StoreConstAction):
91
+ action = _StoreConstAction(
92
+ args["option_strings"],
93
+ args["dest"],
94
+ args["const"],
95
+ args["default"],
96
+ args["required"],
97
+ args["help"],
98
+ )
99
+ elif isinstance(action, argparse._StoreAction):
100
+ action = _StoreAction(**args)
101
+ action = super()._add_action(action)
102
+ return action
103
+
104
+
105
+ class CustomArgumentParser(argparse.ArgumentParser):
106
+ """
107
+ Custom argument parser that allows for the use of `-` or `_` in arguments passed and overrides the help for each
108
+ when applicable.
109
+ """
110
+
111
+ def add_argument(self, *args, **kwargs):
112
+ if "action" in kwargs:
113
+ # Translate action -> class
114
+ if kwargs["action"] == "store_true":
115
+ kwargs["action"] = _StoreTrueAction
116
+ else:
117
+ kwargs["action"] = _StoreAction
118
+ super().add_argument(*args, **kwargs)
119
+
120
+ def add_argument_group(self, *args, **kwargs):
121
+ group = CustomArgumentGroup(self, *args, **kwargs)
122
+ self._action_groups.append(group)
123
+ return group
accelerate/test_utils/__init__.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .testing import (
15
+ DEFAULT_LAUNCH_COMMAND,
16
+ are_the_same_tensors,
17
+ assert_exception,
18
+ capture_call_output,
19
+ device_count,
20
+ execute_subprocess_async,
21
+ get_launch_command,
22
+ get_torch_dist_unique_port,
23
+ memory_allocated_func,
24
+ path_in_accelerate_package,
25
+ pytest_xdist_worker_id,
26
+ require_bnb,
27
+ require_cpu,
28
+ require_cuda,
29
+ require_cuda_or_hpu,
30
+ require_cuda_or_xpu,
31
+ require_fp8,
32
+ require_fp16,
33
+ require_huggingface_suite,
34
+ require_mlu,
35
+ require_mps,
36
+ require_multi_device,
37
+ require_multi_gpu,
38
+ require_multi_gpu_or_xpu,
39
+ require_multi_xpu,
40
+ require_musa,
41
+ require_non_cpu,
42
+ require_non_hpu,
43
+ require_non_torch_xla,
44
+ require_non_xpu,
45
+ require_npu,
46
+ require_pippy,
47
+ require_sdaa,
48
+ require_single_device,
49
+ require_single_gpu,
50
+ require_single_xpu,
51
+ require_torch_min_version,
52
+ require_torchao,
53
+ require_torchvision,
54
+ require_tpu,
55
+ require_transformer_engine,
56
+ require_transformer_engine_mxfp8,
57
+ require_xpu,
58
+ run_first,
59
+ skip,
60
+ slow,
61
+ torch_device,
62
+ )
63
+ from .training import RegressionDataset, RegressionModel
64
+
65
+
66
+ from .scripts import test_script, test_sync, test_ops # isort: skip
accelerate/test_utils/examples.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ A collection of utilities for comparing `examples/complete_*_example.py` scripts with the capabilities inside of each
18
+ `examples/by_feature` example. `compare_against_test` is the main function that should be used when testing, while the
19
+ others are used to either get the code that matters, or to preprocess them (such as stripping comments)
20
+ """
21
+
22
+ import os
23
+ from typing import Optional
24
+
25
+
26
+ def get_function_contents_by_name(lines: list[str], name: str):
27
+ """
28
+ Extracts a function from `lines` of segmented source code with the name `name`.
29
+
30
+ Args:
31
+ lines (`List[str]`):
32
+ Source code of a script separated by line.
33
+ name (`str`):
34
+ The name of the function to extract. Should be either `training_function` or `main`
35
+ """
36
+ if name != "training_function" and name != "main":
37
+ raise ValueError(f"Incorrect function name passed: {name}, choose either 'main' or 'training_function'")
38
+ good_lines, found_start = [], False
39
+ for line in lines:
40
+ if not found_start and f"def {name}" in line:
41
+ found_start = True
42
+ good_lines.append(line)
43
+ continue
44
+ if found_start:
45
+ if name == "training_function" and "def main" in line:
46
+ return good_lines
47
+ if name == "main" and "if __name__" in line:
48
+ return good_lines
49
+ good_lines.append(line)
50
+
51
+
52
+ def clean_lines(lines: list[str]):
53
+ """
54
+ Filters `lines` and removes any entries that start with a comment ('#') or is just a newline ('\n')
55
+
56
+ Args:
57
+ lines (`List[str]`):
58
+ Source code of a script separated by line.
59
+ """
60
+ return [line for line in lines if not line.lstrip().startswith("#") and line != "\n"]
61
+
62
+
63
+ def compare_against_test(
64
+ base_filename: str, feature_filename: str, parser_only: bool, secondary_filename: Optional[str] = None
65
+ ):
66
+ """
67
+ Tests whether the additional code inside of `feature_filename` was implemented in `base_filename`. This should be
68
+ used when testing to see if `complete_*_.py` examples have all of the implementations from each of the
69
+ `examples/by_feature/*` scripts.
70
+
71
+ It utilizes `nlp_example.py` to extract out all of the repeated training code, so that only the new additional code
72
+ is examined and checked. If something *other* than `nlp_example.py` should be used, such as `cv_example.py` for the
73
+ `complete_cv_example.py` script, it should be passed in for the `secondary_filename` parameter.
74
+
75
+ Args:
76
+ base_filename (`str` or `os.PathLike`):
77
+ The filepath of a single "complete" example script to test, such as `examples/complete_cv_example.py`
78
+ feature_filename (`str` or `os.PathLike`):
79
+ The filepath of a single feature example script. The contents of this script are checked to see if they
80
+ exist in `base_filename`
81
+ parser_only (`bool`):
82
+ Whether to compare only the `main()` sections in both files, or to compare the contents of
83
+ `training_loop()`
84
+ secondary_filename (`str`, *optional*):
85
+ A potential secondary filepath that should be included in the check. This function extracts the base
86
+ functionalities off of "examples/nlp_example.py", so if `base_filename` is a script other than
87
+ `complete_nlp_example.py`, the template script should be included here. Such as `examples/cv_example.py`
88
+ """
89
+ with open(base_filename) as f:
90
+ base_file_contents = f.readlines()
91
+ with open(os.path.abspath(os.path.join("examples", "nlp_example.py"))) as f:
92
+ full_file_contents = f.readlines()
93
+ with open(feature_filename) as f:
94
+ feature_file_contents = f.readlines()
95
+ if secondary_filename is not None:
96
+ with open(secondary_filename) as f:
97
+ secondary_file_contents = f.readlines()
98
+
99
+ # This is our base, we remove all the code from here in our `full_filename` and `feature_filename` to find the new content
100
+ if parser_only:
101
+ base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "main"))
102
+ full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "main"))
103
+ feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "main"))
104
+ if secondary_filename is not None:
105
+ secondary_file_func = clean_lines(get_function_contents_by_name(secondary_file_contents, "main"))
106
+ else:
107
+ base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "training_function"))
108
+ full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "training_function"))
109
+ feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "training_function"))
110
+ if secondary_filename is not None:
111
+ secondary_file_func = clean_lines(
112
+ get_function_contents_by_name(secondary_file_contents, "training_function")
113
+ )
114
+
115
+ _dl_line = "train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n"
116
+
117
+ # Specific code in our script that differs from the full version, aka what is new
118
+ new_feature_code = []
119
+ passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
120
+ it = iter(feature_file_func)
121
+ for i in range(len(feature_file_func) - 1):
122
+ if i not in passed_idxs:
123
+ line = next(it)
124
+ if (line not in full_file_func) and (line.lstrip() != _dl_line):
125
+ if "TESTING_MOCKED_DATALOADERS" not in line:
126
+ new_feature_code.append(line)
127
+ passed_idxs.append(i)
128
+ else:
129
+ # Skip over the `config['num_epochs'] = 2` statement
130
+ _ = next(it)
131
+
132
+ # Extract out just the new parts from the full_file_training_func
133
+ new_full_example_parts = []
134
+ passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
135
+ for i, line in enumerate(base_file_func):
136
+ if i not in passed_idxs:
137
+ if (line not in full_file_func) and (line.lstrip() != _dl_line):
138
+ if "TESTING_MOCKED_DATALOADERS" not in line:
139
+ new_full_example_parts.append(line)
140
+ passed_idxs.append(i)
141
+
142
+ # Finally, get the overall diff
143
+ diff_from_example = [line for line in new_feature_code if line not in new_full_example_parts]
144
+ if secondary_filename is not None:
145
+ diff_from_two = [line for line in full_file_contents if line not in secondary_file_func]
146
+ diff_from_example = [line for line in diff_from_example if line not in diff_from_two]
147
+
148
+ return diff_from_example
accelerate/test_utils/scripts/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
accelerate/test_utils/scripts/external_deps/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
accelerate/test_utils/scripts/external_deps/test_checkpointing.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import json
16
+ import os
17
+
18
+ import evaluate
19
+ import torch
20
+ from datasets import load_dataset
21
+ from torch.optim import AdamW
22
+ from torch.utils.data import DataLoader
23
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
24
+
25
+ from accelerate import Accelerator, DistributedType
26
+ from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
27
+
28
+
29
+ MAX_GPU_BATCH_SIZE = 16
30
+ EVAL_BATCH_SIZE = 32
31
+
32
+
33
+ def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"):
34
+ """
35
+ Creates a set of `DataLoader`s for the `glue` dataset.
36
+
37
+ Args:
38
+ accelerator (`Accelerator`):
39
+ An `Accelerator` object
40
+ batch_size (`int`, *optional*):
41
+ The batch size for the train and validation DataLoaders.
42
+ model_name (`str`, *optional*):
43
+ """
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+ datasets = load_dataset("glue", "mrpc")
46
+
47
+ def tokenize_function(examples):
48
+ # max_length=None => use the model max length (it's actually the default)
49
+ outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
50
+ return outputs
51
+
52
+ # Apply the method we just defined to all the examples in all the splits of the dataset
53
+ tokenized_datasets = datasets.map(
54
+ tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
55
+ )
56
+
57
+ # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
58
+ # transformers library
59
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
60
+
61
+ def collate_fn(examples):
62
+ # On TPU it's best to pad everything to the same length or training will be very slow.
63
+ if accelerator.distributed_type == DistributedType.XLA:
64
+ return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
65
+ return tokenizer.pad(examples, padding="longest", return_tensors="pt")
66
+
67
+ # Instantiate dataloaders.
68
+ train_dataloader = DataLoader(
69
+ tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
70
+ )
71
+ eval_dataloader = DataLoader(
72
+ tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
73
+ )
74
+
75
+ return train_dataloader, eval_dataloader
76
+
77
+
78
+ def evaluation_loop(accelerator, model, eval_dataloader, metric):
79
+ model.eval()
80
+ samples_seen = 0
81
+ for step, batch in enumerate(eval_dataloader):
82
+ # We could avoid this line since we set the accelerator with `device_placement=True`.
83
+ batch.to(accelerator.device)
84
+ with torch.no_grad():
85
+ outputs = model(**batch)
86
+ predictions = outputs.logits.argmax(dim=-1)
87
+ # It is slightly faster to call this once, than multiple times
88
+ predictions, references = accelerator.gather(
89
+ (predictions, batch["labels"])
90
+ ) # If we are in a multiprocess environment, the last batch has duplicates
91
+ if accelerator.use_distributed:
92
+ if step == len(eval_dataloader) - 1:
93
+ predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
94
+ references = references[: len(eval_dataloader.dataset) - samples_seen]
95
+ else:
96
+ samples_seen += references.shape[0]
97
+ metric.add_batch(
98
+ predictions=predictions,
99
+ references=references,
100
+ )
101
+
102
+ eval_metric = metric.compute()
103
+ return eval_metric["accuracy"]
104
+
105
+
106
+ def training_function(config, args):
107
+ # Initialize accelerator
108
+ accelerator = Accelerator()
109
+
110
+ # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
111
+ lr = config["lr"]
112
+ num_epochs = int(config["num_epochs"])
113
+ seed = int(config["seed"])
114
+ batch_size = int(config["batch_size"])
115
+ model_name = args.model_name_or_path
116
+
117
+ set_seed(seed)
118
+ train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
119
+
120
+ # Instantiate the model (we build the model here so that the seed also control new weights initialization)
121
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
122
+
123
+ # Instantiate optimizer
124
+ optimizer_cls = (
125
+ AdamW
126
+ if accelerator.state.deepspeed_plugin is None
127
+ or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
128
+ else DummyOptim
129
+ )
130
+ optimizer = optimizer_cls(params=model.parameters(), lr=lr)
131
+
132
+ if accelerator.state.deepspeed_plugin is not None:
133
+ gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
134
+ "gradient_accumulation_steps"
135
+ ]
136
+ else:
137
+ gradient_accumulation_steps = 1
138
+ max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps
139
+
140
+ # Instantiate scheduler
141
+ if (
142
+ accelerator.state.deepspeed_plugin is None
143
+ or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
144
+ ):
145
+ lr_scheduler = get_linear_schedule_with_warmup(
146
+ optimizer=optimizer,
147
+ num_warmup_steps=0,
148
+ num_training_steps=max_training_steps,
149
+ )
150
+ else:
151
+ lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)
152
+
153
+ # Prepare everything
154
+ # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
155
+ # prepare method.
156
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
157
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
158
+ )
159
+
160
+ # We need to keep track of how many total steps we have iterated over
161
+ overall_step = 0
162
+ # We also need to keep track of the stating epoch so files are named properly
163
+ starting_epoch = 0
164
+ metric = evaluate.load("glue", "mrpc")
165
+ ending_epoch = num_epochs
166
+
167
+ if args.partial_train_epoch is not None:
168
+ ending_epoch = args.partial_train_epoch
169
+
170
+ if args.resume_from_checkpoint:
171
+ accelerator.load_state(args.resume_from_checkpoint)
172
+ epoch_string = args.resume_from_checkpoint.split("epoch_")[1]
173
+ state_epoch_num = ""
174
+ for char in epoch_string:
175
+ if char.isdigit():
176
+ state_epoch_num += char
177
+ else:
178
+ break
179
+ starting_epoch = int(state_epoch_num) + 1
180
+ accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric)
181
+ accelerator.print("resumed checkpoint performance:", accuracy)
182
+ accelerator.print("resumed checkpoint's scheduler's lr:", lr_scheduler.get_lr()[0])
183
+ accelerator.print("resumed optimizers's lr:", optimizer.param_groups[0]["lr"])
184
+ with open(os.path.join(args.output_dir, f"state_{starting_epoch - 1}.json")) as f:
185
+ resumed_state = json.load(f)
186
+ assert resumed_state["accuracy"] == accuracy, "Accuracy mismatch, loading from checkpoint failed"
187
+ assert resumed_state["lr"] == lr_scheduler.get_lr()[0], (
188
+ "Scheduler learning rate mismatch, loading from checkpoint failed"
189
+ )
190
+ assert resumed_state["optimizer_lr"] == optimizer.param_groups[0]["lr"], (
191
+ "Optimizer learning rate mismatch, loading from checkpoint failed"
192
+ )
193
+ assert resumed_state["epoch"] == starting_epoch - 1, "Epoch mismatch, loading from checkpoint failed"
194
+ return
195
+
196
+ # Now we train the model
197
+ state = {}
198
+ for epoch in range(starting_epoch, ending_epoch):
199
+ model.train()
200
+ for step, batch in enumerate(train_dataloader):
201
+ outputs = model(**batch)
202
+ loss = outputs.loss
203
+ loss = loss / gradient_accumulation_steps
204
+ accelerator.backward(loss)
205
+ if step % gradient_accumulation_steps == 0:
206
+ optimizer.step()
207
+ lr_scheduler.step()
208
+ optimizer.zero_grad()
209
+
210
+ overall_step += 1
211
+ output_dir = f"epoch_{epoch}"
212
+ output_dir = os.path.join(args.output_dir, output_dir)
213
+ accelerator.save_state(output_dir)
214
+ accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric)
215
+ state["accuracy"] = accuracy
216
+ state["lr"] = lr_scheduler.get_lr()[0]
217
+ state["optimizer_lr"] = optimizer.param_groups[0]["lr"]
218
+ state["epoch"] = epoch
219
+ state["step"] = overall_step
220
+ accelerator.print(f"epoch {epoch}:", state)
221
+
222
+ accelerator.wait_for_everyone()
223
+ if accelerator.is_main_process:
224
+ with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f:
225
+ json.dump(state, f)
226
+ accelerator.end_training()
227
+
228
+
229
+ def main():
230
+ parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
231
+ parser.add_argument(
232
+ "--model_name_or_path",
233
+ type=str,
234
+ default="bert-base-cased",
235
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
236
+ required=False,
237
+ )
238
+ parser.add_argument(
239
+ "--output_dir",
240
+ type=str,
241
+ default=".",
242
+ help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
243
+ )
244
+ parser.add_argument(
245
+ "--resume_from_checkpoint",
246
+ type=str,
247
+ default=None,
248
+ help="If the training should continue from a checkpoint folder.",
249
+ )
250
+ parser.add_argument(
251
+ "--partial_train_epoch",
252
+ type=int,
253
+ default=None,
254
+ help="If passed, the training will stop after this number of epochs.",
255
+ )
256
+ parser.add_argument(
257
+ "--num_epochs",
258
+ type=int,
259
+ default=2,
260
+ help="Number of train epochs.",
261
+ )
262
+ args = parser.parse_args()
263
+ config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
264
+
265
+ training_function(config, args)
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Test script for verifying ALST/Ulysses SP works
17
+ """
18
+
19
+ import torch
20
+ from deepspeed.runtime.utils import move_to_device
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+ from accelerate import Accelerator
24
+ from accelerate.utils import ParallelismConfig, set_seed
25
+ from accelerate.utils.dataclasses import DeepSpeedSequenceParallelConfig
26
+
27
+
28
+ set_seed(42)
29
+
30
+ world_size = 2
31
+ model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
32
+
33
+ micro_batch_size = 1
34
+
35
+ parallelism_config = ParallelismConfig(
36
+ sp_backend="deepspeed",
37
+ sp_size=world_size,
38
+ # dp_shard_size=1, # set if dp is wanted as well
39
+ sp_handler=DeepSpeedSequenceParallelConfig(
40
+ sp_seq_length=256,
41
+ sp_seq_length_is_variable=True,
42
+ sp_attn_implementation="sdpa",
43
+ ),
44
+ )
45
+
46
+ accelerator = Accelerator(
47
+ parallelism_config=parallelism_config,
48
+ )
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ model = AutoModelForCausalLM.from_pretrained(model_name)
52
+
53
+ samples = 4
54
+ seqlen = 32
55
+ input_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100
56
+ position_ids = torch.arange(seqlen * samples).view(-1, seqlen)
57
+
58
+ ds = torch.utils.data.TensorDataset(input_ids, position_ids)
59
+
60
+
61
+ def collate_fn(batch):
62
+ input_ids, position_ids = batch[0]
63
+ return dict(
64
+ input_ids=input_ids.unsqueeze(0),
65
+ position_ids=position_ids.unsqueeze(0),
66
+ labels=input_ids.unsqueeze(0),
67
+ )
68
+
69
+
70
+ dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)
71
+
72
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
73
+
74
+ rank = torch.distributed.get_rank()
75
+
76
+ if rank == 0:
77
+ print(f"DL orig: {len(dl)} samples")
78
+
79
+ model, optimizer, dl = accelerator.prepare(model, optimizer, dl)
80
+
81
+ if rank == 0:
82
+ print(f"DL w/ adapter: {len(dl)} samples")
83
+
84
+ sp_size = parallelism_config.sp_size if parallelism_config else 1
85
+ if sp_size > 1:
86
+ from deepspeed.utils import groups
87
+
88
+ sp_group = groups._get_sequence_parallel_group()
89
+ sp_world_size = parallelism_config.sp_size
90
+
91
+ unwrapped_model = accelerator.unwrap_model(model)
92
+
93
+ # Normal training loop
94
+ for iter, batch in enumerate(dl):
95
+ optimizer.zero_grad()
96
+
97
+ if rank == 0:
98
+ print(f"batch {iter}: seqlen: {len(batch['input_ids'][0])}")
99
+ batch = move_to_device(batch, model.device)
100
+ outputs = model(**batch)
101
+
102
+ shift_labels = batch["shift_labels"]
103
+ loss = unwrapped_model.loss_function(
104
+ logits=outputs.logits,
105
+ labels=None,
106
+ shift_labels=shift_labels,
107
+ vocab_size=unwrapped_model.config.vocab_size,
108
+ )
109
+
110
+ if sp_size > 1:
111
+ # differentiable weighted per-shard-loss aggregation across ranks
112
+ losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
113
+ # special dealing with SFT that has prompt tokens that aren't used in loss computation
114
+ good_tokens = (shift_labels != -100).view(-1).sum()
115
+ good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
116
+ total_loss = sum(
117
+ losses_per_rank[rank] * good_tokens_per_rank[rank]
118
+ for rank in range(sp_world_size)
119
+ if good_tokens_per_rank[rank] > 0
120
+ )
121
+ total_good_tokens = sum(good_tokens_per_rank)
122
+ loss = total_loss / max(total_good_tokens, 1)
123
+
124
+ if rank == 0:
125
+ accelerator.print(f"{iter}: {loss=}")
126
+ accelerator.log(dict(train_loss=loss, step=iter))
127
+
128
+ accelerator.backward(loss)
129
+ optimizer.step()
130
+
131
+ accelerator.end_training()
accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Test script for verifying multiple models can be utilized with Accelerate + DeepSpeed:
17
+
18
+ Scenario 1: One model is training, another model is being used for inference/logits to impact training in some form.
19
+ Scenario 2: Two models are training simultaneously, which means two optimizers, etc.
20
+ """
21
+
22
+ import argparse
23
+ from pathlib import Path
24
+
25
+ import evaluate
26
+ import torch
27
+ from datasets import load_dataset
28
+ from torch.optim import AdamW
29
+ from torch.utils.data import DataLoader
30
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
31
+
32
+ from accelerate import Accelerator, DeepSpeedPlugin, DistributedType
33
+ from accelerate.state import AcceleratorState
34
+ from accelerate.utils.deepspeed import get_active_deepspeed_plugin
35
+
36
+
37
+ EVAL_BATCH_SIZE = 16
38
+
39
+
40
+ class NoiseModel(torch.nn.Module):
41
+ def __init__(self, noise_factor=0.1):
42
+ super().__init__()
43
+ self.noise_factor = torch.nn.Parameter(torch.tensor(noise_factor, dtype=torch.float32))
44
+
45
+ def forward(self, loss):
46
+ return loss * self.noise_factor
47
+
48
+
49
+ def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"):
50
+ """
51
+ Creates a set of `DataLoader`s for the `glue` dataset.
52
+
53
+ Args:
54
+ accelerator (`Accelerator`):
55
+ An `Accelerator` object
56
+ batch_size (`int`, *optional*):
57
+ The batch size for the train and validation DataLoaders.
58
+ model_name (`str`, *optional*):
59
+ """
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
61
+ datasets = load_dataset("glue", "mrpc")
62
+
63
+ def tokenize_function(examples):
64
+ # max_length=None => use the model max length (it's actually the default)
65
+ outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
66
+ return outputs
67
+
68
+ # Apply the method we just defined to all the examples in all the splits of the dataset
69
+ tokenized_datasets = datasets.map(
70
+ tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
71
+ )
72
+
73
+ # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
74
+ # transformers library
75
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
76
+
77
+ def collate_fn(examples):
78
+ # On TPU it's best to pad everything to the same length or training will be very slow.
79
+ if accelerator.distributed_type == DistributedType.XLA:
80
+ return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
81
+ return tokenizer.pad(examples, padding="longest", return_tensors="pt")
82
+
83
+ # Instantiate dataloaders.
84
+ train_dataloader = DataLoader(
85
+ tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
86
+ )
87
+ eval_dataloader = DataLoader(
88
+ tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
89
+ )
90
+
91
+ return train_dataloader, eval_dataloader
92
+
93
+
94
+ test_file_path = __file__
95
+ path = Path(test_file_path).resolve()
96
+ test_file_dir_str = str(path.parent.parent.parent.parent.parent.parent)
97
+
98
+ # Create our DS plugins
99
+ # We use custom schedulers and optimizers, hence `model_only`
100
+ ds_config_file = dict(
101
+ zero2=f"{test_file_dir_str}/tests/deepspeed/ds_config_zero2_model_only.json",
102
+ zero3=f"{test_file_dir_str}/tests/deepspeed/ds_config_zero3_model_only.json",
103
+ )
104
+
105
+
106
+ def single_model_training(config, args):
107
+ # Training a single model, we have a `noise` model that is untrainable used to inject some noise into the training process
108
+ num_epochs = config["num_epochs"]
109
+ zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero2"])
110
+ zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero3"])
111
+
112
+ deepspeed_plugins = {"training": zero2_plugin, "inference": zero3_plugin}
113
+
114
+ # Initialize accelerator
115
+ accelerator = Accelerator(
116
+ deepspeed_plugins=deepspeed_plugins,
117
+ mixed_precision="bf16",
118
+ )
119
+
120
+ # Initialize model under zero2 plugin
121
+ assert get_active_deepspeed_plugin(accelerator.state) is zero2_plugin
122
+ train_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
123
+ train_dataloader, eval_dataloader = get_dataloaders(
124
+ accelerator, batch_size=config["batch_size"], model_name=args.model_name_or_path
125
+ )
126
+ max_training_steps = len(train_dataloader) * config["num_epochs"]
127
+ optimizer = AdamW(train_model.parameters(), lr=config["lr"])
128
+ lr_scheduler = get_linear_schedule_with_warmup(
129
+ optimizer, num_warmup_steps=0, num_training_steps=max_training_steps
130
+ )
131
+
132
+ train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler = accelerator.prepare(
133
+ train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler
134
+ )
135
+
136
+ # Now prepare the model under zero3 plugin
137
+ accelerator.state.select_deepspeed_plugin("inference")
138
+ assert get_active_deepspeed_plugin(accelerator.state) is zero3_plugin
139
+ inference_model = NoiseModel()
140
+ inference_model = accelerator.prepare(inference_model)
141
+ inference_model.eval()
142
+
143
+ # Run training loop
144
+ accelerator.state.select_deepspeed_plugin("training")
145
+ # We also need to keep track of the stating epoch so files are named properly
146
+ starting_epoch = 0
147
+
148
+ # Now we train the model
149
+ best_performance = 0
150
+ metric = evaluate.load("glue", "mrpc")
151
+ performance_metric = {}
152
+ for epoch in range(starting_epoch, num_epochs):
153
+ train_model.train()
154
+ inference_model.train()
155
+ for step, batch in enumerate(train_dataloader):
156
+ with accelerator.accumulate(train_model):
157
+ outputs_1 = train_model(**batch)
158
+ with torch.no_grad():
159
+ outputs_2 = inference_model(outputs_1.loss)
160
+ # Combine the losses
161
+ loss = outputs_1.loss + outputs_2
162
+ accelerator.backward(loss)
163
+ optimizer.step()
164
+ lr_scheduler.step()
165
+ optimizer.zero_grad()
166
+
167
+ train_model.eval()
168
+ for step, batch in enumerate(eval_dataloader):
169
+ with torch.no_grad():
170
+ outputs = train_model(**batch)
171
+ predictions = outputs.logits.argmax(dim=-1)
172
+ # It is slightly faster to call this once, than multiple times
173
+ predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
174
+ metric.add_batch(
175
+ predictions=predictions,
176
+ references=references,
177
+ )
178
+
179
+ eval_metric = metric.compute()
180
+ # Use accelerator.print to print only on the main process.
181
+ accelerator.print(f"epoch {epoch}:", eval_metric)
182
+ performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"]
183
+
184
+ if best_performance < eval_metric["accuracy"]:
185
+ best_performance = eval_metric["accuracy"]
186
+ assert best_performance > performance_metric["epoch-0"]
187
+
188
+
189
+ def multiple_model_training(config, args):
190
+ # This will essentially be like a k-fold model, but one model is Zero-2 and another model is Zero-3
191
+ num_epochs = config["num_epochs"]
192
+ zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero2"])
193
+ zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero3"])
194
+
195
+ deepspeed_plugins = {"zero2": zero2_plugin, "zero3": zero3_plugin}
196
+
197
+ # Initialize accelerator
198
+ zero2_accelerator = Accelerator(
199
+ deepspeed_plugins=deepspeed_plugins,
200
+ mixed_precision="bf16",
201
+ )
202
+
203
+ # Since an `AcceleratorState` has already been made, we can just reuse it here
204
+ zero3_accelerator = Accelerator()
205
+
206
+ # Initialize model under zero2 plugin
207
+ assert get_active_deepspeed_plugin(zero2_accelerator.state) is zero2_plugin
208
+ zero2_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
209
+ train_dataloader, eval_dataloader = get_dataloaders(
210
+ zero2_accelerator, batch_size=config["batch_size"], model_name=args.model_name_or_path
211
+ )
212
+ max_training_steps = len(train_dataloader) * config["num_epochs"]
213
+ zero2_optimizer = AdamW(zero2_model.parameters(), lr=config["lr"])
214
+ zero2_lr_scheduler = get_linear_schedule_with_warmup(
215
+ zero2_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps
216
+ )
217
+
218
+ train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler = zero2_accelerator.prepare(
219
+ train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler
220
+ )
221
+ assert zero2_accelerator.deepspeed_engine_wrapped.engine is zero2_model
222
+
223
+ # now do Zero3
224
+ zero3_accelerator.state.select_deepspeed_plugin("zero3")
225
+ zero3_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = zero2_plugin.deepspeed_config[
226
+ "train_micro_batch_size_per_gpu"
227
+ ]
228
+ assert get_active_deepspeed_plugin(zero3_accelerator.state) is zero3_plugin
229
+ zero3_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
230
+ zero3_optimizer = AdamW(zero3_model.parameters(), lr=config["lr"])
231
+ zero3_lr_scheduler = get_linear_schedule_with_warmup(
232
+ zero3_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps
233
+ )
234
+ zero3_model, zero3_optimizer, zero3_lr_scheduler = zero3_accelerator.prepare(
235
+ zero3_model, zero3_optimizer, zero3_lr_scheduler
236
+ )
237
+ assert zero3_accelerator.deepspeed_engine_wrapped.engine is zero3_model
238
+
239
+ # Run training loop
240
+ starting_epoch = 0
241
+
242
+ # Now we train the model
243
+ best_performance_a = 0
244
+ best_performance_b = 0
245
+ metric_a = evaluate.load("glue", "mrpc")
246
+ metric_b = evaluate.load("glue", "mrpc")
247
+ performance_metric_a = {}
248
+ performance_metric_b = {}
249
+ for epoch in range(starting_epoch, num_epochs):
250
+ zero2_model.train()
251
+ zero3_model.train()
252
+ for step, batch in enumerate(train_dataloader):
253
+ with zero2_accelerator.accumulate(zero2_model, zero3_model):
254
+ outputs_1 = zero2_model(**batch)
255
+ zero2_accelerator.backward(outputs_1.loss)
256
+ zero2_optimizer.step()
257
+ zero2_lr_scheduler.step()
258
+ zero2_optimizer.zero_grad()
259
+ outputs_2 = zero3_model(**batch)
260
+ zero3_accelerator.backward(outputs_2.loss)
261
+ zero3_optimizer.step()
262
+ zero3_lr_scheduler.step()
263
+ zero3_optimizer.zero_grad()
264
+
265
+ zero2_model.eval()
266
+ zero3_model.eval()
267
+ for step, batch in enumerate(eval_dataloader):
268
+ with torch.no_grad():
269
+ logits_a = zero2_model(**batch).logits
270
+ logits_b = zero3_model(**batch).logits
271
+ # Combine the logits from both models
272
+ predictions_a = logits_a.argmax(dim=-1)
273
+ predictions_b = logits_b.argmax(dim=-1)
274
+ # It is slightly faster to call this once, than multiple times
275
+ predictions_a, predictions_b, references = zero2_accelerator.gather_for_metrics(
276
+ (predictions_a, predictions_b, batch["labels"])
277
+ )
278
+ metric_a.add_batch(
279
+ predictions=predictions_a,
280
+ references=references,
281
+ )
282
+ metric_b.add_batch(
283
+ predictions=predictions_b,
284
+ references=references,
285
+ )
286
+
287
+ eval_metric_a = metric_a.compute()
288
+ eval_metric_b = metric_b.compute()
289
+ # Use accelerator.print to print only on the main process.
290
+ zero2_accelerator.print(f"epoch {epoch}:", eval_metric_a, eval_metric_b)
291
+ performance_metric_a[f"epoch-{epoch}"] = eval_metric_a["accuracy"]
292
+ performance_metric_b[f"epoch-{epoch}"] = eval_metric_b["accuracy"]
293
+
294
+ if best_performance_a < eval_metric_a["accuracy"]:
295
+ best_performance_a = eval_metric_a["accuracy"]
296
+ if best_performance_b < eval_metric_b["accuracy"]:
297
+ best_performance_b = eval_metric_b["accuracy"]
298
+ assert best_performance_a > performance_metric_a["epoch-0"]
299
+ assert best_performance_b > performance_metric_b["epoch-0"]
300
+
301
+
302
+ def main():
303
+ parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
304
+ parser.add_argument(
305
+ "--model_name_or_path",
306
+ type=str,
307
+ default="bert-base-cased",
308
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
309
+ required=False,
310
+ )
311
+ parser.add_argument(
312
+ "--performance_lower_bound",
313
+ type=float,
314
+ default=None,
315
+ help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.",
316
+ )
317
+ parser.add_argument(
318
+ "--num_epochs",
319
+ type=int,
320
+ default=3,
321
+ help="Number of train epochs.",
322
+ )
323
+ args = parser.parse_args()
324
+ config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 8}
325
+ single_model_training(config, args)
326
+ AcceleratorState._reset_state(True)
327
+ multiple_model_training(config, args)
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()
accelerate/test_utils/scripts/external_deps/test_metrics.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ import os
18
+ from copy import deepcopy
19
+
20
+ import datasets
21
+ import evaluate
22
+ import torch
23
+ import transformers
24
+ from datasets import load_dataset
25
+ from torch.utils.data import DataLoader, IterableDataset
26
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
27
+
28
+ from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
29
+ from accelerate.data_loader import DataLoaderDispatcher
30
+ from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device
31
+ from accelerate.utils import is_torch_xla_available, set_seed
32
+
33
+
34
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
35
+
36
+
37
+ class ListHandler(logging.Handler):
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+ self.logs = []
41
+
42
+ def emit(self, record):
43
+ self.logs.append(record)
44
+
45
+
46
+ def get_basic_setup(accelerator, num_samples=82, batch_size=16):
47
+ "Returns everything needed to perform basic training"
48
+ set_seed(42)
49
+ model = RegressionModel()
50
+ ddp_model = deepcopy(model)
51
+ dset = RegressionDataset(length=num_samples)
52
+ dataloader = DataLoader(dset, batch_size=batch_size)
53
+ model.to(accelerator.device)
54
+ ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
55
+ return model, ddp_model, dataloader
56
+
57
+
58
+ def get_dataloader(accelerator: Accelerator, use_longest=False):
59
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased")
60
+ dataset = load_dataset("glue", "mrpc", split="validation")
61
+
62
+ def tokenize_function(examples):
63
+ outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
64
+ return outputs
65
+
66
+ with accelerator.main_process_first():
67
+ tokenized_datasets = dataset.map(
68
+ tokenize_function,
69
+ batched=True,
70
+ remove_columns=["idx", "sentence1", "sentence2"],
71
+ )
72
+
73
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
74
+
75
+ def collate_fn(examples):
76
+ if use_longest:
77
+ return tokenizer.pad(examples, padding="longest", return_tensors="pt")
78
+ return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
79
+
80
+ return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16)
81
+
82
+
83
+ def get_mrpc_setup(dispatch_batches, split_batches):
84
+ dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches)
85
+ accelerator = Accelerator(dataloader_config=dataloader_config)
86
+ dataloader = get_dataloader(accelerator, not dispatch_batches)
87
+ model = AutoModelForSequenceClassification.from_pretrained(
88
+ "hf-internal-testing/mrpc-bert-base-cased", return_dict=True
89
+ )
90
+ ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader)
91
+ return {
92
+ "ddp": [ddp_model, ddp_dataloader, torch_device],
93
+ "no": [model, dataloader, accelerator.device],
94
+ }, accelerator
95
+
96
+
97
+ def generate_predictions(model, dataloader, accelerator):
98
+ logits_and_targets = []
99
+ for batch in dataloader:
100
+ input, target = batch.values()
101
+ with torch.no_grad():
102
+ logit = model(input)
103
+ logit, target = accelerator.gather_for_metrics((logit, target))
104
+ logits_and_targets.append((logit, target))
105
+ logits, targs = [], []
106
+ for logit, targ in logits_and_targets:
107
+ logits.append(logit)
108
+ targs.append(targ)
109
+ logits, targs = torch.cat(logits), torch.cat(targs)
110
+ return logits, targs
111
+
112
+
113
+ def test_torch_metrics(
114
+ accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16
115
+ ):
116
+ _, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size)
117
+ logits, _ = generate_predictions(ddp_model, dataloader, accelerator)
118
+ assert len(logits) == num_samples, (
119
+ f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(logits)}"
120
+ )
121
+
122
+
123
+ def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):
124
+ metric = evaluate.load("glue", "mrpc")
125
+ setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches)
126
+ # First do baseline
127
+ model, dataloader, device = setup["no"]
128
+ model.to(device)
129
+ model.eval()
130
+ for batch in dataloader:
131
+ batch.to(device)
132
+ with torch.inference_mode():
133
+ outputs = model(**batch)
134
+ preds = outputs.logits.argmax(dim=-1)
135
+ metric.add_batch(predictions=preds, references=batch["labels"])
136
+ baseline = metric.compute()
137
+
138
+ # Then do distributed
139
+ model, dataloader, device = setup["ddp"]
140
+ model.eval()
141
+ for batch in dataloader:
142
+ with torch.inference_mode():
143
+ outputs = model(**batch)
144
+ preds = outputs.logits.argmax(dim=-1)
145
+ references = batch["labels"]
146
+ preds, references = accelerator.gather_for_metrics((preds, references))
147
+ metric.add_batch(predictions=preds, references=references)
148
+ distributed = metric.compute()
149
+
150
+ for key in "accuracy f1".split():
151
+ assert math.isclose(baseline[key], distributed[key]), (
152
+ f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n"
153
+ )
154
+
155
+
156
+ def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset():
157
+ class DummyIterableDataset(IterableDataset):
158
+ def __init__(self, data):
159
+ self.data = data
160
+
161
+ def __len__(self):
162
+ return len(self.data)
163
+
164
+ def __iter__(self):
165
+ yield from self.data
166
+
167
+ iterable_dataset = DummyIterableDataset([n for n in range(30)])
168
+ dataloader = DataLoader(iterable_dataset, batch_size=4)
169
+ accelerator = Accelerator()
170
+ prepared_dataloader = accelerator.prepare(dataloader)
171
+
172
+ if accelerator.is_main_process:
173
+ logger = logging.root.manager.loggerDict["accelerate.accelerator"]
174
+ list_handler = ListHandler()
175
+ logger.addHandler(list_handler)
176
+
177
+ batches_for_metrics = []
178
+ for batch in prepared_dataloader:
179
+ batches_for_metrics.append(accelerator.gather_for_metrics(batch))
180
+
181
+ assert torch.cat(batches_for_metrics).size(0) == 30
182
+
183
+ if accelerator.is_main_process:
184
+ assert len(list_handler.logs) == 0
185
+ logger.removeHandler(list_handler)
186
+
187
+
188
+ def test_gather_for_metrics_with_iterable_dataset():
189
+ class DummyIterableDataset(IterableDataset):
190
+ def __init__(self, data):
191
+ self.data = data
192
+
193
+ def __len__(self):
194
+ return len(self.data)
195
+
196
+ def __iter__(self):
197
+ yield from self.data
198
+
199
+ iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30)))
200
+ dataloader = DataLoader(iterable_dataset, batch_size=4)
201
+
202
+ accelerator = Accelerator()
203
+ prepared_dataloader = accelerator.prepare(dataloader)
204
+
205
+ assert isinstance(prepared_dataloader, DataLoaderDispatcher)
206
+
207
+ if accelerator.is_main_process:
208
+ logger = logging.root.manager.loggerDict["accelerate.accelerator"]
209
+ list_handler = ListHandler()
210
+ logger.addHandler(list_handler)
211
+
212
+ batches_for_metrics = []
213
+ for batch in prepared_dataloader:
214
+ batches_for_metrics.append(accelerator.gather_for_metrics(batch))
215
+
216
+ assert torch.cat(batches_for_metrics).size(0) == 30
217
+
218
+ if accelerator.is_main_process:
219
+ assert len(list_handler.logs) == 0
220
+
221
+ logger.removeHandler(list_handler)
222
+
223
+
224
+ def test_gather_for_metrics_drop_last():
225
+ accelerator = Accelerator()
226
+ per_device_batch_size = 5
227
+ num_items = (10 * accelerator.num_processes) + 1
228
+ dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True)
229
+ dataloader = accelerator.prepare(dataloader)
230
+
231
+ iterator = iter(dataloader)
232
+ next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0')
233
+ batch = next(iterator)
234
+ gathered_items = accelerator.gather_for_metrics(batch)
235
+
236
+ # Should return a full set of complete batches from each GPU
237
+ num_expected_items = per_device_batch_size * accelerator.num_processes
238
+ assert gathered_items.size(0) == (num_expected_items), (
239
+ f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}"
240
+ )
241
+
242
+
243
+ def main():
244
+ dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False)
245
+ accelerator = Accelerator(dataloader_config=dataloader_config)
246
+ if accelerator.is_local_main_process:
247
+ datasets.utils.logging.set_verbosity_warning()
248
+ transformers.utils.logging.set_verbosity_warning()
249
+ else:
250
+ datasets.utils.logging.set_verbosity_error()
251
+ transformers.utils.logging.set_verbosity_error()
252
+ # TorchXLA does not support batch dispatching. 'put_on_device' is always False for
253
+ # TorchXLA, which can cause a value error in 'prepare_data_loader' function.
254
+ dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False]
255
+
256
+ # Temporarily close this test for TorchXLA due to the 'Cannot set version_counter for
257
+ # inference tensor' error in inference mode. Reopen it after TorchXLA fixes this bug.
258
+ # These are a bit slower so they should only be ran on the GPU or TPU
259
+ if accelerator.device.type != "cpu" and not is_torch_xla_available():
260
+ if accelerator.is_local_main_process:
261
+ print("**Testing gather_for_metrics**")
262
+ for split_batches in [True, False]:
263
+ for dispatch_batches in dispatch_batches_options:
264
+ if accelerator.is_local_main_process:
265
+ print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`")
266
+ test_mrpc(dispatch_batches, split_batches)
267
+ accelerator.state._reset_state()
268
+ print("test_gather_for_metrics_with_iterable_dataset")
269
+ test_gather_for_metrics_with_iterable_dataset()
270
+ print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset")
271
+ test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()
272
+
273
+ # MpDeviceLoader in TorchXLA is an asynchronous loader that preloads several batches into cache.
274
+ # This can cause the 'end_of_dataloader' of DataLoaderStateMixin to be set earlier than intended.
275
+ # Skip this test when TorchXLA is enabled.
276
+ if accelerator.state.distributed_type != DistributedType.XLA:
277
+ if accelerator.is_local_main_process:
278
+ print("**Test torch metrics**")
279
+ for split_batches in [True, False]:
280
+ for dispatch_batches in dispatch_batches_options:
281
+ dataloader_config = DataLoaderConfiguration(
282
+ split_batches=split_batches, dispatch_batches=dispatch_batches
283
+ )
284
+ accelerator = Accelerator(dataloader_config=dataloader_config)
285
+ if accelerator.is_local_main_process:
286
+ print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99")
287
+ test_torch_metrics(accelerator, 99)
288
+ accelerator.state._reset_state()
289
+ if accelerator.is_local_main_process:
290
+ print("**Test last batch is not dropped when perfectly divisible**")
291
+ accelerator = Accelerator()
292
+ test_torch_metrics(accelerator, 512)
293
+ accelerator.state._reset_state()
294
+ if accelerator.is_local_main_process:
295
+ print("**Test that `drop_last` is taken into account**")
296
+ test_gather_for_metrics_drop_last()
297
+ accelerator.end_training()
298
+ accelerator.state._reset_state()
299
+
300
+
301
+ def _mp_fn(index):
302
+ # For xla_spawn (TPUs)
303
+ main()
304
+
305
+
306
+ if __name__ == "__main__":
307
+ main()
accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import gc
16
+ import json
17
+ import os
18
+
19
+ import torch
20
+ from datasets import load_dataset
21
+ from torch.optim import AdamW
22
+ from torch.utils.data import DataLoader
23
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
24
+
25
+ from accelerate import Accelerator, DistributedType
26
+ from accelerate.utils import (
27
+ is_hpu_available,
28
+ is_mlu_available,
29
+ is_musa_available,
30
+ is_neuron_available,
31
+ is_npu_available,
32
+ is_sdaa_available,
33
+ is_xpu_available,
34
+ )
35
+ from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
36
+
37
+
38
+ MAX_GPU_BATCH_SIZE = 16
39
+ EVAL_BATCH_SIZE = 32
40
+
41
+
42
+ # Converting Bytes to Megabytes
43
+ def b2mb(x):
44
+ return int(x / 2**20)
45
+
46
+
47
+ # This context manager is used to track the peak memory usage of the process
48
+ class TorchTracemalloc:
49
+ def __enter__(self):
50
+ gc.collect()
51
+ if torch.cuda.is_available():
52
+ torch.cuda.empty_cache()
53
+ torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
54
+ self.begin = torch.cuda.memory_allocated()
55
+ elif is_mlu_available():
56
+ torch.mlu.empty_cache()
57
+ torch.mlu.reset_max_memory_allocated() # reset the peak gauge to zero
58
+ self.begin = torch.mlu.memory_allocated()
59
+ elif is_sdaa_available():
60
+ torch.sdaa.empty_cache()
61
+ torch.sdaa.reset_max_memory_allocated() # reset the peak gauge to zero
62
+ self.begin = torch.sdaa.memory_allocated()
63
+ elif is_musa_available():
64
+ torch.musa.empty_cache()
65
+ torch.musa.reset_max_memory_allocated() # reset the peak gauge to zero
66
+ self.begin = torch.musa.memory_allocated()
67
+ elif is_npu_available():
68
+ torch.npu.empty_cache()
69
+ torch.npu.reset_max_memory_allocated() # reset the peak gauge to zero
70
+ self.begin = torch.npu.memory_allocated()
71
+ elif is_xpu_available():
72
+ torch.xpu.empty_cache()
73
+ torch.xpu.reset_peak_memory_stats() # reset the peak gauge to zero
74
+ self.begin = torch.xpu.memory_allocated()
75
+ elif is_hpu_available():
76
+ # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
77
+ torch.hpu.reset_peak_memory_stats() # reset the peak gauge to zero
78
+ self.begin = torch.hpu.memory_allocated()
79
+ elif is_neuron_available():
80
+ torch.neuron.empty_cache()
81
+ torch.neuron.reset_peak_memory_stats() # reset the peak gauge to zero
82
+ self.begin = torch.neuron.memory_allocated()
83
+ return self
84
+
85
+ def __exit__(self, *exc):
86
+ gc.collect()
87
+ if torch.cuda.is_available():
88
+ torch.cuda.empty_cache()
89
+ self.end = torch.cuda.memory_allocated()
90
+ self.peak = torch.cuda.max_memory_allocated()
91
+ elif is_mlu_available():
92
+ torch.mlu.empty_cache()
93
+ self.end = torch.mlu.memory_allocated()
94
+ self.begin = torch.mlu.max_memory_allocated()
95
+ elif is_sdaa_available():
96
+ torch.sdaa.empty_cache()
97
+ self.end = torch.sdaa.memory_allocated()
98
+ self.begin = torch.sdaa.max_memory_allocated()
99
+ elif is_musa_available():
100
+ torch.musa.empty_cache()
101
+ self.end = torch.musa.memory_allocated()
102
+ self.begin = torch.musa.max_memory_allocated()
103
+ elif is_npu_available():
104
+ torch.npu.empty_cache()
105
+ self.end = torch.npu.memory_allocated()
106
+ self.peak = torch.npu.max_memory_allocated()
107
+ elif is_xpu_available():
108
+ torch.xpu.empty_cache()
109
+ self.end = torch.xpu.memory_allocated()
110
+ self.peak = torch.xpu.max_memory_allocated()
111
+ elif is_hpu_available():
112
+ # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
113
+ self.end = torch.hpu.memory_allocated()
114
+ self.peak = torch.hpu.max_memory_allocated()
115
+ elif is_neuron_available():
116
+ torch.neuron.empty_cache()
117
+ self.end = torch.neuron.memory_allocated()
118
+ self.peak = torch.neuron.max_memory_allocated()
119
+ self.used = b2mb(self.end - self.begin)
120
+ self.peaked = b2mb(self.peak - self.begin)
121
+ # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
122
+
123
+
124
+ def get_dataloaders(
125
+ accelerator: Accelerator,
126
+ batch_size: int = 16,
127
+ model_name: str = "bert-base-cased",
128
+ n_train: int = 320,
129
+ n_val: int = 160,
130
+ ):
131
+ """
132
+ Creates a set of `DataLoader`s for the `glue` dataset.
133
+
134
+ Args:
135
+ accelerator (`Accelerator`):
136
+ An `Accelerator` object
137
+ batch_size (`int`, *optional*):
138
+ The batch size for the train and validation DataLoaders.
139
+ model_name (`str`, *optional*):
140
+ The name of the model to use.
141
+ n_train (`int`, *optional*):
142
+ The number of training examples to use.
143
+ n_val (`int`, *optional*):
144
+ The number of validation examples to use.
145
+ """
146
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
147
+ datasets = load_dataset(
148
+ "glue", "mrpc", split={"train": f"train[:{n_train}]", "validation": f"validation[:{n_val}]"}
149
+ )
150
+
151
+ def tokenize_function(examples):
152
+ # max_length=None => use the model max length (it's actually the default)
153
+ outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
154
+ return outputs
155
+
156
+ # Apply the method we just defined to all the examples in all the splits of the dataset
157
+ tokenized_datasets = datasets.map(
158
+ tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
159
+ )
160
+
161
+ # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
162
+ # transformers library
163
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
164
+
165
+ def collate_fn(examples):
166
+ # On TPU it's best to pad everything to the same length or training will be very slow.
167
+ if accelerator.distributed_type == DistributedType.XLA:
168
+ return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
169
+ return tokenizer.pad(examples, padding="longest", return_tensors="pt")
170
+
171
+ # Instantiate dataloaders.
172
+ train_dataloader = DataLoader(
173
+ tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
174
+ )
175
+ eval_dataloader = DataLoader(
176
+ tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
177
+ )
178
+
179
+ return train_dataloader, eval_dataloader
180
+
181
+
182
+ def training_function(config, args):
183
+ # Initialize accelerator
184
+ accelerator = Accelerator()
185
+
186
+ # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
187
+ lr = config["lr"]
188
+ num_epochs = int(config["num_epochs"])
189
+ seed = int(config["seed"])
190
+ batch_size = int(config["batch_size"])
191
+ model_name = args.model_name_or_path
192
+
193
+ set_seed(seed)
194
+ train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name, args.n_train, args.n_val)
195
+
196
+ # Instantiate the model (we build the model here so that the seed also control new weights initialization)
197
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
198
+
199
+ # Instantiate optimizer
200
+ optimizer_cls = (
201
+ AdamW
202
+ if accelerator.state.deepspeed_plugin is None
203
+ or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
204
+ else DummyOptim
205
+ )
206
+ optimizer = optimizer_cls(params=model.parameters(), lr=lr)
207
+
208
+ if accelerator.state.deepspeed_plugin is not None:
209
+ gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
210
+ "gradient_accumulation_steps"
211
+ ]
212
+ else:
213
+ gradient_accumulation_steps = 1
214
+ max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps
215
+
216
+ # Instantiate scheduler
217
+ if (
218
+ accelerator.state.deepspeed_plugin is None
219
+ or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
220
+ ):
221
+ lr_scheduler = get_linear_schedule_with_warmup(
222
+ optimizer=optimizer,
223
+ num_warmup_steps=0,
224
+ num_training_steps=max_training_steps,
225
+ )
226
+ else:
227
+ lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)
228
+
229
+ # Prepare everything
230
+ # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
231
+ # prepare method.
232
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
233
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
234
+ )
235
+
236
+ # We need to keep track of how many total steps we have iterated over
237
+ overall_step = 0
238
+ # We also need to keep track of the stating epoch so files are named properly
239
+ starting_epoch = 0
240
+
241
+ # Now we train the model
242
+ train_total_peak_memory = {}
243
+ for epoch in range(starting_epoch, num_epochs):
244
+ with TorchTracemalloc() as tracemalloc:
245
+ model.train()
246
+ for step, batch in enumerate(train_dataloader):
247
+ outputs = model(**batch)
248
+ loss = outputs.loss
249
+ loss = loss / gradient_accumulation_steps
250
+ accelerator.backward(loss)
251
+ if step % gradient_accumulation_steps == 0:
252
+ optimizer.step()
253
+ lr_scheduler.step()
254
+ optimizer.zero_grad()
255
+
256
+ overall_step += 1
257
+
258
+ # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
259
+ accelerator.print(f"Memory before entering the train : {b2mb(tracemalloc.begin)}")
260
+ accelerator.print(f"Memory consumed at the end of the train (end-begin): {tracemalloc.used}")
261
+ accelerator.print(f"Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}")
262
+ accelerator.print(
263
+ f"Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
264
+ )
265
+ train_total_peak_memory[f"epoch-{epoch}"] = tracemalloc.peaked + b2mb(tracemalloc.begin)
266
+ if args.peak_memory_upper_bound is not None:
267
+ assert train_total_peak_memory[f"epoch-{epoch}"] <= args.peak_memory_upper_bound, (
268
+ "Peak memory usage exceeded the upper bound"
269
+ )
270
+
271
+ accelerator.wait_for_everyone()
272
+ if accelerator.is_main_process:
273
+ with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f:
274
+ json.dump(train_total_peak_memory, f)
275
+ accelerator.end_training()
276
+
277
+
278
+ def main():
279
+ parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
280
+ parser.add_argument(
281
+ "--model_name_or_path",
282
+ type=str,
283
+ default="bert-base-cased",
284
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
285
+ required=False,
286
+ )
287
+ parser.add_argument(
288
+ "--output_dir",
289
+ type=str,
290
+ default=".",
291
+ help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
292
+ )
293
+ parser.add_argument(
294
+ "--peak_memory_upper_bound",
295
+ type=float,
296
+ default=None,
297
+ help="The upper bound of peak memory usage in MB. If set, the training will throw an error if the peak memory usage exceeds this value.",
298
+ )
299
+ parser.add_argument(
300
+ "--n_train",
301
+ type=int,
302
+ default=320,
303
+ help="Number of training examples to use.",
304
+ )
305
+ parser.add_argument(
306
+ "--n_val",
307
+ type=int,
308
+ default=160,
309
+ help="Number of validation examples to use.",
310
+ )
311
+ parser.add_argument(
312
+ "--num_epochs",
313
+ type=int,
314
+ default=1,
315
+ help="Number of train epochs.",
316
+ )
317
+ args = parser.parse_args()
318
+ config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
319
+ training_function(config, args)
320
+
321
+
322
+ if __name__ == "__main__":
323
+ main()
accelerate/test_utils/scripts/external_deps/test_performance.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import json
16
+ import os
17
+ from contextlib import nullcontext
18
+ from pathlib import Path
19
+
20
+ import evaluate
21
+ import torch
22
+ from datasets import load_dataset
23
+ from torch.optim import AdamW
24
+ from torch.utils.data import DataLoader
25
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
26
+
27
+ from accelerate import Accelerator, DistributedType
28
+ from accelerate.parallelism_config import ParallelismConfig
29
+ from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
30
+ from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
31
+
32
+
33
+ MAX_GPU_BATCH_SIZE = 16
34
+ EVAL_BATCH_SIZE = 32
35
+
36
+
37
+ def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"):
38
+ """
39
+ Creates a set of `DataLoader`s for the `glue` dataset.
40
+
41
+ Args:
42
+ accelerator (`Accelerator`):
43
+ An `Accelerator` object
44
+ batch_size (`int`, *optional*):
45
+ The batch size for the train and validation DataLoaders.
46
+ model_name (`str`, *optional*):
47
+ """
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+
50
+ datasets = load_dataset("glue", "mrpc")
51
+
52
+ def tokenize_function(examples):
53
+ # max_length=None => use the model max length (it's actually the default)
54
+ outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
55
+ return outputs
56
+
57
+ # Apply the method we just defined to all the examples in all the splits of the dataset
58
+ tokenized_datasets = datasets.map(
59
+ tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
60
+ )
61
+
62
+ # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
63
+ # transformers library
64
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
65
+
66
+ def collate_fn(examples):
67
+ # On TPU it's best to pad everything to the same length or training will be very slow.
68
+ if accelerator.distributed_type == DistributedType.XLA:
69
+ return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
70
+ return tokenizer.pad(examples, padding="longest", return_tensors="pt")
71
+
72
+ # Instantiate dataloaders.
73
+ train_dataloader = DataLoader(
74
+ tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
75
+ )
76
+ eval_dataloader = DataLoader(
77
+ tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
78
+ )
79
+
80
+ return train_dataloader, eval_dataloader
81
+
82
+
83
+ def training_function(config, args):
84
+ accelerator_kwargs = {}
85
+ # need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail
86
+ if args.tp_size is not None:
87
+ accelerator_kwargs["parallelism_config"] = ParallelismConfig(tp_size=args.tp_size)
88
+
89
+ # Initialize accelerator
90
+ accelerator = Accelerator(**accelerator_kwargs)
91
+
92
+ # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
93
+ lr = config["lr"]
94
+ num_epochs = int(config["num_epochs"])
95
+ seed = int(config["seed"])
96
+ batch_size = int(config["batch_size"])
97
+ model_name = args.model_name_or_path
98
+
99
+ set_seed(seed)
100
+ train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
101
+
102
+ # Add TP related kwargs if provided
103
+ model_kwargs = {}
104
+ if args.tp_plan is not None:
105
+ model_kwargs["tp_plan"] = args.tp_plan
106
+ if args.tp_size is not None:
107
+ model_kwargs["tp_size"] = args.tp_size
108
+
109
+ # Instantiate the model (we build the model here so that the seed also control new weights initialization)
110
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, **model_kwargs)
111
+
112
+ if args.add_pad_token:
113
+ if model.config.pad_token_id is None:
114
+ model.config.pad_token_id = 0
115
+
116
+ # Instantiate optimizer
117
+ optimizer_cls = (
118
+ AdamW
119
+ if accelerator.state.deepspeed_plugin is None
120
+ or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
121
+ else DummyOptim
122
+ )
123
+ optimizer = optimizer_cls(params=model.parameters(), lr=lr)
124
+
125
+ max_training_steps = len(train_dataloader) * num_epochs
126
+
127
+ # Instantiate scheduler
128
+ linear_decay_scheduler = False
129
+ if (
130
+ accelerator.state.deepspeed_plugin is None
131
+ or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
132
+ ):
133
+ lr_scheduler = get_linear_schedule_with_warmup(
134
+ optimizer=optimizer,
135
+ num_warmup_steps=0,
136
+ num_training_steps=max_training_steps,
137
+ )
138
+ linear_decay_scheduler = True
139
+ else:
140
+ lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)
141
+
142
+ # Prepare everything
143
+ # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
144
+ # prepare method.
145
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
146
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
147
+ )
148
+
149
+ # We also need to keep track of the stating epoch so files are named properly
150
+ starting_epoch = 0
151
+
152
+ # Now we train the model
153
+ metric = evaluate.load("glue", "mrpc")
154
+ best_performance = 0
155
+ performance_metric = {}
156
+ expected_lr_after_first_optim_step = lr * (
157
+ 1 - 1 / (max_training_steps / accelerator.num_processes / accelerator.gradient_accumulation_steps)
158
+ )
159
+ lr_scheduler_check_completed = False
160
+ for epoch in range(starting_epoch, num_epochs):
161
+ model.train()
162
+ for step, batch in enumerate(train_dataloader):
163
+ with accelerator.accumulate(model):
164
+ outputs = model(**batch)
165
+ loss = outputs.loss
166
+ accelerator.backward(loss)
167
+ context = nullcontext
168
+ if args.tp_plan is not None:
169
+ from torch.distributed._tensor.experimental import implicit_replication
170
+
171
+ context = implicit_replication
172
+ with context():
173
+ optimizer.step()
174
+ lr_scheduler.step()
175
+ optimizer.zero_grad()
176
+
177
+ # assert the learning rate after first optimizer step
178
+ if (
179
+ accelerator.sync_gradients
180
+ and not lr_scheduler_check_completed
181
+ and linear_decay_scheduler
182
+ and accelerator.state.mixed_precision == "no"
183
+ ):
184
+ assert lr_scheduler.get_last_lr()[0] == expected_lr_after_first_optim_step, (
185
+ f"Wrong lr found at second step, expected {expected_lr_after_first_optim_step}, got {lr_scheduler.get_last_lr()[0]}"
186
+ )
187
+ lr_scheduler_check_completed = True
188
+
189
+ model.eval()
190
+ samples_seen = 0
191
+ for step, batch in enumerate(eval_dataloader):
192
+ # We could avoid this line since we set the accelerator with `device_placement=True`.
193
+ batch.to(accelerator.device)
194
+ with torch.no_grad():
195
+ outputs = model(**batch)
196
+ predictions = outputs.logits.argmax(dim=-1)
197
+ # It is slightly faster to call this once, than multiple times
198
+ predictions, references = accelerator.gather(
199
+ (predictions, batch["labels"])
200
+ ) # If we are in a multiprocess environment, the last batch has duplicates
201
+ if accelerator.use_distributed:
202
+ if step == len(eval_dataloader) - 1:
203
+ predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
204
+ references = references[: len(eval_dataloader.dataset) - samples_seen]
205
+ else:
206
+ samples_seen += references.shape[0]
207
+ metric.add_batch(
208
+ predictions=predictions,
209
+ references=references,
210
+ )
211
+
212
+ eval_metric = metric.compute()
213
+ # Use accelerator.print to print only on the main process.
214
+ accelerator.print(f"epoch {epoch}:", eval_metric)
215
+ performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"]
216
+
217
+ if best_performance < eval_metric["accuracy"]:
218
+ best_performance = eval_metric["accuracy"]
219
+
220
+ # check that the LR is 0
221
+ if linear_decay_scheduler and accelerator.state.mixed_precision == "no":
222
+ assert lr_scheduler.get_last_lr()[0] == 0, (
223
+ f"Wrong lr found at last step, expected 0, got {lr_scheduler.get_last_lr()[0]}"
224
+ )
225
+
226
+ if args.performance_lower_bound is not None:
227
+ assert args.performance_lower_bound <= best_performance, (
228
+ f"Best performance metric {best_performance} is lower than the lower bound {args.performance_lower_bound}"
229
+ )
230
+
231
+ accelerator.wait_for_everyone()
232
+ if accelerator.is_main_process:
233
+ with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
234
+ json.dump(performance_metric, f)
235
+
236
+ # TODO: skip saving of the model test for TP until the feature lands
237
+ if args.tp_plan is None:
238
+ # Finally try saving the model
239
+ accelerator.save_model(model, args.output_dir)
240
+ accelerator.wait_for_everyone()
241
+ if args.tp_plan is None:
242
+ assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
243
+ "Model was not saved when calling `Accelerator.save_model`"
244
+ )
245
+ accelerator.end_training()
246
+
247
+
248
+ def main():
249
+ parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
250
+ parser.add_argument(
251
+ "--model_name_or_path",
252
+ type=str,
253
+ default="bert-base-cased",
254
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
255
+ required=False,
256
+ )
257
+ parser.add_argument(
258
+ "--output_dir",
259
+ type=str,
260
+ default=".",
261
+ help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
262
+ )
263
+ parser.add_argument(
264
+ "--performance_lower_bound",
265
+ type=float,
266
+ default=None,
267
+ help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.",
268
+ )
269
+ parser.add_argument(
270
+ "--num_epochs",
271
+ type=int,
272
+ default=3,
273
+ help="Number of train epochs.",
274
+ )
275
+ parser.add_argument(
276
+ "--add_pad_token",
277
+ type=bool,
278
+ default=False,
279
+ help="To add pad token if not exists.",
280
+ )
281
+ parser.add_argument(
282
+ "--tp_plan",
283
+ type=str,
284
+ default=None,
285
+ help="pass 'auto' to use TP",
286
+ )
287
+ parser.add_argument(
288
+ "--tp_size",
289
+ type=int,
290
+ default=None,
291
+ help="TP size to be used to shard the model",
292
+ )
293
+ args = parser.parse_args()
294
+ config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
295
+ training_function(config, args)
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
accelerate/test_utils/scripts/external_deps/test_pippy.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from transformers import (
16
+ BertConfig,
17
+ BertForMaskedLM,
18
+ GPT2Config,
19
+ GPT2ForSequenceClassification,
20
+ )
21
+
22
+ from accelerate import PartialState
23
+ from accelerate.inference import prepare_pippy
24
+ from accelerate.test_utils import torch_device
25
+ from accelerate.utils import DistributedType, set_seed
26
+
27
+
28
+ model_to_config = {
29
+ "bert": (BertForMaskedLM, BertConfig, 512),
30
+ "gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
31
+ }
32
+
33
+
34
+ def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
35
+ initializer, config, seq_len = model_to_config[model_name]
36
+ config_args = {}
37
+ # Eventually needed for batch inference tests on gpt-2 when bs != 1
38
+ # if model_name == "gpt2":
39
+ # config_args["pad_token_id"] = 0
40
+ model_config = config(**config_args)
41
+ model = initializer(model_config)
42
+ kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
43
+ trace_input = torch.randint(size=(1, seq_len), **kwargs)
44
+ inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
45
+ return model, trace_input, inference_inputs
46
+
47
+
48
+ def test_bert(batch_size: int = 2):
49
+ set_seed(42)
50
+ state = PartialState()
51
+ model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
52
+ model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
53
+ # For inference args need to be a tuple
54
+ inputs = inference_inputs.to(torch_device)
55
+ with torch.no_grad():
56
+ output = model(inputs)
57
+ # Zach: Check that we just grab the real outputs we need at the end
58
+ if not state.is_last_process:
59
+ assert output is None, "Output was not generated on just the last process!"
60
+ else:
61
+ assert output is not None, "Output was not generated in the last process!"
62
+
63
+
64
+ def test_gpt2(batch_size: int = 2):
65
+ set_seed(42)
66
+ state = PartialState()
67
+ model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
68
+ model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
69
+ # For inference args need to be a tuple
70
+ inputs = inference_inputs.to(torch_device)
71
+ with torch.no_grad():
72
+ output = model(inputs)
73
+ # Zach: Check that we just grab the real outputs we need at the end
74
+ if not state.is_last_process:
75
+ assert output is None, "Output was not generated on just the last process!"
76
+ else:
77
+ assert output is not None, "Output was not generated in the last process!"
78
+
79
+
80
+ # Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
81
+ # def test_resnet(batch_size: int = 2):
82
+ # set_seed(42)
83
+ # state = PartialState()
84
+ # model = resnet34()
85
+ # input_tensor = torch.rand(1, 3, 224, 224)
86
+ # model = prepare_pippy(
87
+ # model,
88
+ # example_args=(input_tensor,),
89
+ # )
90
+ # inference_inputs = torch.rand(batch_size, 3, 224, 224)
91
+ # inputs = send_to_device(inference_inputs, torch_device)
92
+ # with torch.no_grad():
93
+ # output = model(inputs)
94
+ # # Zach: Check that we just grab the real outputs we need at the end
95
+ # if not state.is_last_process:
96
+ # assert output is None, "Output was not generated on just the last process!"
97
+ # else:
98
+ # assert output is not None, "Output was not generated in the last process!"
99
+
100
+
101
+ if __name__ == "__main__":
102
+ state = PartialState()
103
+ state.print("Testing pippy integration...")
104
+ try:
105
+ if state.distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_HPU]:
106
+ state.print("Testing GPT2...")
107
+ test_gpt2()
108
+ # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
109
+ # due to references
110
+ # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
111
+ # test_gpt2(3)
112
+ state.print("Testing BERT...")
113
+ test_bert()
114
+ else:
115
+ print("Less than two GPUs found, not running tests!")
116
+ finally:
117
+ state.destroy_process_group()
accelerate/test_utils/scripts/external_deps/test_zero3_integration.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.distributed
16
+
17
+ from accelerate.test_utils import require_huggingface_suite, torch_device
18
+ from accelerate.utils import is_transformers_available
19
+
20
+
21
+ if is_transformers_available():
22
+ from transformers import AutoModel, TrainingArguments
23
+
24
+
25
+ GPT2_TINY = "sshleifer/tiny-gpt2"
26
+
27
+
28
+ @require_huggingface_suite
29
+ def init_torch_dist_then_launch_deepspeed():
30
+ if torch_device == "xpu":
31
+ backend = "xccl"
32
+ elif torch_device == "hpu":
33
+ backend = "hccl"
34
+ else:
35
+ backend = "nccl"
36
+
37
+ torch.distributed.init_process_group(backend=backend)
38
+ deepspeed_config = {
39
+ "zero_optimization": {
40
+ "stage": 3,
41
+ },
42
+ "train_batch_size": "auto",
43
+ "train_micro_batch_size_per_gpu": "auto",
44
+ }
45
+ train_args = TrainingArguments(
46
+ output_dir="./",
47
+ deepspeed=deepspeed_config,
48
+ )
49
+ model = AutoModel.from_pretrained(GPT2_TINY)
50
+ assert train_args is not None
51
+ assert model is not None
52
+
53
+
54
+ def main():
55
+ init_torch_dist_then_launch_deepspeed()
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
accelerate/test_utils/scripts/test_cli.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ from accelerate.utils import is_xpu_available
17
+
18
+
19
+ def main():
20
+ accelerator_type = "GPU"
21
+ num_accelerators = 0
22
+ if torch.cuda.is_available():
23
+ num_accelerators = torch.cuda.device_count()
24
+ accelerator_type = "GPU"
25
+ elif is_xpu_available():
26
+ num_accelerators = torch.xpu.device_count()
27
+ accelerator_type = "XPU"
28
+ print(f"Successfully ran on {num_accelerators} {accelerator_type}s")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
accelerate/test_utils/scripts/test_ddp_comm_hook.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState
17
+ from accelerate.utils import is_hpu_available
18
+
19
+
20
+ class MockModel(torch.nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ torch.manual_seed(0)
24
+ self.p = torch.nn.Parameter(torch.randn(40, 20))
25
+
26
+ def forward(self, x, rank):
27
+ return self.p * (x ** (1 + rank))
28
+
29
+
30
+ def _run_and_get_grads(model, rank):
31
+ torch.manual_seed(2024)
32
+ input = torch.randn(40, 20)
33
+ output = model(input, rank)
34
+ output.mean().backward()
35
+ param = next(model.parameters())
36
+ return param.grad
37
+
38
+
39
+ def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option):
40
+ ddp_kwargs = DistributedDataParallelKwargs(
41
+ comm_hook=comm_hook,
42
+ comm_wrapper=comm_wrapper,
43
+ comm_state_option=comm_state_option,
44
+ )
45
+ accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
46
+
47
+ model = accelerator.prepare(MockModel())
48
+ hook_grads = _run_and_get_grads(model, accelerator.local_process_index)
49
+
50
+ reference_model = torch.nn.parallel.DistributedDataParallel(
51
+ MockModel().to(accelerator.device),
52
+ device_ids=[accelerator.local_process_index],
53
+ output_device=accelerator.local_process_index,
54
+ )
55
+ reference_grads = _run_and_get_grads(reference_model, accelerator.local_process_index)
56
+
57
+ torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-2, atol=1e-2)
58
+
59
+
60
+ def main():
61
+ for comm_hook, comm_wrapper, comm_state_option in [
62
+ (DDPCommunicationHookType.NO, DDPCommunicationHookType.NO, {}),
63
+ (DDPCommunicationHookType.FP16, DDPCommunicationHookType.NO, {}),
64
+ (DDPCommunicationHookType.BF16, DDPCommunicationHookType.NO, {}),
65
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {}),
66
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.FP16, {}),
67
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BF16, {}),
68
+ (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {"matrix_approximation_rank": 2}),
69
+ (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.NO, {}),
70
+ (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.FP16, {}),
71
+ (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.BF16, {}),
72
+ ]:
73
+ if is_hpu_available():
74
+ HPU_UNSUPPORTED_COMM_HOOKS = {DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16}
75
+ if comm_hook in HPU_UNSUPPORTED_COMM_HOOKS or comm_wrapper in HPU_UNSUPPORTED_COMM_HOOKS:
76
+ print(f"Skipping test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper} on HPU")
77
+ continue
78
+
79
+ print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
80
+ test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
81
+ PartialState().destroy_process_group()
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
accelerate/test_utils/scripts/test_distributed_data_loop.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import pickle
18
+ import tempfile
19
+ import warnings
20
+ from unittest.mock import Mock
21
+
22
+ import torch
23
+ from torch.utils.data import (
24
+ BatchSampler,
25
+ DataLoader,
26
+ Dataset,
27
+ IterableDataset,
28
+ RandomSampler,
29
+ TensorDataset,
30
+ default_collate,
31
+ )
32
+
33
+ from accelerate.accelerator import Accelerator, DataLoaderConfiguration
34
+ from accelerate.utils.dataclasses import DistributedType
35
+
36
+
37
+ NUM_ELEMENTS = 22
38
+ NUM_WORKERS = 4
39
+ BATCH_SIZE = 4
40
+
41
+
42
+ class DummyDataset(Dataset):
43
+ def __len__(self):
44
+ return NUM_ELEMENTS
45
+
46
+ def __getitem__(self, index):
47
+ squeeze = False
48
+
49
+ if isinstance(index, int):
50
+ index = [index]
51
+ squeeze = True
52
+ elif isinstance(index, slice):
53
+ index = list(range(*index.indices(self.size)))
54
+ else:
55
+ index = list(index)
56
+
57
+ batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index]
58
+
59
+ if squeeze:
60
+ batch = batch[0]
61
+
62
+ return batch
63
+
64
+
65
+ class DummyIterableDataset(IterableDataset):
66
+ def __init__(self, data):
67
+ self.data = data
68
+
69
+ def __iter__(self):
70
+ yield from self.data
71
+
72
+
73
+ def create_accelerator(even_batches=True):
74
+ dataloader_config = DataLoaderConfiguration(even_batches=even_batches)
75
+ accelerator = Accelerator(dataloader_config=dataloader_config)
76
+ assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
77
+ return accelerator
78
+
79
+
80
+ def create_dataloader(
81
+ accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
82
+ ):
83
+ """
84
+ Create a simple DataLoader to use during the test cases
85
+ """
86
+ values = torch.as_tensor(range(dataset_size))
87
+ if shuffle:
88
+ values = values[torch.randperm(values.size(0))]
89
+ if iterable:
90
+ dataset = DummyIterableDataset(values)
91
+ else:
92
+ dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
93
+
94
+ dl = DataLoader(dataset, batch_size=batch_size)
95
+ dl = accelerator.prepare(dl)
96
+
97
+ return dl
98
+
99
+
100
+ def verify_dataloader_batch_sizes(
101
+ accelerator: Accelerator,
102
+ dataset_size: int,
103
+ batch_size: int,
104
+ process_0_expected_batch_sizes: list[int],
105
+ process_1_expected_batch_sizes: list[int],
106
+ ):
107
+ """
108
+ A helper function for verifying the batch sizes coming from a prepared dataloader in each process
109
+ """
110
+ dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size)
111
+
112
+ batch_sizes = [len(batch[0]) for batch in dl]
113
+
114
+ if accelerator.process_index == 0:
115
+ assert batch_sizes == process_0_expected_batch_sizes
116
+ elif accelerator.process_index == 1:
117
+ assert batch_sizes == process_1_expected_batch_sizes
118
+
119
+
120
+ def test_default_ensures_even_batch_sizes():
121
+ accelerator = create_accelerator()
122
+
123
+ # without padding, we would expect a different number of batches
124
+ verify_dataloader_batch_sizes(
125
+ accelerator,
126
+ dataset_size=3,
127
+ batch_size=1,
128
+ process_0_expected_batch_sizes=[1, 1],
129
+ process_1_expected_batch_sizes=[1, 1],
130
+ )
131
+
132
+ # without padding, we would expect the same number of batches, but different sizes
133
+ verify_dataloader_batch_sizes(
134
+ accelerator,
135
+ dataset_size=7,
136
+ batch_size=2,
137
+ process_0_expected_batch_sizes=[2, 2],
138
+ process_1_expected_batch_sizes=[2, 2],
139
+ )
140
+
141
+
142
+ def test_can_disable_even_batches():
143
+ accelerator = create_accelerator(even_batches=False)
144
+
145
+ verify_dataloader_batch_sizes(
146
+ accelerator,
147
+ dataset_size=3,
148
+ batch_size=1,
149
+ process_0_expected_batch_sizes=[1, 1],
150
+ process_1_expected_batch_sizes=[1],
151
+ )
152
+
153
+ verify_dataloader_batch_sizes(
154
+ accelerator,
155
+ dataset_size=7,
156
+ batch_size=2,
157
+ process_0_expected_batch_sizes=[2, 2],
158
+ process_1_expected_batch_sizes=[2, 1],
159
+ )
160
+
161
+
162
+ def test_can_join_uneven_inputs():
163
+ accelerator = create_accelerator(even_batches=False)
164
+
165
+ model = torch.nn.Linear(1, 1)
166
+ ddp_model = accelerator.prepare(model)
167
+
168
+ dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
169
+
170
+ batch_idxs = []
171
+ with accelerator.join_uneven_inputs([ddp_model]):
172
+ for batch_idx, batch in enumerate(dl):
173
+ output = ddp_model(batch[0].float())
174
+ loss = output.sum()
175
+ loss.backward()
176
+ batch_idxs.append(batch_idx)
177
+
178
+ accelerator.wait_for_everyone()
179
+
180
+ if accelerator.process_index == 0:
181
+ assert batch_idxs == [0, 1]
182
+ elif accelerator.process_index == 1:
183
+ assert batch_idxs == [0]
184
+
185
+
186
+ def test_join_raises_warning_for_non_ddp_distributed(accelerator):
187
+ with warnings.catch_warnings(record=True) as w:
188
+ with accelerator.join_uneven_inputs([Mock()]):
189
+ pass
190
+
191
+ assert issubclass(w[-1].category, UserWarning)
192
+ assert "only supported for multi-GPU" in str(w[-1].message)
193
+
194
+
195
+ def test_join_can_override_even_batches():
196
+ default_even_batches = True
197
+ overridden_even_batches = False
198
+ accelerator = create_accelerator(even_batches=default_even_batches)
199
+ model = torch.nn.Linear(1, 1)
200
+ ddp_model = accelerator.prepare(model)
201
+ train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
202
+ valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
203
+
204
+ with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
205
+ train_dl_overridden_value = train_dl.batch_sampler.even_batches
206
+ valid_dl_overridden_value = valid_dl.batch_sampler.even_batches
207
+
208
+ assert train_dl_overridden_value == overridden_even_batches
209
+ assert valid_dl_overridden_value == overridden_even_batches
210
+ assert train_dl.batch_sampler.even_batches == default_even_batches
211
+ assert valid_dl.batch_sampler.even_batches == default_even_batches
212
+
213
+
214
+ def test_join_can_override_for_mixed_type_dataloaders():
215
+ default_even_batches = True
216
+ overridden_even_batches = False
217
+ accelerator = create_accelerator(even_batches=default_even_batches)
218
+ model = torch.nn.Linear(1, 1)
219
+ ddp_model = accelerator.prepare(model)
220
+ create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
221
+ batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
222
+
223
+ with warnings.catch_warnings():
224
+ warnings.filterwarnings("ignore")
225
+ try:
226
+ with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
227
+ batch_dl_overridden_value = batch_dl.batch_sampler.even_batches
228
+ except AttributeError:
229
+ # ensure attribute error is not raised when processing iterable dl
230
+ raise AssertionError
231
+
232
+ assert batch_dl_overridden_value == overridden_even_batches
233
+ assert batch_dl.batch_sampler.even_batches == default_even_batches
234
+
235
+
236
+ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
237
+ accelerator = create_accelerator()
238
+ model = torch.nn.Linear(1, 1)
239
+ ddp_model = accelerator.prepare(model)
240
+ create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
241
+
242
+ with warnings.catch_warnings(record=True) as w:
243
+ with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
244
+ pass
245
+
246
+ assert issubclass(w[-1].category, UserWarning)
247
+ assert "only supported for map-style datasets" in str(w[-1].message)
248
+
249
+
250
+ def test_pickle_accelerator():
251
+ accelerator = create_accelerator()
252
+ data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
253
+ _ = accelerator.prepare(data_loader)
254
+ pickled_accelerator = pickle.dumps(accelerator)
255
+ unpickled_accelerator = pickle.loads(pickled_accelerator)
256
+ # TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
257
+ assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
258
+
259
+
260
+ def test_data_loader(data_loader, accelerator):
261
+ # Prepare the DataLoader
262
+ data_loader = accelerator.prepare(data_loader)
263
+
264
+ all_examples = []
265
+ for i, batch in enumerate(data_loader):
266
+ index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"]))
267
+ all_examples.extend(index.detach().cpu().numpy().tolist())
268
+
269
+ # Sort the examples
270
+ sorted_all_examples = sorted(all_examples)
271
+
272
+ # Check if all elements are present in the sorted list of iterated samples
273
+ assert len(set(sorted_all_examples)) == NUM_ELEMENTS, (
274
+ "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
275
+ )
276
+
277
+
278
+ def _test_stateful_dataloader_resume(accelerator, iterable):
279
+ """
280
+ Helper: iterate a stateful dataloader, save state after a few batches using `load_state_dict`,
281
+ resume from the saved state, and verify the resumed batches match what was originally unseen.
282
+
283
+ Saves early (after 3 batches) so many batches remain, exposing any off-by-one in state restoration.
284
+ Tested with both iterable and map-style datasets to cover different state_dict code paths.
285
+ """
286
+ old_dataloader_config = accelerator.dataloader_config
287
+ try:
288
+ accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
289
+ prepared_dl = create_dataloader(
290
+ accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True
291
+ )
292
+ untrained_batches = []
293
+ save_step = 2
294
+ for step, batch in enumerate(prepared_dl):
295
+ if step == save_step:
296
+ state_dict = prepared_dl.state_dict()
297
+ if step > save_step:
298
+ untrained_batches.append(batch)
299
+ not_skipped_batches = accelerator.gather(untrained_batches)
300
+ prepared_dl.load_state_dict(state_dict)
301
+ resumed_batches = []
302
+ for batch in prepared_dl:
303
+ resumed_batches.append(batch)
304
+ resumed_batches = accelerator.gather(resumed_batches)
305
+ assert len(not_skipped_batches) == len(resumed_batches), (
306
+ f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}"
307
+ )
308
+ for b1, b2 in zip(not_skipped_batches, resumed_batches):
309
+ for v1, v2 in zip(b1, b2):
310
+ assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
311
+ finally:
312
+ accelerator.dataloader_config = old_dataloader_config
313
+
314
+
315
+ def test_stateful_dataloader(accelerator):
316
+ """
317
+ Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
318
+ resumed from the saved state.
319
+
320
+ The result should be the same as the rest of the data that iterated over after saving.
321
+ """
322
+ _test_stateful_dataloader_resume(accelerator, iterable=True)
323
+ _test_stateful_dataloader_resume(accelerator, iterable=False)
324
+
325
+
326
+ def _test_stateful_dataloader_save_state_resume(accelerator, iterable):
327
+ """
328
+ Helper: iterate a stateful dataloader, save state after a few batches using `Accelerator.save_state`,
329
+ resume, and verify the resumed batches match what was originally unseen.
330
+ """
331
+ old_dataloader_config = accelerator.dataloader_config
332
+ try:
333
+ with tempfile.TemporaryDirectory() as tmpdir:
334
+ accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
335
+ prepared_dl = create_dataloader(
336
+ accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True
337
+ )
338
+ untrained_batches = []
339
+ save_step = 2
340
+ for step, batch in enumerate(prepared_dl):
341
+ if step == save_step:
342
+ accelerator.save_state(tmpdir)
343
+ if step > save_step:
344
+ untrained_batches.append(batch)
345
+ not_skipped_batches = accelerator.gather(untrained_batches)
346
+ accelerator.load_state(tmpdir)
347
+ resumed_batches = []
348
+ for batch in prepared_dl:
349
+ resumed_batches.append(batch)
350
+ resumed_batches = accelerator.gather(resumed_batches)
351
+ assert len(not_skipped_batches) == len(resumed_batches), (
352
+ f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}"
353
+ )
354
+ for b1, b2 in zip(not_skipped_batches, resumed_batches):
355
+ for v1, v2 in zip(b1, b2):
356
+ assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
357
+ finally:
358
+ accelerator.dataloader_config = old_dataloader_config
359
+
360
+
361
+ def test_stateful_dataloader_save_state(accelerator):
362
+ """
363
+ Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
364
+ and then resumed from the saved state.
365
+
366
+ The result should be the same as the rest of the data that iterated over after saving.
367
+ """
368
+ _test_stateful_dataloader_save_state_resume(accelerator, iterable=True)
369
+ _test_stateful_dataloader_save_state_resume(accelerator, iterable=False)
370
+
371
+
372
+ def main():
373
+ accelerator = create_accelerator()
374
+ torch.manual_seed(accelerator.process_index)
375
+
376
+ accelerator.print("Test that even_batches variable ensures uniform batches across processes")
377
+ test_default_ensures_even_batch_sizes()
378
+
379
+ accelerator.print("Run tests with even_batches disabled")
380
+ test_can_disable_even_batches()
381
+
382
+ accelerator.print("Test joining uneven inputs")
383
+ test_can_join_uneven_inputs()
384
+
385
+ accelerator.print("Test overriding even_batches when joining uneven inputs")
386
+ test_join_can_override_even_batches()
387
+
388
+ accelerator.print("Test overriding even_batches for mixed dataloader types")
389
+ test_join_can_override_for_mixed_type_dataloaders()
390
+
391
+ accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders")
392
+ test_join_raises_warning_for_iterable_when_overriding_even_batches()
393
+
394
+ accelerator.print("Test join with non DDP distributed raises warning")
395
+ original_state = accelerator.state.distributed_type
396
+ accelerator.state.distributed_type = DistributedType.FSDP
397
+ test_join_raises_warning_for_non_ddp_distributed(accelerator)
398
+ accelerator.state.distributed_type = original_state
399
+
400
+ accelerator.print("Test pickling an accelerator")
401
+ test_pickle_accelerator()
402
+
403
+ dataset = DummyDataset()
404
+
405
+ accelerator.print("Test DataLoader with shuffle=False")
406
+ loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
407
+ test_data_loader(loader, accelerator)
408
+
409
+ accelerator.print("Test DataLoader with shuffle=True")
410
+ loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
411
+ test_data_loader(loader, accelerator)
412
+
413
+ accelerator.print("Test DataLoader with batch_sampler")
414
+ sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
415
+ loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS)
416
+ test_data_loader(loader, accelerator)
417
+
418
+ accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`")
419
+ sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
420
+ loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
421
+ test_data_loader(loader, accelerator)
422
+ test_stateful_dataloader(accelerator)
423
+ test_stateful_dataloader_save_state(accelerator)
424
+
425
+ accelerator.end_training()
426
+
427
+
428
+ if __name__ == "__main__":
429
+ main()
accelerate/test_utils/scripts/test_merge_weights.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import gc
16
+ import logging
17
+ import shutil
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+ from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType
23
+ from torch.utils.data import DataLoader
24
+
25
+ from accelerate import Accelerator, FullyShardedDataParallelPlugin
26
+ from accelerate.commands.merge import merge_command, merge_command_parser
27
+ from accelerate.state import AcceleratorState
28
+ from accelerate.test_utils import torch_device
29
+ from accelerate.test_utils.training import RegressionDataset
30
+ from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model
31
+
32
+
33
+ logging.basicConfig(level=logging.INFO)
34
+
35
+ parser = merge_command_parser()
36
+
37
+
38
+ class TinyModel(torch.nn.Module):
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.linear1 = torch.nn.Linear(16, 16)
42
+ self.activation = torch.nn.ReLU()
43
+ self.linear2 = torch.nn.Linear(16, 16)
44
+ self.softmax = torch.nn.Softmax()
45
+
46
+ def forward(self, x):
47
+ return self.linear2(self.activation(self.linear1(x)))
48
+
49
+
50
+ def setup():
51
+ if AcceleratorState._shared_state != {}:
52
+ AcceleratorState()._reset_state()
53
+ plugin = FullyShardedDataParallelPlugin(
54
+ sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
55
+ )
56
+ model = TinyModel()
57
+ with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
58
+ plugin.set_auto_wrap_policy(model)
59
+ accelerator = Accelerator(fsdp_plugin=plugin)
60
+ model = accelerator.prepare(model)
61
+ return model, plugin, accelerator
62
+
63
+
64
+ def mock_training(accelerator, model):
65
+ train_set = RegressionDataset(length=128, seed=42)
66
+ train_dl = DataLoader(train_set, batch_size=16, shuffle=False)
67
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
68
+
69
+ train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
70
+ for _ in range(3):
71
+ for batch in train_dl:
72
+ model.zero_grad()
73
+ output = model(batch["x"])
74
+ loss = torch.nn.functional.mse_loss(output, batch["y"])
75
+ accelerator.backward(loss)
76
+ optimizer.step()
77
+ return model
78
+
79
+
80
+ def check_weights(operation, state_1, state_2):
81
+ for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
82
+ if operation == "same":
83
+ assert torch.allclose(weight_1, weight_2)
84
+ else:
85
+ assert not torch.allclose(weight_1, weight_2)
86
+
87
+
88
+ def check_safetensors_weights(path, model):
89
+ safe_state_dict = load_file(path / "model.safetensors")
90
+ safe_loaded_model = TinyModel().to(torch_device)
91
+ check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
92
+ safe_loaded_model.load_state_dict(safe_state_dict)
93
+ check_weights("same", model.state_dict(), safe_loaded_model.state_dict())
94
+
95
+
96
+ def check_pytorch_weights(path, model):
97
+ nonsafe_state_dict = torch.load(path / "pytorch_model.bin", weights_only=True)
98
+ nonsafe_loaded_model = TinyModel().to(torch_device)
99
+ check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
100
+ nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
101
+ check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())
102
+
103
+
104
+ def test_merge_weights_safetensors(model, path):
105
+ # Should now be saved at `path/merged.safetensors`
106
+ merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True)
107
+ check_safetensors_weights(path, model)
108
+
109
+
110
+ def test_merge_weights_command_safetensors(model, path):
111
+ args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)])
112
+ merge_command(args)
113
+ check_safetensors_weights(path, model)
114
+
115
+
116
+ def test_merge_weights_pytorch(model, path):
117
+ # Should now be saved at `path/merged.bin`
118
+ merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False)
119
+ check_pytorch_weights(path, model)
120
+
121
+
122
+ def test_merge_weights_command_pytorch(model, path):
123
+ args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"])
124
+ merge_command(args)
125
+ check_pytorch_weights(path, model)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ # Note this test requires at least two accelerators!
130
+ model, plugin, accelerator = setup()
131
+ if accelerator.num_processes > 1:
132
+ try:
133
+ # Initial setup for things
134
+ out_path = Path("test_merge_weights_fsdp_weights")
135
+ if not out_path.exists():
136
+ out_path.mkdir(parents=True, exist_ok=True)
137
+
138
+ # Train briefly once weights aren't the baseline
139
+ model = mock_training(accelerator, model)
140
+ accelerator.wait_for_everyone()
141
+
142
+ gc.collect() # Needed for some lingering refs after training
143
+ save_fsdp_model(plugin, accelerator, model, out_path)
144
+ accelerator.wait_for_everyone()
145
+
146
+ # Finally we can test
147
+ test_merge_weights_safetensors(model, out_path)
148
+ test_merge_weights_command_safetensors(model, out_path)
149
+ test_merge_weights_pytorch(model, out_path)
150
+ test_merge_weights_command_pytorch(model, out_path)
151
+ except Exception:
152
+ raise
153
+ finally:
154
+ # Cleanup in case of any failures
155
+ if accelerator.is_main_process:
156
+ shutil.rmtree(out_path)
157
+ accelerator.wait_for_everyone()
158
+ accelerator.end_training()
accelerate/test_utils/scripts/test_notebook.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Test file to ensure that in general certain situational setups for notebooks work.
16
+ """
17
+
18
+ import os
19
+ import time
20
+
21
+ from pytest import mark, raises
22
+ from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
23
+
24
+ from accelerate import PartialState, notebook_launcher
25
+ from accelerate.test_utils import require_bnb
26
+ from accelerate.utils import is_bnb_available, is_xpu_available
27
+
28
+
29
+ def basic_function():
30
+ # Just prints the PartialState
31
+ print(f"PartialState:\n{PartialState()}")
32
+
33
+
34
+ def tough_nut_function(queue):
35
+ if queue.empty():
36
+ return
37
+ trial = queue.get()
38
+ if trial > 0:
39
+ queue.put(trial - 1)
40
+ raise RuntimeError("The nut hasn't cracked yet! Try again.")
41
+
42
+ print(f"PartialState:\n{PartialState()}")
43
+
44
+
45
+ def bipolar_sleep_function(sleep_sec: int):
46
+ state = PartialState()
47
+ if state.process_index % 2 == 0:
48
+ raise RuntimeError("I'm an even process. I don't like to sleep.")
49
+ else:
50
+ time.sleep(sleep_sec)
51
+
52
+
53
+ NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1))
54
+
55
+
56
+ def test_can_initialize():
57
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
58
+
59
+
60
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends")
61
+ def test_static_rdzv_backend():
62
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static")
63
+
64
+
65
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends")
66
+ def test_c10d_rdzv_backend():
67
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d")
68
+
69
+
70
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance")
71
+ def test_fault_tolerant(max_restarts: int = 3):
72
+ # Use torch.multiprocessing to get the right context for the current device
73
+ import torch.multiprocessing as mp
74
+
75
+ # Get appropriate context - 'spawn' for XPU, 'fork' for others
76
+ if is_xpu_available():
77
+ ctx = mp.get_context("spawn")
78
+ else:
79
+ ctx = mp.get_context("fork")
80
+ queue = ctx.Queue()
81
+ queue.put(max_restarts)
82
+ notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts)
83
+
84
+
85
+ @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring")
86
+ def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100):
87
+ start_time = time.time()
88
+ with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."):
89
+ notebook_launcher(
90
+ bipolar_sleep_function,
91
+ (sleep_sec,),
92
+ num_processes=NUM_PROCESSES,
93
+ monitor_interval=monitor_interval,
94
+ )
95
+ assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time."
96
+
97
+
98
+ @require_bnb
99
+ def test_problematic_imports():
100
+ with raises(RuntimeError, match="Please keep these imports"):
101
+ import bitsandbytes as bnb # noqa: F401
102
+
103
+ notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
104
+
105
+
106
+ def main():
107
+ print("Test basic notebook can be ran")
108
+ test_can_initialize()
109
+ print("Test static rendezvous backend")
110
+ test_static_rdzv_backend()
111
+ print("Test c10d rendezvous backend")
112
+ test_c10d_rdzv_backend()
113
+ print("Test fault tolerant")
114
+ test_fault_tolerant()
115
+ print("Test monitoring")
116
+ test_monitoring()
117
+ if is_bnb_available():
118
+ print("Test problematic imports (bnb)")
119
+ test_problematic_imports()
120
+ if NUM_PROCESSES > 1:
121
+ PartialState().destroy_process_group()
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
accelerate/test_utils/scripts/test_ops.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+
19
+ from accelerate import PartialState
20
+ from accelerate.test_utils.testing import assert_exception
21
+ from accelerate.utils.dataclasses import DistributedType
22
+ from accelerate.utils.operations import (
23
+ DistributedOperationException,
24
+ broadcast,
25
+ copy_tensor_to_devices,
26
+ gather,
27
+ gather_object,
28
+ pad_across_processes,
29
+ reduce,
30
+ )
31
+
32
+
33
+ def create_tensor(state):
34
+ return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device)
35
+
36
+
37
+ def test_gather(state):
38
+ tensor = create_tensor(state)
39
+ gathered_tensor = gather(tensor)
40
+ assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))
41
+
42
+
43
+ def test_gather_object(state):
44
+ # Gather objects in TorchXLA is not supported.
45
+ if state.distributed_type == DistributedType.XLA:
46
+ return
47
+ obj = [state.process_index]
48
+ gathered_obj = gather_object(obj)
49
+ assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
50
+ assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
51
+
52
+
53
+ def test_gather_non_contiguous(state):
54
+ # Skip this test because the 'is_contiguous' function of XLA tensor always returns True.
55
+ if state.distributed_type == DistributedType.XLA:
56
+ return
57
+
58
+ # Create a non-contiguous tensor (enforce non-contiguity after device memory allocation)
59
+ tensor = torch.arange(12, device=state.device).view(4, 3).t()
60
+ assert not tensor.is_contiguous()
61
+ # Shouldn't error out
62
+ _ = gather(tensor)
63
+
64
+
65
+ def test_broadcast(state):
66
+ tensor = create_tensor(state)
67
+ broadcasted_tensor = broadcast(tensor)
68
+ assert broadcasted_tensor.shape == torch.Size([state.num_processes])
69
+ assert broadcasted_tensor.tolist() == list(range(1, state.num_processes + 1))
70
+
71
+
72
+ def test_pad_across_processes(state):
73
+ # We need to pad the tensor with one more element if we are the main process
74
+ # to ensure that we can pad
75
+ if state.is_main_process:
76
+ tensor = torch.arange(state.num_processes + 1).to(state.device)
77
+ else:
78
+ tensor = torch.arange(state.num_processes).to(state.device)
79
+ padded_tensor = pad_across_processes(tensor)
80
+ assert padded_tensor.shape == torch.Size([state.num_processes + 1])
81
+ if not state.is_main_process:
82
+ assert padded_tensor.tolist() == list(range(0, state.num_processes)) + [0]
83
+
84
+
85
+ def test_reduce_sum(state):
86
+ # For now runs on only two processes
87
+ if state.num_processes != 2:
88
+ return
89
+ tensor = create_tensor(state)
90
+ reduced_tensor = reduce(tensor, "sum")
91
+ truth_tensor = torch.tensor([4.0, 6]).to(state.device)
92
+ assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}"
93
+
94
+
95
+ def test_reduce_mean(state):
96
+ # For now runs on only two processes
97
+ if state.num_processes != 2:
98
+ return
99
+ tensor = create_tensor(state)
100
+ reduced_tensor = reduce(tensor, "mean")
101
+ truth_tensor = torch.tensor([2.0, 3]).to(state.device)
102
+ assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}"
103
+
104
+
105
+ def test_op_checker(state):
106
+ # Must be in a distributed state, and gathering is currently not supported in TorchXLA.
107
+ if state.distributed_type in [DistributedType.NO, DistributedType.XLA]:
108
+ return
109
+ state.debug = True
110
+ # `pad_across_processes`
111
+ if state.process_index == 0:
112
+ data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}
113
+ else:
114
+ data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4, 5]]]).to(state.device)}
115
+
116
+ with assert_exception(DistributedOperationException):
117
+ pad_across_processes(data, dim=0)
118
+
119
+ # `reduce`
120
+ if state.process_index == 0:
121
+ data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}
122
+ else:
123
+ data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)}
124
+
125
+ with assert_exception(DistributedOperationException):
126
+ reduce(data)
127
+
128
+ # `broadcast`
129
+ if state.process_index == 0:
130
+ data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}
131
+ else:
132
+ data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)}
133
+
134
+ with assert_exception(DistributedOperationException):
135
+ broadcast(data)
136
+
137
+ state.debug = False
138
+
139
+
140
+ def test_copy_tensor_to_devices(state):
141
+ if state.distributed_type not in [DistributedType.MULTI_GPU, DistributedType.XLA]:
142
+ return
143
+ if state.is_main_process:
144
+ tensor = torch.tensor([1, 2, 3], dtype=torch.int).to(state.device)
145
+ else:
146
+ tensor = None
147
+ tensor = copy_tensor_to_devices(tensor)
148
+ assert torch.allclose(tensor, torch.tensor([1, 2, 3], dtype=torch.int, device=state.device))
149
+
150
+
151
+ def _mp_fn(index):
152
+ # For xla_spawn (TPUs)
153
+ main()
154
+
155
+
156
+ def main():
157
+ state = PartialState()
158
+ state.print(f"State: {state}")
159
+ state.print("testing gather")
160
+ test_gather(state)
161
+ state.print("testing gather_object")
162
+ test_gather_object(state)
163
+ state.print("testing gather non-contiguous")
164
+ test_gather_non_contiguous(state)
165
+ state.print("testing broadcast")
166
+ test_broadcast(state)
167
+ state.print("testing pad_across_processes")
168
+ test_pad_across_processes(state)
169
+ state.print("testing reduce_sum")
170
+ test_reduce_sum(state)
171
+ state.print("testing reduce_mean")
172
+ test_reduce_mean(state)
173
+ state.print("testing op_checker")
174
+ test_op_checker(state)
175
+ state.print("testing sending tensors across devices")
176
+ test_copy_tensor_to_devices(state)
177
+ state.destroy_process_group()
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
accelerate/test_utils/scripts/test_script.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import contextlib
18
+ import io
19
+ import math
20
+ import time
21
+ from copy import deepcopy
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ import torch
26
+ from torch.utils.data import DataLoader, Dataset
27
+
28
+ from accelerate import Accelerator
29
+ from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
30
+ from accelerate.state import AcceleratorState
31
+ from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
32
+ from accelerate.utils import (
33
+ DataLoaderConfiguration,
34
+ DistributedType,
35
+ gather,
36
+ gather_object,
37
+ is_bf16_available,
38
+ is_cuda_available,
39
+ is_datasets_available,
40
+ is_fp16_available,
41
+ is_hpu_available,
42
+ is_mps_available,
43
+ is_pytest_available,
44
+ set_seed,
45
+ synchronize_rng_states,
46
+ )
47
+
48
+
49
+ if is_hpu_available():
50
+ ATOL = 1e-3
51
+ RTOL = 1e-3
52
+ else:
53
+ ATOL = 1e-6
54
+ RTOL = 1e-6
55
+
56
+
57
+ def generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False):
58
+ "Creates a dataloader that can also use the `SeedableRandomSampler`"
59
+ if use_seedable_sampler:
60
+ # The SeedableRandomSampler is needed during distributed setups
61
+ # for full reproducibility across processes with the `DataLoader`
62
+ sampler = SeedableRandomSampler(
63
+ generator=generator,
64
+ data_source=train_set,
65
+ num_samples=len(train_set),
66
+ )
67
+ return DataLoader(train_set, batch_size=batch_size, sampler=sampler)
68
+ else:
69
+ return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
70
+
71
+
72
+ def print_main(state):
73
+ print(f"Printing from the main process {state.process_index}")
74
+
75
+
76
+ def print_local_main(state):
77
+ print(f"Printing from the local main process {state.local_process_index}")
78
+
79
+
80
+ def print_last(state):
81
+ print(f"Printing from the last process {state.process_index}")
82
+
83
+
84
+ def print_on(state, process_idx):
85
+ print(f"Printing from process {process_idx}: {state.process_index}")
86
+
87
+
88
+ def process_execution_check():
89
+ accelerator = Accelerator()
90
+ num_processes = accelerator.num_processes
91
+ # Test main_process_first context manager
92
+ path = Path("check_main_process_first.txt")
93
+ with accelerator.main_process_first():
94
+ if accelerator.is_main_process:
95
+ time.sleep(0.1) # ensure main process takes longest
96
+ with open(path, "a+") as f:
97
+ f.write("Currently in the main process\n")
98
+ else:
99
+ with open(path, "a+") as f:
100
+ f.write("Now on another process\n")
101
+ accelerator.wait_for_everyone()
102
+
103
+ if accelerator.is_main_process:
104
+ with open(path) as f:
105
+ text = "".join(f.readlines())
106
+ try:
107
+ assert text.startswith("Currently in the main process\n"), "Main process was not first"
108
+ if num_processes > 1:
109
+ assert text.endswith("Now on another process\n"), "Main process was not first"
110
+ assert text.count("Now on another process\n") == accelerator.num_processes - 1, (
111
+ f"Only wrote to file {text.count('Now on another process') + 1} times, not {accelerator.num_processes}"
112
+ )
113
+ except AssertionError:
114
+ path.unlink()
115
+ raise
116
+
117
+ if accelerator.is_main_process and path.exists():
118
+ path.unlink()
119
+ accelerator.wait_for_everyone()
120
+ # Test the decorators
121
+ f = io.StringIO()
122
+ with contextlib.redirect_stdout(f):
123
+ accelerator.on_main_process(print_main)(accelerator.state)
124
+ result = f.getvalue().rstrip()
125
+ if accelerator.is_main_process:
126
+ assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0"
127
+ else:
128
+ assert f.getvalue().rstrip() == "", f'{result} != ""'
129
+ f.truncate(0)
130
+ f.seek(0)
131
+
132
+ with contextlib.redirect_stdout(f):
133
+ accelerator.on_local_main_process(print_local_main)(accelerator.state)
134
+ if accelerator.is_local_main_process:
135
+ assert f.getvalue().rstrip() == "Printing from the local main process 0"
136
+ else:
137
+ assert f.getvalue().rstrip() == ""
138
+ f.truncate(0)
139
+ f.seek(0)
140
+
141
+ with contextlib.redirect_stdout(f):
142
+ accelerator.on_last_process(print_last)(accelerator.state)
143
+ if accelerator.is_last_process:
144
+ assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}"
145
+ else:
146
+ assert f.getvalue().rstrip() == ""
147
+ f.truncate(0)
148
+ f.seek(0)
149
+
150
+ for process_idx in range(num_processes):
151
+ with contextlib.redirect_stdout(f):
152
+ accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx)
153
+ if accelerator.process_index == process_idx:
154
+ assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}"
155
+ else:
156
+ assert f.getvalue().rstrip() == ""
157
+ f.truncate(0)
158
+ f.seek(0)
159
+
160
+
161
+ def init_state_check():
162
+ # Test we can instantiate this twice in a row.
163
+ state = AcceleratorState()
164
+ if state.local_process_index == 0:
165
+ print("Testing, testing. 1, 2, 3.")
166
+ print(state)
167
+
168
+
169
+ def rng_sync_check():
170
+ state = AcceleratorState()
171
+ synchronize_rng_states(["torch"])
172
+ assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU."
173
+ if state.distributed_type == DistributedType.MULTI_GPU:
174
+ synchronize_rng_states(["cuda"])
175
+ assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU."
176
+ elif state.distributed_type == DistributedType.MULTI_XPU:
177
+ synchronize_rng_states(["xpu"])
178
+ assert are_the_same_tensors(torch.xpu.get_rng_state()), "RNG states improperly synchronized on XPU."
179
+ generator = torch.Generator()
180
+ synchronize_rng_states(["generator"], generator=generator)
181
+ assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator."
182
+
183
+ if state.local_process_index == 0:
184
+ print("All rng are properly synched.")
185
+
186
+
187
+ def dl_preparation_check():
188
+ state = AcceleratorState()
189
+ length = 32 * state.num_processes
190
+
191
+ dl = DataLoader(range(length), batch_size=8)
192
+ dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
193
+ result = []
194
+ for batch in dl:
195
+ result.append(gather(batch))
196
+ result = torch.cat(result)
197
+
198
+ assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
199
+
200
+ dl = DataLoader(range(length), batch_size=8)
201
+ dl = prepare_data_loader(
202
+ dl,
203
+ state.device,
204
+ state.num_processes,
205
+ state.process_index,
206
+ put_on_device=True,
207
+ split_batches=True,
208
+ )
209
+ result = []
210
+ for batch in dl:
211
+ result.append(gather(batch))
212
+ result = torch.cat(result)
213
+ assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
214
+
215
+ if state.process_index == 0:
216
+ print("Non-shuffled dataloader passing.")
217
+
218
+ dl = DataLoader(range(length), batch_size=8, shuffle=True)
219
+ dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
220
+ result = []
221
+ for batch in dl:
222
+ result.append(gather(batch))
223
+ result = torch.cat(result).tolist()
224
+ result.sort()
225
+ assert result == list(range(length)), "Wrong shuffled dataloader result."
226
+
227
+ dl = DataLoader(range(length), batch_size=8, shuffle=True)
228
+ dl = prepare_data_loader(
229
+ dl,
230
+ state.device,
231
+ state.num_processes,
232
+ state.process_index,
233
+ put_on_device=True,
234
+ split_batches=True,
235
+ )
236
+ result = []
237
+ for batch in dl:
238
+ result.append(gather(batch))
239
+ result = torch.cat(result).tolist()
240
+ result.sort()
241
+ assert result == list(range(length)), "Wrong shuffled dataloader result."
242
+
243
+ if state.local_process_index == 0:
244
+ print("Shuffled dataloader passing.")
245
+
246
+
247
+ def central_dl_preparation_check():
248
+ state = AcceleratorState()
249
+ length = 32 * state.num_processes
250
+
251
+ dl = DataLoader(range(length), batch_size=8)
252
+ dl = prepare_data_loader(
253
+ dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
254
+ )
255
+ result = []
256
+ for batch in dl:
257
+ result.append(gather(batch))
258
+ result = torch.cat(result)
259
+ assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
260
+
261
+ dl = DataLoader(range(length), batch_size=8)
262
+ dl = prepare_data_loader(
263
+ dl,
264
+ state.device,
265
+ state.num_processes,
266
+ state.process_index,
267
+ put_on_device=True,
268
+ split_batches=True,
269
+ dispatch_batches=True,
270
+ )
271
+ result = []
272
+ for batch in dl:
273
+ result.append(gather(batch))
274
+ result = torch.cat(result)
275
+ assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
276
+
277
+ if state.process_index == 0:
278
+ print("Non-shuffled central dataloader passing.")
279
+
280
+ dl = DataLoader(range(length), batch_size=8, shuffle=True)
281
+ dl = prepare_data_loader(
282
+ dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
283
+ )
284
+ result = []
285
+ for batch in dl:
286
+ result.append(gather(batch))
287
+ result = torch.cat(result).tolist()
288
+ result.sort()
289
+ assert result == list(range(length)), "Wrong shuffled dataloader result."
290
+
291
+ dl = DataLoader(range(length), batch_size=8, shuffle=True)
292
+ dl = prepare_data_loader(
293
+ dl,
294
+ state.device,
295
+ state.num_processes,
296
+ state.process_index,
297
+ put_on_device=True,
298
+ split_batches=True,
299
+ dispatch_batches=True,
300
+ )
301
+ result = []
302
+ for batch in dl:
303
+ result.append(gather(batch))
304
+ result = torch.cat(result).tolist()
305
+ result.sort()
306
+ assert result == list(range(length)), "Wrong shuffled dataloader result."
307
+
308
+ if state.local_process_index == 0:
309
+ print("Shuffled central dataloader passing.")
310
+
311
+
312
+ def custom_sampler_check():
313
+ state = AcceleratorState()
314
+
315
+ class CustomDataset(Dataset):
316
+ def __init__(self, data):
317
+ self.data = data
318
+
319
+ def __len__(self):
320
+ return len(self.data)
321
+
322
+ def __getitem__(self, index):
323
+ return self.data[index]
324
+
325
+ class CustomBatchSampler:
326
+ def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
327
+ self.batch_size = batch_size
328
+ self.data_index = np.arange(dataset_length)
329
+ self.shuffle = shuffle
330
+
331
+ def __iter__(self):
332
+ num_batches = len(self)
333
+ if self.shuffle:
334
+ index = np.random.permutation(self.data_index)
335
+ else:
336
+ index = self.data_index
337
+ output = np.array_split(index, num_batches)
338
+ yield from output
339
+
340
+ def __len__(self):
341
+ return math.ceil(len(self.data_index) / self.batch_size)
342
+
343
+ dataset = CustomDataset(range(32 * state.num_processes))
344
+ sampler = CustomBatchSampler(len(dataset), batch_size=8)
345
+ dl = DataLoader(dataset, batch_sampler=sampler)
346
+ dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)
347
+ # We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler
348
+ if hasattr(dl.batch_sampler, "batch_sampler"):
349
+ assert isinstance(dl.batch_sampler.batch_sampler, CustomBatchSampler), (
350
+ "Custom sampler was changed after calling `prepare_data_loader`"
351
+ )
352
+ else:
353
+ assert isinstance(dl.batch_sampler, CustomBatchSampler), (
354
+ "Custom sampler was changed after calling `prepare_data_loader`"
355
+ )
356
+
357
+
358
+ def check_seedable_sampler():
359
+ # Set seed
360
+ set_seed(42)
361
+ train_set = RegressionDataset(length=10, seed=42)
362
+ train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
363
+
364
+ config = DataLoaderConfiguration(use_seedable_sampler=True)
365
+ accelerator = Accelerator(dataloader_config=config)
366
+ train_dl = accelerator.prepare(train_dl)
367
+ original_items = []
368
+ for _ in range(3):
369
+ for batch in train_dl:
370
+ original_items.append(batch["x"])
371
+ original_items = torch.cat(original_items)
372
+
373
+ # Set seed again and the epoch
374
+ set_seed(42)
375
+ train_dl.set_epoch(0)
376
+ new_items = []
377
+ for _ in range(3):
378
+ for batch in train_dl:
379
+ new_items.append(batch["x"])
380
+ new_items = torch.cat(new_items)
381
+ assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch."
382
+
383
+
384
+ def check_seedable_sampler_in_batch_sampler_shard():
385
+ set_seed(42)
386
+
387
+ config = DataLoaderConfiguration(use_seedable_sampler=True)
388
+ accelerator = Accelerator(dataloader_config=config)
389
+ assert accelerator.num_processes > 1, "This test requires more than one process."
390
+
391
+ dataloader = DataLoader(list(range(10)), batch_size=1, shuffle=True)
392
+ prepared_data_loader = prepare_data_loader(
393
+ dataloader=dataloader,
394
+ use_seedable_sampler=True,
395
+ )
396
+
397
+ target_sampler = prepared_data_loader.batch_sampler.batch_sampler.sampler
398
+ assert isinstance(target_sampler, SeedableRandomSampler), (
399
+ "Sampler in BatchSamplerShard is not SeedableRandomSampler."
400
+ )
401
+
402
+
403
+ def check_seedable_sampler_with_data_seed():
404
+ # Set seed
405
+ set_seed(42)
406
+ data_seed = 42
407
+ train_set = RegressionDataset(length=10, seed=42)
408
+ train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
409
+
410
+ config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed)
411
+ accelerator = Accelerator(dataloader_config=config)
412
+ prepared_dl = accelerator.prepare(train_dl)
413
+ original_items = []
414
+ for _ in range(3):
415
+ for batch in prepared_dl:
416
+ original_items.append(batch["x"])
417
+ original_items = torch.cat(original_items)
418
+
419
+ # Set new data seed
420
+ config.data_seed = 43
421
+ accelerator = Accelerator(dataloader_config=config)
422
+ prepared_dl = accelerator.prepare(train_dl)
423
+ new_items = []
424
+ for _ in range(3):
425
+ for batch in prepared_dl:
426
+ new_items.append(batch["x"])
427
+ new_items = torch.cat(new_items)
428
+ assert not torch.allclose(original_items, new_items), "Obtained the same items with different data seed."
429
+
430
+
431
+ def mock_training(length, batch_size, generator, use_seedable_sampler=False):
432
+ set_seed(42)
433
+ generator.manual_seed(42)
434
+ train_set = RegressionDataset(length=length, seed=42)
435
+
436
+ train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
437
+ model = RegressionModel()
438
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
439
+ for epoch in range(3):
440
+ for batch in train_dl:
441
+ model.zero_grad()
442
+ output = model(batch["x"])
443
+ loss = torch.nn.functional.mse_loss(output, batch["y"])
444
+ loss.backward()
445
+ optimizer.step()
446
+ return train_set, model
447
+
448
+
449
+ def training_check(use_seedable_sampler=False):
450
+ state = AcceleratorState()
451
+ generator = torch.Generator()
452
+ batch_size = 8
453
+ length = batch_size * 4 * state.num_processes
454
+
455
+ train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler)
456
+ assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
457
+ assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."
458
+
459
+ accelerator = Accelerator()
460
+ train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
461
+ model = RegressionModel()
462
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
463
+
464
+ train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
465
+ set_seed(42)
466
+ generator.manual_seed(42)
467
+ for _ in range(3):
468
+ for batch in train_dl:
469
+ model.zero_grad()
470
+ output = model(batch["x"])
471
+ loss = torch.nn.functional.mse_loss(output, batch["y"])
472
+ accelerator.backward(loss)
473
+ optimizer.step()
474
+
475
+ model = accelerator.unwrap_model(model).cpu()
476
+ torch.testing.assert_close(
477
+ old_model.a,
478
+ model.a,
479
+ atol=ATOL,
480
+ rtol=RTOL,
481
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
482
+ )
483
+ torch.testing.assert_close(
484
+ old_model.b,
485
+ model.b,
486
+ atol=ATOL,
487
+ rtol=RTOL,
488
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
489
+ )
490
+
491
+ accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
492
+
493
+ dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler)
494
+ accelerator = Accelerator(dataloader_config=dataloader_config)
495
+ train_dl = generate_baseline_dataloader(
496
+ train_set, generator, batch_size * state.num_processes, use_seedable_sampler
497
+ )
498
+ model = RegressionModel()
499
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
500
+
501
+ train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
502
+ set_seed(42)
503
+ generator.manual_seed(42)
504
+ for _ in range(3):
505
+ for batch in train_dl:
506
+ model.zero_grad()
507
+ output = model(batch["x"])
508
+ loss = torch.nn.functional.mse_loss(output, batch["y"])
509
+ accelerator.backward(loss)
510
+ optimizer.step()
511
+
512
+ model = accelerator.unwrap_model(model).cpu()
513
+ torch.testing.assert_close(
514
+ old_model.a,
515
+ model.a,
516
+ atol=ATOL,
517
+ rtol=RTOL,
518
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
519
+ )
520
+ torch.testing.assert_close(
521
+ old_model.b,
522
+ model.b,
523
+ atol=ATOL,
524
+ rtol=RTOL,
525
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
526
+ )
527
+
528
+ accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
529
+
530
+ # FP32 wrapper check
531
+ if is_cuda_available() or is_mps_available():
532
+ # Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
533
+ print("Keep fp32 wrapper check.")
534
+ AcceleratorState._reset_state()
535
+ accelerator = Accelerator(mixed_precision="fp16")
536
+
537
+ model = torch.nn.Linear(2, 4)
538
+ model = accelerator.prepare(model)
539
+ model_with_fp32_wrapper = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
540
+
541
+ # Run forward with fp16 as input.
542
+ # When the model is with mixed precision wrapper, no error will be raised.
543
+ input_tensor = torch.Tensor([1, 2]).to(dtype=torch.float16, device=accelerator.device)
544
+ output = model_with_fp32_wrapper(input_tensor)
545
+
546
+ # BF16 support
547
+ if is_bf16_available():
548
+ # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
549
+ print("BF16 training check.")
550
+ AcceleratorState._reset_state()
551
+ dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
552
+ accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config)
553
+ train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
554
+ model = RegressionModel()
555
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
556
+
557
+ train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
558
+ set_seed(42)
559
+ generator.manual_seed(42)
560
+ for _ in range(3):
561
+ for batch in train_dl:
562
+ model.zero_grad()
563
+ output = model(batch["x"])
564
+ loss = torch.nn.functional.mse_loss(output, batch["y"])
565
+ accelerator.backward(loss)
566
+ optimizer.step()
567
+
568
+ model = accelerator.unwrap_model(model).cpu()
569
+ torch.testing.assert_close(
570
+ old_model.a,
571
+ model.a,
572
+ atol=ATOL,
573
+ rtol=RTOL,
574
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
575
+ )
576
+ torch.testing.assert_close(
577
+ old_model.b,
578
+ model.b,
579
+ atol=ATOL,
580
+ rtol=RTOL,
581
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
582
+ )
583
+
584
+ # FP16 support (HPU fp16 model seems to be off by 10% from the CPU, which is a lot of numerical error)
585
+ if is_fp16_available() and not is_hpu_available():
586
+ # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
587
+ print("FP16 training check.")
588
+ AcceleratorState._reset_state()
589
+ dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
590
+ accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config)
591
+ train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
592
+ model = RegressionModel()
593
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
594
+
595
+ train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
596
+ set_seed(42)
597
+ generator.manual_seed(42)
598
+ for _ in range(3):
599
+ for batch in train_dl:
600
+ model.zero_grad()
601
+ output = model(batch["x"])
602
+ loss = torch.nn.functional.mse_loss(output, batch["y"])
603
+ accelerator.backward(loss)
604
+ optimizer.step()
605
+
606
+ model = accelerator.unwrap_model(model).cpu()
607
+ torch.testing.assert_close(
608
+ old_model.a,
609
+ model.a,
610
+ atol=ATOL,
611
+ rtol=RTOL,
612
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
613
+ )
614
+ torch.testing.assert_close(
615
+ old_model.b,
616
+ model.b,
617
+ atol=ATOL,
618
+ rtol=RTOL,
619
+ msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
620
+ )
621
+
622
+
623
+ def test_split_between_processes_dataset(datasets_Dataset):
624
+ state = AcceleratorState()
625
+ data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
626
+ with state.split_between_processes(data, apply_padding=False) as results:
627
+ assert len(results) == 2, (
628
+ f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
629
+ )
630
+
631
+ data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
632
+ with state.split_between_processes(data, apply_padding=False) as results:
633
+ if state.is_last_process:
634
+ assert len(results) == 1, (
635
+ f"Last process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}"
636
+ )
637
+ else:
638
+ assert len(results) == 2, (
639
+ f"One of the intermediate processes did not receive two items. Process index: {state.process_index}; Length: {len(results)}"
640
+ )
641
+ state.wait_for_everyone()
642
+
643
+ odd_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
644
+ even_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
645
+
646
+ for data in [odd_data, even_data]:
647
+ expected_output = data["k"]
648
+
649
+ with state.split_between_processes(data, apply_padding=True) as results:
650
+ if state.num_processes == 1:
651
+ assert len(results) == len(data), (
652
+ f"Single process did not receive all items. Process index: {state.process_index}; Length: {len(results)}"
653
+ )
654
+ else:
655
+ assert len(results) == 2, (
656
+ f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
657
+ )
658
+
659
+ results_per_process = []
660
+ for result in results:
661
+ results_per_process.append(result)
662
+
663
+ state.wait_for_everyone()
664
+
665
+ gathered_results = gather_object(results_per_process)
666
+ output = [r["k"] for r in gathered_results[: len(data)]]
667
+
668
+ assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
669
+
670
+
671
+ def test_split_between_processes_list():
672
+ state = AcceleratorState()
673
+ data = list(range(0, 2 * state.num_processes))
674
+ with state.split_between_processes(data) as results:
675
+ assert len(results) == 2, (
676
+ f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
677
+ )
678
+ state.wait_for_everyone()
679
+
680
+ even_data = list(range(0, (2 * state.num_processes)))
681
+ odd_data = list(range(0, (2 * state.num_processes) - 1))
682
+ for data in [odd_data, even_data]:
683
+ expected_output = data
684
+
685
+ with state.split_between_processes(data, apply_padding=True) as results:
686
+ num_samples_per_device = math.ceil(len(data) / state.num_processes)
687
+ # Test all processes gets the correct number of item(s)
688
+ assert len(results) == num_samples_per_device, (
689
+ f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
690
+ )
691
+
692
+ results_per_process = []
693
+ for result in results:
694
+ results_per_process.append(result)
695
+
696
+ state.wait_for_everyone()
697
+
698
+ gathered_results = gather_object(results_per_process)
699
+ output = gathered_results[: len(data)]
700
+
701
+ assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
702
+
703
+
704
+ def test_split_between_processes_nested_dict():
705
+ state = AcceleratorState()
706
+ a = [1, 2, 3, 4, 5, 6, 7, 8]
707
+ b = ["a", "b", "c", "d", "e", "f", "g", "h"]
708
+ c = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
709
+ if state.num_processes in (1, 2, 4):
710
+ data = {"a": a, "b": b, "c": c}
711
+ data_copy = deepcopy(data)
712
+ with state.split_between_processes(data) as results:
713
+ if state.process_index == 0:
714
+ assert results["a"] == data_copy["a"][: 8 // state.num_processes]
715
+ elif state.num_processes == 2:
716
+ assert results["a"] == data_copy["a"][4:]
717
+ elif state.process_index == 3:
718
+ # We return a list each time
719
+ assert results["a"] == data_copy["a"][-2:], f"Expected: {data_copy['a'][-2]}, Actual: {results['a']}"
720
+ if state.process_index == 0:
721
+ assert results["b"] == data_copy["b"][: 8 // state.num_processes]
722
+ elif state.num_processes == 2:
723
+ assert results["b"] == data_copy["b"][4:]
724
+ elif state.process_index == 3:
725
+ assert results["b"] == data_copy["b"][-2:]
726
+ if state.process_index == 0:
727
+ assert torch.allclose(results["c"], data_copy["c"][: 8 // state.num_processes]), (
728
+ f"Did not obtain expected values on process 0, expected `{data['c'][: 8 // state.num_processes]}`, received: {results['c']}"
729
+ )
730
+ elif state.num_processes == 2:
731
+ assert torch.allclose(results["c"], data_copy["c"][4:]), (
732
+ f"Did not obtain expected values on process 2, expected `{data['c'][4:]}`, received: {results['c']}"
733
+ )
734
+ elif state.process_index == 3:
735
+ assert torch.allclose(results["c"], data_copy["c"][-2:]), (
736
+ f"Did not obtain expected values on process 4, expected `{data['c'][-2:]}`, received: {results['c']}"
737
+ )
738
+
739
+ state.wait_for_everyone()
740
+
741
+
742
+ def test_split_between_processes_tensor():
743
+ state = AcceleratorState()
744
+ if state.num_processes > 1:
745
+ data = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).to(state.device)
746
+ with state.split_between_processes(data) as results:
747
+ if state.process_index == 0:
748
+ expected = torch.tensor([[0, 1, 2, 3]]).to(state.device)
749
+ else:
750
+ expected = torch.tensor([[4, 5, 6, 7]]).to(state.device)
751
+ torch.testing.assert_close(results, expected)
752
+ state.wait_for_everyone()
753
+
754
+ even_data = torch.tensor([[i] for i in range(2 * state.num_processes)]).to(state.device)
755
+ odd_data = torch.tensor([[i] for i in range(2 * state.num_processes - 1)]).to(state.device)
756
+ for data in [even_data, odd_data]:
757
+ expected_output = [torch.tensor(i) for i in data.tolist()]
758
+
759
+ with state.split_between_processes(data, apply_padding=True) as results:
760
+ num_samples_per_device = math.ceil(len(data) / state.num_processes)
761
+ assert len(results) == num_samples_per_device, (
762
+ f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
763
+ )
764
+ results_per_process = []
765
+ for result in results:
766
+ results_per_process.append(result.to("cpu"))
767
+
768
+ state.wait_for_everyone()
769
+
770
+ gathered_results = gather_object(results_per_process)
771
+ output = gathered_results[: len(data)]
772
+
773
+ assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
774
+
775
+
776
+ def test_split_between_processes_evenly():
777
+ state = AcceleratorState()
778
+ if state.num_processes in (1, 2, 4, 8):
779
+ data = list(range(17))
780
+ num_samples_per_process = len(data) // state.num_processes
781
+ num_extras = len(data) % state.num_processes
782
+ with state.split_between_processes(data) as results:
783
+ if state.process_index < num_extras:
784
+ assert len(results) == num_samples_per_process + 1, (
785
+ f"Each Process should have even elements. Expected: {num_samples_per_process + 1}, Actual: {len(results)}"
786
+ )
787
+ else:
788
+ assert len(results) == num_samples_per_process, (
789
+ f"Each Process should have even elements. Expected: {num_samples_per_process}, Actual: {len(results)}"
790
+ )
791
+ state.wait_for_everyone()
792
+
793
+
794
+ def test_trigger():
795
+ accelerator = Accelerator()
796
+ # should start with being false
797
+ assert accelerator.check_trigger() is False
798
+
799
+ # set a breakpoint on the main process
800
+ if accelerator.is_main_process:
801
+ accelerator.set_trigger()
802
+
803
+ # check it's been activated across all processes
804
+ # calls `all_reduce` and triggers a sync
805
+ assert accelerator.check_trigger() is True
806
+
807
+ # check it's been reset after the sync
808
+ assert accelerator.check_trigger() is False
809
+
810
+
811
+ def test_reinstantiated_state():
812
+ import pytest
813
+
814
+ AcceleratorState._reset_state()
815
+ simple_model = torch.nn.Linear(1, 1)
816
+ # First define an accelerator
817
+ accelerator = Accelerator()
818
+ # Then call `reset_state`, breaking the state existing in the accelerator
819
+ AcceleratorState._reset_state()
820
+ # Now try and prepare a simple model, should raise the custom error early
821
+ with pytest.raises(AttributeError) as cm:
822
+ accelerator.prepare(simple_model)
823
+ assert "`AcceleratorState` object has no attribute" in str(cm.value.args[0])
824
+ assert "This happens if `AcceleratorState._reset_state()`" in str(cm.value.args[0])
825
+
826
+
827
+ def main():
828
+ accelerator = Accelerator()
829
+ state = accelerator.state
830
+ if state.local_process_index == 0:
831
+ print("**Initialization**")
832
+ init_state_check()
833
+ state.wait_for_everyone()
834
+
835
+ if state.distributed_type == DistributedType.MULTI_GPU:
836
+ num_processes_per_node = torch.cuda.device_count()
837
+ else:
838
+ num_processes_per_node = state.num_processes
839
+
840
+ # We only run this test on non-multinode
841
+ if num_processes_per_node == state.num_processes:
842
+ if state.process_index == 0:
843
+ print("\n**Test process execution**")
844
+ process_execution_check()
845
+
846
+ if state.process_index == 0:
847
+ print("\n**Test split between processes as a list**")
848
+ test_split_between_processes_list()
849
+
850
+ if state.process_index == 0:
851
+ print("\n**Test split between processes as a dict**")
852
+ test_split_between_processes_nested_dict()
853
+
854
+ if state.process_index == 0:
855
+ print("\n**Test split between processes as a tensor**")
856
+ test_split_between_processes_tensor()
857
+
858
+ if state.process_index == 0:
859
+ print("\n**Test split between processes evenly**")
860
+ test_split_between_processes_evenly()
861
+
862
+ if state.process_index == 0:
863
+ print("\n**Test split between processes as a datasets.Dataset**")
864
+ if is_datasets_available():
865
+ from datasets import Dataset as datasets_Dataset
866
+
867
+ test_split_between_processes_dataset(datasets_Dataset)
868
+ else:
869
+ print("Skipped because Hugging Face datasets is not available")
870
+
871
+ if state.local_process_index == 0:
872
+ print("\n**Test random number generator synchronization**")
873
+ rng_sync_check()
874
+
875
+ if state.local_process_index == 0:
876
+ print("\n**DataLoader integration test**")
877
+ dl_preparation_check()
878
+ if state.distributed_type != DistributedType.XLA:
879
+ central_dl_preparation_check()
880
+ custom_sampler_check()
881
+ check_seedable_sampler()
882
+ check_seedable_sampler_with_data_seed()
883
+
884
+ if state.num_processes > 1:
885
+ check_seedable_sampler_in_batch_sampler_shard()
886
+
887
+ # Trainings are not exactly the same in DeepSpeed and CPU mode
888
+ if state.distributed_type == DistributedType.DEEPSPEED:
889
+ return
890
+
891
+ if state.local_process_index == 0:
892
+ print("\n**Training integration test**")
893
+ training_check(use_seedable_sampler=False)
894
+ training_check(use_seedable_sampler=True)
895
+
896
+ if state.local_process_index == 0:
897
+ print("\n**Breakpoint trigger test**")
898
+ test_trigger()
899
+
900
+ if is_pytest_available():
901
+ if state.local_process_index == 0:
902
+ print("\n**Test reinstantiated state**")
903
+ test_reinstantiated_state()
904
+
905
+ state.destroy_process_group()
906
+
907
+
908
+ if __name__ == "__main__":
909
+ main()
accelerate/test_utils/scripts/test_sync.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from copy import deepcopy
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch.optim import AdamW
20
+ from torch.optim.lr_scheduler import LambdaLR
21
+ from torch.utils.data import DataLoader
22
+
23
+ from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin
24
+ from accelerate.state import GradientState
25
+ from accelerate.test_utils import RegressionDataset, RegressionModel
26
+ from accelerate.utils import DistributedType, set_seed
27
+
28
+
29
+ def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs):
30
+ for param, grad_param in zip(model_a.parameters(), model_b.parameters()):
31
+ if not param.requires_grad:
32
+ continue
33
+ if not did_step:
34
+ # Grads should not be in sync
35
+ assert torch.allclose(param.grad, grad_param.grad, **kwargs) is False, (
36
+ f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})"
37
+ )
38
+ else:
39
+ # Grads should be in sync
40
+ assert torch.allclose(param.grad, grad_param.grad, **kwargs) is True, (
41
+ f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})"
42
+ )
43
+
44
+
45
+ def step_model(model, input, target, accelerator, do_backward=True):
46
+ model.train()
47
+ output = model(input)
48
+ loss = F.mse_loss(output, target.to(output.device))
49
+ if not do_backward:
50
+ loss /= accelerator.gradient_accumulation_steps
51
+ loss.backward()
52
+ else:
53
+ accelerator.backward(loss)
54
+
55
+
56
+ def get_training_setup(accelerator, sched=False):
57
+ "Returns everything needed to perform basic training"
58
+ set_seed(42)
59
+ model = RegressionModel()
60
+ ddp_model = deepcopy(model)
61
+ dset = RegressionDataset(length=80)
62
+ dataloader = DataLoader(dset, batch_size=16)
63
+ model.to(accelerator.device)
64
+ if sched:
65
+ opt = AdamW(params=model.parameters(), lr=1e-3)
66
+ ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3)
67
+ sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65)
68
+ ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65)
69
+ # Make a copy of `model`
70
+ if sched:
71
+ ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader)
72
+ else:
73
+ ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
74
+ if sched:
75
+ return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched)
76
+ return model, ddp_model, dataloader
77
+
78
+
79
+ def test_noop_sync(accelerator):
80
+ # Test when on a single CPU or GPU that the context manager does nothing
81
+ model, ddp_model, dataloader = get_training_setup(accelerator)
82
+ # Use a single batch
83
+ ddp_input, ddp_target = next(iter(dataloader)).values()
84
+ for iteration in range(3):
85
+ # Gather the distributed inputs and targs for the base model
86
+ input, target = accelerator.gather((ddp_input, ddp_target))
87
+ input, target = input.to(accelerator.device), target.to(accelerator.device)
88
+ # Perform our initial ground truth step in non "DDP"
89
+ step_model(model, input, target, accelerator)
90
+ # Do "gradient accumulation" (noop)
91
+ if iteration % 2 == 0:
92
+ # Accumulate grads locally
93
+ with accelerator.no_sync(ddp_model):
94
+ step_model(ddp_model, ddp_input, ddp_target, accelerator)
95
+ else:
96
+ # Sync grads
97
+ step_model(ddp_model, ddp_input, ddp_target, accelerator)
98
+
99
+ # Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync
100
+ check_model_parameters(model, ddp_model, True, iteration)
101
+ for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
102
+ if not param.requires_grad:
103
+ continue
104
+ assert torch.allclose(param.grad, ddp_param.grad), (
105
+ f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
106
+ )
107
+
108
+ # Shuffle ddp_input on each iteration
109
+ torch.manual_seed(1337 + iteration)
110
+ ddp_input = ddp_input[torch.randperm(len(ddp_input))]
111
+
112
+
113
+ def test_distributed_sync(accelerator):
114
+ # Test on distributed setup that context manager behaves properly
115
+ model, ddp_model, dataloader = get_training_setup(accelerator)
116
+ # Use a single batch
117
+ ddp_input, ddp_target = next(iter(dataloader)).values()
118
+ for iteration in range(3):
119
+ # Gather the distributed inputs and targs for the base model
120
+ input, target = accelerator.gather((ddp_input, ddp_target))
121
+ input, target = input.to(accelerator.device), target.to(accelerator.device)
122
+ # Perform our initial ground truth step in non "DDP"
123
+ step_model(model, input, target, accelerator)
124
+ # Do "gradient accumulation" (noop)
125
+ if iteration % 2 == 0:
126
+ # Accumulate grads locally
127
+ with accelerator.no_sync(ddp_model):
128
+ step_model(ddp_model, ddp_input, ddp_target, accelerator)
129
+ else:
130
+ # Sync grads
131
+ step_model(ddp_model, ddp_input, ddp_target, accelerator)
132
+
133
+ # DDP model and model should only be in sync when not (iteration % 2 == 0)
134
+ for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
135
+ if not param.requires_grad:
136
+ continue
137
+ if iteration % 2 == 0:
138
+ # Grads should not be in sync
139
+ assert torch.allclose(param.grad, ddp_param.grad) is False, (
140
+ f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
141
+ )
142
+ else:
143
+ # Grads should be in sync
144
+ assert torch.allclose(param.grad, ddp_param.grad) is True, (
145
+ f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
146
+ )
147
+
148
+ # Shuffle ddp_input on each iteration
149
+ torch.manual_seed(1337 + iteration)
150
+ ddp_input = ddp_input[torch.randperm(len(ddp_input))]
151
+
152
+
153
+ def test_distributed_sync_multiple_fwd(accelerator):
154
+ # Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards
155
+ model, ddp_model, dataloader = get_training_setup(accelerator)
156
+ # Do multiple forwards
157
+ losses = []
158
+ num_iterations = 3
159
+ for iteration in range(num_iterations):
160
+ ddp_input, ddp_target = next(iter(dataloader)).values()
161
+
162
+ # Gather the distributed inputs and targs for the base model
163
+ input, target = accelerator.gather((ddp_input, ddp_target))
164
+ input, target = input.to(accelerator.device), target.to(accelerator.device)
165
+
166
+ # Perform our initial ground truth step in non "DDP"
167
+ step_model(model, input, target, accelerator)
168
+
169
+ # Accumulate grads locally
170
+ with accelerator.no_sync(ddp_model):
171
+ ddp_output = ddp_model(ddp_input)
172
+ loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device))
173
+ losses.append(loss)
174
+
175
+ # Do multiple backwards and sync only at the last backward
176
+ for iteration in range(num_iterations):
177
+ loss = losses[iteration]
178
+
179
+ if iteration < num_iterations - 1:
180
+ # Accumulate grads locally
181
+ accelerator.backward(loss)
182
+
183
+ # DDP model and model should only be in sync after last backward
184
+ for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
185
+ if not param.requires_grad:
186
+ continue
187
+ # Grads should not be in sync
188
+ assert torch.allclose(param.grad, ddp_param.grad) is False, (
189
+ f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
190
+ )
191
+
192
+ else:
193
+ # Sync grads if last backward
194
+ with accelerator.trigger_sync_in_backward(ddp_model):
195
+ accelerator.backward(loss)
196
+
197
+ # DDP model and model should only be in sync after last backward
198
+ for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
199
+ if not param.requires_grad:
200
+ continue
201
+ # Grads should be in sync
202
+ assert torch.allclose(param.grad, ddp_param.grad) is True, (
203
+ f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
204
+ )
205
+
206
+
207
+ def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False):
208
+ gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
209
+ dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
210
+ accelerator = Accelerator(
211
+ dataloader_config=dataloader_config,
212
+ gradient_accumulation_plugin=gradient_accumulation_plugin,
213
+ )
214
+ # Test that context manager behaves properly
215
+ model, ddp_model, dataloader = get_training_setup(accelerator)
216
+ for iteration, batch in enumerate(dataloader):
217
+ ddp_input, ddp_target = batch.values()
218
+ # Gather the distributed inputs and targs for the base model
219
+ input, target = accelerator.gather((ddp_input, ddp_target))
220
+ input, target = input.to(accelerator.device), target.to(accelerator.device)
221
+ # Perform our initial ground truth step in non "DDP"
222
+ step_model(model, input, target, accelerator, False)
223
+ # Do "gradient accumulation" (noop)
224
+ with accelerator.accumulate(ddp_model):
225
+ step_model(ddp_model, ddp_input, ddp_target, accelerator)
226
+
227
+ # DDP model and model should only be in sync when not (iteration % 2 == 0)
228
+ for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
229
+ if not param.requires_grad:
230
+ continue
231
+ if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch:
232
+ # Grads should be in sync
233
+ assert torch.allclose(param.grad, ddp_param.grad) is True, (
234
+ f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
235
+ )
236
+ else:
237
+ # Grads should not be in sync
238
+ assert torch.allclose(param.grad, ddp_param.grad) is False, (
239
+ f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
240
+ )
241
+
242
+ # Shuffle ddp_input on each iteration
243
+ torch.manual_seed(1337 + iteration)
244
+ ddp_input = ddp_input[torch.randperm(len(ddp_input))]
245
+ GradientState._reset_state()
246
+
247
+
248
+ def test_gradient_accumulation_with_opt_and_scheduler(
249
+ split_batches=False, dispatch_batches=False, sync_each_batch=False
250
+ ):
251
+ gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
252
+ dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
253
+ accelerator = Accelerator(
254
+ dataloader_config=dataloader_config,
255
+ gradient_accumulation_plugin=gradient_accumulation_plugin,
256
+ )
257
+ # Test that context manager behaves properly
258
+ model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True)
259
+ for iteration, batch in enumerate(dataloader):
260
+ ddp_input, ddp_target = batch.values()
261
+ # Gather the distributed inputs and targs for the base model
262
+ input, target = accelerator.gather((ddp_input, ddp_target))
263
+ input, target = input.to(accelerator.device), target.to(accelerator.device)
264
+ # Perform our initial ground truth step in non "DDP"
265
+ model.train()
266
+ ddp_model.train()
267
+ step_model(model, input, target, accelerator, False)
268
+ opt.step()
269
+
270
+ if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)):
271
+ if split_batches:
272
+ sched.step()
273
+ else:
274
+ for _ in range(accelerator.num_processes):
275
+ sched.step()
276
+
277
+ # Perform gradient accumulation under wrapper
278
+ with accelerator.accumulate(ddp_model):
279
+ step_model(ddp_model, ddp_input, ddp_target, accelerator)
280
+ ddp_opt.step()
281
+ ddp_sched.step()
282
+
283
+ # Learning rates should be the same
284
+ assert opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"], (
285
+ f"Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]['lr']}\nDDP opt: {ddp_opt.param_groups[0]['lr']}\n"
286
+ )
287
+ did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader))
288
+ if accelerator.num_processes > 1:
289
+ check_model_parameters(
290
+ model,
291
+ ddp_model,
292
+ did_step or sync_each_batch, # syncs at each grad_accum interval of if sync_each_batch==True
293
+ iteration,
294
+ rtol=1e-3, # needs a relative tolerance due to roundoff errors
295
+ )
296
+
297
+ if did_step:
298
+ opt.zero_grad() # flush gradients every accum step
299
+ ddp_opt.zero_grad()
300
+
301
+ # Shuffle ddp_input on each iteration
302
+ torch.manual_seed(1337 + iteration)
303
+ GradientState._reset_state()
304
+
305
+
306
+ def test_dataloader_break():
307
+ accelerator = Accelerator()
308
+ first_dset = RegressionDataset(length=80)
309
+ first_dataloader = DataLoader(first_dset, batch_size=16)
310
+ second_dset = RegressionDataset(length=96)
311
+ second_dataloader = DataLoader(second_dset, batch_size=16)
312
+ first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
313
+
314
+ assert accelerator.gradient_state.active_dataloader is None
315
+ for iteration, _ in enumerate(first_dataloader):
316
+ assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
317
+ if iteration < len(first_dataloader) - 1:
318
+ assert not accelerator.gradient_state.end_of_dataloader
319
+ if iteration == 1:
320
+ for batch_num, _ in enumerate(second_dataloader):
321
+ assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader)
322
+ if batch_num < len(second_dataloader) - 1:
323
+ assert not accelerator.gradient_state.end_of_dataloader
324
+ else:
325
+ assert accelerator.gradient_state.end_of_dataloader
326
+ else:
327
+ assert accelerator.gradient_state.end_of_dataloader
328
+ assert accelerator.gradient_state.active_dataloader is None
329
+
330
+
331
+ def main():
332
+ accelerator = Accelerator()
333
+ state = accelerator.state
334
+ if state.local_process_index == 0:
335
+ print("**Test `accumulate` gradient accumulation with dataloader break**")
336
+ if state.distributed_type != DistributedType.XLA:
337
+ test_dataloader_break()
338
+ if state.distributed_type == DistributedType.NO:
339
+ if state.local_process_index == 0:
340
+ print("**Test NOOP `no_sync` context manager**")
341
+ test_noop_sync(accelerator)
342
+ if state.distributed_type in (
343
+ DistributedType.MULTI_GPU,
344
+ DistributedType.MULTI_NPU,
345
+ DistributedType.MULTI_MLU,
346
+ DistributedType.MULTI_SDAA,
347
+ DistributedType.MULTI_MUSA,
348
+ DistributedType.MULTI_CPU,
349
+ DistributedType.MULTI_HPU,
350
+ DistributedType.MULTI_NEURON,
351
+ ):
352
+ if state.local_process_index == 0:
353
+ print("**Test Distributed `no_sync` context manager**")
354
+ test_distributed_sync(accelerator)
355
+ if state.local_process_index == 0:
356
+ print("**Test Distributed `no_sync` context manager with multiple forwards**")
357
+ test_distributed_sync_multiple_fwd(accelerator)
358
+ if state.distributed_type in (
359
+ DistributedType.MULTI_GPU,
360
+ DistributedType.MULTI_NPU,
361
+ DistributedType.MULTI_MLU,
362
+ DistributedType.MULTI_SDAA,
363
+ DistributedType.MULTI_MUSA,
364
+ DistributedType.MULTI_HPU,
365
+ DistributedType.MULTI_NEURON,
366
+ ):
367
+ for split_batch in [True, False]:
368
+ for dispatch_batches in [True, False]:
369
+ for sync_each_batch in [True, False]:
370
+ if state.local_process_index == 0:
371
+ print(
372
+ "**Test `accumulate` gradient accumulation, ",
373
+ f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
374
+ )
375
+ test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch)
376
+
377
+ # Currently will break on torch 2.0 +, need to investigate why
378
+ if state.local_process_index == 0:
379
+ print(
380
+ "**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
381
+ "`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**",
382
+ )
383
+ test_gradient_accumulation_with_opt_and_scheduler()
384
+ if state.distributed_type in (
385
+ DistributedType.MULTI_GPU,
386
+ DistributedType.MULTI_NPU,
387
+ DistributedType.MULTI_MLU,
388
+ DistributedType.MULTI_SDAA,
389
+ DistributedType.MULTI_MUSA,
390
+ DistributedType.MULTI_HPU,
391
+ DistributedType.MULTI_NEURON,
392
+ ):
393
+ for split_batch in [True, False]:
394
+ for dispatch_batches in [True, False]:
395
+ for sync_each_batch in [True, False]:
396
+ if not split_batch and not dispatch_batches and not sync_each_batch:
397
+ continue
398
+ if state.local_process_index == 0:
399
+ print(
400
+ "**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
401
+ f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
402
+ )
403
+ test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
404
+ state.destroy_process_group()
405
+
406
+
407
+ def _mp_fn(index):
408
+ # For xla_spawn (TPUs)
409
+ main()
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()
accelerate/test_utils/testing.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import inspect
17
+ import io
18
+ import os
19
+ import re
20
+ import shutil
21
+ import subprocess
22
+ import sys
23
+ import tempfile
24
+ import unittest
25
+ from contextlib import contextmanager
26
+ from functools import partial
27
+ from pathlib import Path
28
+ from typing import Optional, Union
29
+ from unittest import mock
30
+
31
+ import torch
32
+
33
+ import accelerate
34
+
35
+ from ..state import AcceleratorState
36
+ from ..utils import (
37
+ check_cuda_fp8_capability,
38
+ compare_versions,
39
+ gather,
40
+ is_aim_available,
41
+ is_bnb_available,
42
+ is_clearml_available,
43
+ is_comet_ml_available,
44
+ is_cuda_available,
45
+ is_datasets_available,
46
+ is_deepspeed_available,
47
+ is_dvclive_available,
48
+ is_fp8_available,
49
+ is_fp16_available,
50
+ is_habana_gaudi1,
51
+ is_hpu_available,
52
+ is_import_timer_available,
53
+ is_matplotlib_available,
54
+ is_mlflow_available,
55
+ is_mlu_available,
56
+ is_mps_available,
57
+ is_musa_available,
58
+ is_neuron_available,
59
+ is_npu_available,
60
+ is_pandas_available,
61
+ is_pippy_available,
62
+ is_pytest_available,
63
+ is_schedulefree_available,
64
+ is_sdaa_available,
65
+ is_swanlab_available,
66
+ is_tensorboard_available,
67
+ is_timm_available,
68
+ is_torch_version,
69
+ is_torch_xla_available,
70
+ is_torchao_available,
71
+ is_torchdata_stateful_dataloader_available,
72
+ is_torchvision_available,
73
+ is_trackio_available,
74
+ is_transformer_engine_available,
75
+ is_transformer_engine_mxfp8_available,
76
+ is_transformers_available,
77
+ is_triton_available,
78
+ is_wandb_available,
79
+ is_xpu_available,
80
+ str_to_bool,
81
+ )
82
+
83
+
84
+ def get_backend():
85
+ if is_torch_xla_available():
86
+ return "xla", torch.cuda.device_count(), torch.cuda.memory_allocated
87
+ elif is_cuda_available():
88
+ return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated
89
+ elif is_mps_available(min_version="2.0"):
90
+ return "mps", 1, torch.mps.current_allocated_memory
91
+ elif is_mps_available():
92
+ return "mps", 1, lambda: 0
93
+ elif is_mlu_available():
94
+ return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated
95
+ elif is_sdaa_available():
96
+ return "sdaa", torch.sdaa.device_count(), torch.sdaa.memory_allocated
97
+ elif is_musa_available():
98
+ return "musa", torch.musa.device_count(), torch.musa.memory_allocated
99
+ elif is_npu_available():
100
+ return "npu", torch.npu.device_count(), torch.npu.memory_allocated
101
+ elif is_xpu_available():
102
+ return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated
103
+ elif is_hpu_available():
104
+ return "hpu", torch.hpu.device_count(), torch.hpu.memory_allocated
105
+ elif is_neuron_available():
106
+ return "neuron", torch.neuron.device_count(), torch.neuron.memory_allocated
107
+ else:
108
+ return "cpu", 1, lambda: 0
109
+
110
+
111
+ torch_device, device_count, memory_allocated_func = get_backend()
112
+
113
+
114
+ def get_launch_command(**kwargs) -> list:
115
+ """
116
+ Wraps around `kwargs` to help simplify launching from `subprocess`.
117
+
118
+ Example:
119
+ ```python
120
+ # returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2']
121
+ get_launch_command(num_processes=2, device_count=2)
122
+ ```
123
+ """
124
+ command = ["accelerate", "launch"]
125
+ for k, v in kwargs.items():
126
+ if isinstance(v, bool) and v:
127
+ command.append(f"--{k}")
128
+ elif v is not None:
129
+ command.append(f"--{k}={v}")
130
+ return command
131
+
132
+
133
+ DEFAULT_LAUNCH_COMMAND = get_launch_command(num_processes=device_count, monitor_interval=0.1)
134
+
135
+
136
+ def parse_flag_from_env(key, default=False):
137
+ try:
138
+ value = os.environ[key]
139
+ except KeyError:
140
+ # KEY isn't set, default to `default`.
141
+ _value = default
142
+ else:
143
+ # KEY is set, convert it to True or False.
144
+ try:
145
+ _value = str_to_bool(value)
146
+ except ValueError:
147
+ # More values are supported, but let's keep the message simple.
148
+ raise ValueError(f"If set, {key} must be yes or no.")
149
+ return _value
150
+
151
+
152
+ _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
153
+
154
+
155
+ def skip(test_case):
156
+ "Decorator that skips a test unconditionally"
157
+ return unittest.skip("Test was skipped")(test_case)
158
+
159
+
160
+ def slow(test_case):
161
+ """
162
+ Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
163
+ truthy value to run them.
164
+ """
165
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
166
+
167
+
168
+ def require_cpu(test_case):
169
+ """
170
+ Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available.
171
+ """
172
+ return unittest.skipUnless(torch_device == "cpu", "test requires only a CPU")(test_case)
173
+
174
+
175
+ def require_non_cpu(test_case):
176
+ """
177
+ Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
178
+ hardware accelerator available.
179
+ """
180
+ return unittest.skipUnless(torch_device != "cpu", "test requires a GPU")(test_case)
181
+
182
+
183
+ def require_cuda(test_case):
184
+ """
185
+ Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available or when
186
+ TorchXLA is available.
187
+ """
188
+ return unittest.skipUnless(is_cuda_available() and not is_torch_xla_available(), "test requires a GPU")(test_case)
189
+
190
+
191
+ def require_cuda_or_hpu(test_case):
192
+ """
193
+ Decorator marking a test that requires CUDA or HPU. These tests are skipped when there are no GPU available or when
194
+ TorchXLA is available.
195
+ """
196
+ return unittest.skipUnless(
197
+ (is_cuda_available() and not is_torch_xla_available()) or is_hpu_available(), "test requires a GPU or HPU"
198
+ )(test_case)
199
+
200
+
201
+ def require_xpu(test_case):
202
+ """
203
+ Decorator marking a test that requires XPU. These tests are skipped when there are no XPU available.
204
+ """
205
+ return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case)
206
+
207
+
208
+ def require_cuda_or_xpu(test_case):
209
+ """
210
+ Decorator marking a test that requires CUDA or XPU. These tests are skipped when there are no GPU available or when
211
+ TorchXLA is available.
212
+ """
213
+ cuda_condition = is_cuda_available() and not is_torch_xla_available()
214
+ xpu_condition = is_xpu_available()
215
+ return unittest.skipUnless(cuda_condition or xpu_condition, "test requires a CUDA GPU or XPU")(test_case)
216
+
217
+
218
+ def require_non_xpu(test_case):
219
+ """
220
+ Decorator marking a test that should be skipped for XPU.
221
+ """
222
+ return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
223
+
224
+
225
+ def require_non_hpu(test_case):
226
+ """
227
+ Decorator marking a test that should be skipped for HPU.
228
+ """
229
+ return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
230
+
231
+
232
+ def require_fp16(test_case):
233
+ """
234
+ Decorator marking a test that requires FP16. These tests are skipped when FP16 is not supported.
235
+ """
236
+
237
+ return unittest.skipUnless(is_fp16_available(), "test requires FP16 support")(test_case)
238
+
239
+
240
+ def require_fp8(test_case):
241
+ """
242
+ Decorator marking a test that requires FP8. These tests are skipped when FP8 is not supported.
243
+ """
244
+
245
+ # is_fp8_available only checks for libraries
246
+ # ideally it should check for device capability as well
247
+ fp8_is_available = is_fp8_available()
248
+
249
+ if torch.cuda.is_available() and not check_cuda_fp8_capability():
250
+ fp8_is_available = False
251
+
252
+ if is_hpu_available() and is_habana_gaudi1():
253
+ fp8_is_available = False
254
+
255
+ return unittest.skipUnless(fp8_is_available, "test requires FP8 support")(test_case)
256
+
257
+
258
+ def require_fsdp2(test_case):
259
+ return unittest.skipUnless(is_torch_version(">=", "2.5.0"), "test requires FSDP2 (torch >= 2.5.0)")(test_case)
260
+
261
+
262
+ def require_mlu(test_case):
263
+ """
264
+ Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.
265
+ """
266
+ return unittest.skipUnless(is_mlu_available(), "test require a MLU")(test_case)
267
+
268
+
269
+ def require_sdaa(test_case):
270
+ """
271
+ Decorator marking a test that requires SDAA. These tests are skipped when there are no SDAA available.
272
+ """
273
+ return unittest.skipUnless(is_sdaa_available(), "test require a SDAA")(test_case)
274
+
275
+
276
+ def require_musa(test_case):
277
+ """
278
+ Decorator marking a test that requires MUSA. These tests are skipped when there are no MUSA available.
279
+ """
280
+ return unittest.skipUnless(is_musa_available(), "test require a MUSA")(test_case)
281
+
282
+
283
+ def require_npu(test_case):
284
+ """
285
+ Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available.
286
+ """
287
+ return unittest.skipUnless(is_npu_available(), "test require a NPU")(test_case)
288
+
289
+
290
+ def require_neuron(test_case):
291
+ """
292
+ Decorator marking a test that requires Neuron. These tests are skipped when there are no Neuron Cores available.
293
+ """
294
+ return unittest.skipUnless(is_neuron_available(), "test require Neuron Cores")(test_case)
295
+
296
+
297
+ def require_mps(test_case):
298
+ """
299
+ Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps`
300
+ backend.
301
+ """
302
+ return unittest.skipUnless(is_mps_available(), "test requires a `mps` backend support in `torch`")(test_case)
303
+
304
+
305
+ def require_huggingface_suite(test_case):
306
+ """
307
+ Decorator marking a test that requires transformers and datasets. These tests are skipped when they are not.
308
+ """
309
+ return unittest.skipUnless(
310
+ is_transformers_available() and is_datasets_available(),
311
+ "test requires the Hugging Face suite",
312
+ )(test_case)
313
+
314
+
315
+ def require_transformers(test_case):
316
+ """
317
+ Decorator marking a test that requires transformers. These tests are skipped when they are not.
318
+ """
319
+ return unittest.skipUnless(is_transformers_available(), "test requires the transformers library")(test_case)
320
+
321
+
322
+ def require_timm(test_case):
323
+ """
324
+ Decorator marking a test that requires timm. These tests are skipped when they are not.
325
+ """
326
+ return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case)
327
+
328
+
329
+ def require_torchvision(test_case):
330
+ """
331
+ Decorator marking a test that requires torchvision. These tests are skipped when they are not.
332
+ """
333
+ return unittest.skipUnless(is_torchvision_available(), "test requires the torchvision library")(test_case)
334
+
335
+
336
+ def require_triton(test_case):
337
+ """
338
+ Decorator marking a test that requires triton. These tests are skipped when they are not.
339
+ """
340
+ return unittest.skipUnless(is_triton_available(), "test requires the triton library")(test_case)
341
+
342
+
343
+ def require_schedulefree(test_case):
344
+ """
345
+ Decorator marking a test that requires schedulefree. These tests are skipped when they are not.
346
+ """
347
+ return unittest.skipUnless(is_schedulefree_available(), "test requires the schedulefree library")(test_case)
348
+
349
+
350
+ def require_bnb(test_case):
351
+ """
352
+ Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not.
353
+ """
354
+ return unittest.skipUnless(is_bnb_available(), "test requires the bitsandbytes library")(test_case)
355
+
356
+
357
+ def require_tpu(test_case):
358
+ """
359
+ Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available.
360
+ """
361
+ return unittest.skipUnless(is_torch_xla_available(check_is_tpu=True), "test requires TPU")(test_case)
362
+
363
+
364
+ def require_non_torch_xla(test_case):
365
+ """
366
+ Decorator marking a test as requiring an environment without TorchXLA. These tests are skipped when TorchXLA is
367
+ available.
368
+ """
369
+ return unittest.skipUnless(not is_torch_xla_available(), "test requires an env without TorchXLA")(test_case)
370
+
371
+
372
+ def require_single_device(test_case):
373
+ """
374
+ Decorator marking a test that requires a single device. These tests are skipped when there is no hardware
375
+ accelerator available or number of devices is more than one.
376
+ """
377
+ return unittest.skipUnless(
378
+ torch_device != "cpu" and device_count == 1, "test requires a single device accelerator"
379
+ )(test_case)
380
+
381
+
382
+ def require_single_gpu(test_case):
383
+ """
384
+ Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU
385
+ available or number of GPUs is more than one.
386
+ """
387
+ return unittest.skipUnless(torch.cuda.device_count() == 1, "test requires a GPU")(test_case)
388
+
389
+
390
+ def require_single_xpu(test_case):
391
+ """
392
+ Decorator marking a test that requires CUDA on a single XPU. These tests are skipped when there are no XPU
393
+ available or number of xPUs is more than one.
394
+ """
395
+ return unittest.skipUnless(torch.xpu.device_count() == 1, "test requires a XPU")(test_case)
396
+
397
+
398
+ def require_multi_device(test_case):
399
+ """
400
+ Decorator marking a test that requires a multi-device setup. These tests are skipped on a machine without multiple
401
+ devices.
402
+ """
403
+ return unittest.skipUnless(device_count > 1, "test requires multiple hardware accelerators")(test_case)
404
+
405
+
406
+ def require_multi_gpu(test_case):
407
+ """
408
+ Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
409
+ GPUs.
410
+ """
411
+ return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
412
+
413
+
414
+ def require_multi_xpu(test_case):
415
+ """
416
+ Decorator marking a test that requires a multi-XPU setup. These tests are skipped on a machine without multiple
417
+ XPUs.
418
+ """
419
+ return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
420
+
421
+
422
+ def require_multi_gpu_or_xpu(test_case):
423
+ """
424
+ Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
425
+ GPUs or XPUs.
426
+ """
427
+ return unittest.skipUnless(
428
+ (is_cuda_available() or is_xpu_available()) and device_count > 1, "test requires multiple GPUs or XPUs"
429
+ )(test_case)
430
+
431
+
432
+ def require_deepspeed(test_case):
433
+ """
434
+ Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
435
+ """
436
+ return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case)
437
+
438
+
439
+ def require_tp(test_case):
440
+ """
441
+ Decorator marking a test that requires TP installed. These tests are skipped when TP isn't installed
442
+ """
443
+ return unittest.skipUnless(
444
+ is_torch_version(">=", "2.3.0") and compare_versions("transformers", ">=", "4.52.0"),
445
+ "test requires torch version >= 2.3.0 and transformers version >= 4.52.0",
446
+ )(test_case)
447
+
448
+
449
+ def require_torch_min_version(test_case=None, version=None):
450
+ """
451
+ Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an
452
+ installed torch version is less than the required one.
453
+ """
454
+ if test_case is None:
455
+ return partial(require_torch_min_version, version=version)
456
+ return unittest.skipUnless(is_torch_version(">=", version), f"test requires torch version >= {version}")(test_case)
457
+
458
+
459
+ def require_tensorboard(test_case):
460
+ """
461
+ Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't
462
+ installed
463
+ """
464
+ return unittest.skipUnless(is_tensorboard_available(), "test requires Tensorboard")(test_case)
465
+
466
+
467
+ def require_wandb(test_case):
468
+ """
469
+ Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed
470
+ """
471
+ return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
472
+
473
+
474
+ def require_trackio(test_case):
475
+ """
476
+ Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed
477
+ """
478
+ return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
479
+
480
+
481
+ def require_comet_ml(test_case):
482
+ """
483
+ Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
484
+ """
485
+ return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case)
486
+
487
+
488
+ def require_aim(test_case):
489
+ """
490
+ Decorator marking a test that requires aim installed. These tests are skipped when aim isn't installed
491
+ """
492
+ return unittest.skipUnless(is_aim_available(), "test requires aim")(test_case)
493
+
494
+
495
+ def require_clearml(test_case):
496
+ """
497
+ Decorator marking a test that requires clearml installed. These tests are skipped when clearml isn't installed
498
+ """
499
+ return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
500
+
501
+
502
+ def require_dvclive(test_case):
503
+ """
504
+ Decorator marking a test that requires dvclive installed. These tests are skipped when dvclive isn't installed
505
+ """
506
+ return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)
507
+
508
+
509
+ def require_swanlab(test_case):
510
+ """
511
+ Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
512
+ """
513
+ return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
514
+
515
+
516
+ def require_pandas(test_case):
517
+ """
518
+ Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
519
+ """
520
+ return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
521
+
522
+
523
+ def require_mlflow(test_case):
524
+ """
525
+ Decorator marking a test that requires mlflow installed. These tests are skipped when mlflow isn't installed
526
+ """
527
+ return unittest.skipUnless(is_mlflow_available(), "test requires mlflow")(test_case)
528
+
529
+
530
+ def require_pippy(test_case):
531
+ """
532
+ Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed It is
533
+ also checked if the test is running on a Gaudi1 device which doesn't support pippy.
534
+ """
535
+ return unittest.skipUnless(is_pippy_available() and not is_habana_gaudi1(), "test requires pippy")(test_case)
536
+
537
+
538
+ def require_import_timer(test_case):
539
+ """
540
+ Decorator marking a test that requires tuna interpreter installed. These tests are skipped when tuna isn't
541
+ installed
542
+ """
543
+ return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case)
544
+
545
+
546
+ def require_transformer_engine(test_case):
547
+ """
548
+ Decorator marking a test that requires transformers engine installed. These tests are skipped when transformers
549
+ engine isn't installed
550
+ """
551
+ return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
552
+
553
+
554
+ def require_transformer_engine_mxfp8(test_case):
555
+ """
556
+ Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
557
+ when transformers engine MXFP8 block scaling isn't available
558
+ """
559
+ return unittest.skipUnless(
560
+ is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
561
+ )(test_case)
562
+
563
+
564
+ def require_torchao(test_case):
565
+ """
566
+ Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
567
+ """
568
+ return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
569
+
570
+
571
+ def require_matplotlib(test_case):
572
+ """
573
+ Decorator marking a test that requires matplotlib installed. These tests are skipped when matplotlib isn't
574
+ installed
575
+ """
576
+ return unittest.skipUnless(is_matplotlib_available(), "test requires matplotlib")(test_case)
577
+
578
+
579
+ _atleast_one_tracker_available = (
580
+ any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])
581
+ and not is_comet_ml_available()
582
+ )
583
+
584
+
585
+ def require_trackers(test_case):
586
+ """
587
+ Decorator marking that a test requires at least one tracking library installed. These tests are skipped when none
588
+ are installed
589
+ """
590
+ return unittest.skipUnless(
591
+ _atleast_one_tracker_available,
592
+ "test requires at least one tracker to be available and for `comet_ml` to not be installed",
593
+ )(test_case)
594
+
595
+
596
+ def require_torchdata_stateful_dataloader(test_case):
597
+ """
598
+ Decorator marking a test that requires torchdata.stateful_dataloader.
599
+
600
+ These tests are skipped when torchdata with stateful_dataloader module isn't installed.
601
+
602
+ """
603
+ return unittest.skipUnless(
604
+ is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader"
605
+ )(test_case)
606
+
607
+
608
+ def run_first(test_case):
609
+ """
610
+ Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are
611
+ guaranteed to run first.
612
+
613
+ This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
614
+ single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
615
+ allocation conflicts.
616
+
617
+ If pytest is not installed, test will be returned as is.
618
+ """
619
+
620
+ if is_pytest_available():
621
+ import pytest
622
+
623
+ return pytest.mark.order(1)(test_case)
624
+ return test_case
625
+
626
+
627
+ class TempDirTestCase(unittest.TestCase):
628
+ """
629
+ A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
630
+ data at the start of a test, and then destroys it at the end of the TestCase.
631
+
632
+ Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
633
+
634
+ The temporary directory location will be stored in `self.tmpdir`
635
+ """
636
+
637
+ clear_on_setup = True
638
+
639
+ @classmethod
640
+ def setUpClass(cls):
641
+ "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
642
+ cls.tmpdir = Path(tempfile.mkdtemp())
643
+
644
+ @classmethod
645
+ def tearDownClass(cls):
646
+ "Remove `cls.tmpdir` after test suite has finished"
647
+ if os.path.exists(cls.tmpdir):
648
+ shutil.rmtree(cls.tmpdir)
649
+
650
+ def setUp(self):
651
+ "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
652
+ if self.clear_on_setup:
653
+ for path in self.tmpdir.glob("**/*"):
654
+ if path.is_file():
655
+ path.unlink()
656
+ elif path.is_dir():
657
+ shutil.rmtree(path)
658
+
659
+
660
+ class AccelerateTestCase(unittest.TestCase):
661
+ """
662
+ A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
663
+ the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between
664
+ tests.
665
+ """
666
+
667
+ def tearDown(self):
668
+ super().tearDown()
669
+ # Reset the state of the AcceleratorState singleton.
670
+ AcceleratorState._reset_state(True)
671
+
672
+
673
+ class MockingTestCase(unittest.TestCase):
674
+ """
675
+ A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the
676
+ behavior of a class-wide mock when defining one normally will not do.
677
+
678
+ Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as
679
+ setting an environment variable with that information.
680
+
681
+ The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to
682
+ `super().setUp()` such as:
683
+ ```python
684
+ def setUp(self):
685
+ super().setUp()
686
+ mocks = mock.patch.dict(os.environ, {"SOME_ENV_VAR", "SOME_VALUE"})
687
+ self.add_mocks(mocks)
688
+ ```
689
+ """
690
+
691
+ def add_mocks(self, mocks: Union[mock.Mock, list[mock.Mock]]):
692
+ """
693
+ Add custom mocks for tests that should be repeated on each test. Should be called during
694
+ `MockingTestCase.setUp`, after `super().setUp()`.
695
+
696
+ Args:
697
+ mocks (`mock.Mock` or list of `mock.Mock`):
698
+ Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run
699
+ """
700
+ self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks]
701
+ for m in self.mocks:
702
+ m.start()
703
+ self.addCleanup(m.stop)
704
+
705
+
706
+ def are_the_same_tensors(tensor):
707
+ state = AcceleratorState()
708
+ tensor = tensor[None].clone().to(state.device)
709
+ tensors = gather(tensor).cpu()
710
+ tensor = tensor[0].cpu()
711
+ for i in range(tensors.shape[0]):
712
+ if not torch.equal(tensors[i], tensor):
713
+ return False
714
+ return True
715
+
716
+
717
+ class _RunOutput:
718
+ def __init__(self, returncode, stdout, stderr):
719
+ self.returncode = returncode
720
+ self.stdout = stdout
721
+ self.stderr = stderr
722
+
723
+
724
+ async def _read_stream(stream, callback):
725
+ while True:
726
+ line = await stream.readline()
727
+ if line:
728
+ callback(line)
729
+ else:
730
+ break
731
+
732
+
733
+ async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
734
+ if echo:
735
+ print("\nRunning: ", " ".join(cmd))
736
+
737
+ p = await asyncio.create_subprocess_exec(
738
+ cmd[0],
739
+ *cmd[1:],
740
+ stdin=stdin,
741
+ stdout=asyncio.subprocess.PIPE,
742
+ stderr=asyncio.subprocess.PIPE,
743
+ env=env,
744
+ )
745
+
746
+ # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
747
+ # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
748
+ #
749
+ # If it starts hanging, will need to switch to the following code. The problem is that no data
750
+ # will be seen until it's done and if it hangs for example there will be no debug info.
751
+ # out, err = await p.communicate()
752
+ # return _RunOutput(p.returncode, out, err)
753
+
754
+ out = []
755
+ err = []
756
+
757
+ def tee(line, sink, pipe, label=""):
758
+ line = line.decode("utf-8").rstrip()
759
+ sink.append(line)
760
+ if not quiet:
761
+ print(label, line, file=pipe)
762
+
763
+ # XXX: the timeout doesn't seem to make any difference here
764
+ await asyncio.wait(
765
+ [
766
+ asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
767
+ asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
768
+ ],
769
+ timeout=timeout,
770
+ )
771
+ return _RunOutput(await p.wait(), out, err)
772
+
773
+
774
+ def execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
775
+ # Cast every path in `cmd` to a string
776
+ for i, c in enumerate(cmd):
777
+ if isinstance(c, Path):
778
+ cmd[i] = str(c)
779
+
780
+ result = asyncio.run(_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo))
781
+
782
+ cmd_str = " ".join(cmd)
783
+ if result.returncode > 0:
784
+ stderr = "\n".join(result.stderr)
785
+ raise RuntimeError(
786
+ f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
787
+ f"The combined stderr from workers follows:\n{stderr}"
788
+ )
789
+
790
+ return result
791
+
792
+
793
+ def pytest_xdist_worker_id():
794
+ """
795
+ Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
796
+ if `-n 1` or `pytest-xdist` isn't being used.
797
+ """
798
+ worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
799
+ worker = re.sub(r"^gw", "", worker, 0, re.M)
800
+ return int(worker)
801
+
802
+
803
+ def get_torch_dist_unique_port():
804
+ """
805
+ Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
806
+
807
+ Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
808
+ port at once.
809
+ """
810
+ port = 29500
811
+ uniq_delta = pytest_xdist_worker_id()
812
+ return port + uniq_delta
813
+
814
+
815
+ class SubprocessCallException(Exception):
816
+ pass
817
+
818
+
819
+ def run_command(command: list[str], return_stdout=False, env=None):
820
+ """
821
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
822
+ if an error occurred while running `command`
823
+ """
824
+ # Cast every path in `command` to a string
825
+ for i, c in enumerate(command):
826
+ if isinstance(c, Path):
827
+ command[i] = str(c)
828
+ if env is None:
829
+ env = os.environ.copy()
830
+ try:
831
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
832
+ if return_stdout:
833
+ if hasattr(output, "decode"):
834
+ output = output.decode("utf-8")
835
+ return output
836
+ except subprocess.CalledProcessError as e:
837
+ raise SubprocessCallException(
838
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
839
+ ) from e
840
+
841
+
842
+ def path_in_accelerate_package(*components: str) -> Path:
843
+ """
844
+ Get a path within the `accelerate` package's directory.
845
+
846
+ Args:
847
+ *components: Components of the path to join after the package directory.
848
+
849
+ Returns:
850
+ `Path`: The path to the requested file or directory.
851
+ """
852
+
853
+ accelerate_package_dir = Path(inspect.getfile(accelerate)).parent
854
+ return accelerate_package_dir.joinpath(*components)
855
+
856
+
857
+ @contextmanager
858
+ def assert_exception(exception_class: Exception, msg: Optional[str] = None) -> bool:
859
+ """
860
+ Context manager to assert that the right `Exception` class was raised.
861
+
862
+ If `msg` is provided, will check that the message is contained in the raised exception.
863
+ """
864
+ was_ran = False
865
+ try:
866
+ yield
867
+ was_ran = True
868
+ except Exception as e:
869
+ assert isinstance(e, exception_class), f"Expected exception of type {exception_class} but got {type(e)}"
870
+ if msg is not None:
871
+ assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'"
872
+ if was_ran:
873
+ raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.")
874
+
875
+
876
+ def capture_call_output(func, *args, **kwargs):
877
+ """
878
+ Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string
879
+ """
880
+ captured_output = io.StringIO()
881
+ original_stdout = sys.stdout
882
+ try:
883
+ sys.stdout = captured_output
884
+ func(*args, **kwargs)
885
+ except Exception as e:
886
+ raise e
887
+ finally:
888
+ sys.stdout = original_stdout
889
+ return captured_output.getvalue()
accelerate/test_utils/training.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ from torch.utils.data import DataLoader
18
+
19
+ from accelerate.utils.dataclasses import DistributedType
20
+
21
+
22
+ class RegressionDataset:
23
+ def __init__(self, a=2, b=3, length=64, seed=None):
24
+ rng = np.random.default_rng(seed)
25
+ self.length = length
26
+ self.x = rng.normal(size=(length,)).astype(np.float32)
27
+ self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)
28
+
29
+ def __len__(self):
30
+ return self.length
31
+
32
+ def __getitem__(self, i):
33
+ return {"x": self.x[i], "y": self.y[i]}
34
+
35
+
36
+ class RegressionModel(torch.nn.Module):
37
+ def __init__(self, a=0, b=0, double_output=False):
38
+ super().__init__()
39
+ self.a = torch.nn.Parameter(torch.tensor(a).float())
40
+ self.b = torch.nn.Parameter(torch.tensor(b).float())
41
+ self.first_batch = True
42
+
43
+ def forward(self, x=None):
44
+ if self.first_batch:
45
+ print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}")
46
+ self.first_batch = False
47
+ return x * self.a + self.b
48
+
49
+
50
+ def mocked_dataloaders(accelerator, batch_size: int = 16):
51
+ from datasets import load_dataset
52
+ from transformers import AutoTokenizer
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
55
+ data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
56
+ datasets = load_dataset("csv", data_files=data_files)
57
+ label_list = datasets["train"].unique("label")
58
+
59
+ label_to_id = {v: i for i, v in enumerate(label_list)}
60
+
61
+ def tokenize_function(examples):
62
+ # max_length=None => use the model max length (it's actually the default)
63
+ outputs = tokenizer(
64
+ examples["sentence1"], examples["sentence2"], truncation=True, max_length=None, padding="max_length"
65
+ )
66
+ if "label" in examples:
67
+ outputs["labels"] = [label_to_id[l] for l in examples["label"]]
68
+ return outputs
69
+
70
+ # Apply the method we just defined to all the examples in all the splits of the dataset
71
+ tokenized_datasets = datasets.map(
72
+ tokenize_function,
73
+ batched=True,
74
+ remove_columns=["sentence1", "sentence2", "label"],
75
+ )
76
+
77
+ def collate_fn(examples):
78
+ # On TPU it's best to pad everything to the same length or training will be very slow.
79
+ if accelerator.distributed_type == DistributedType.XLA:
80
+ return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
81
+ return tokenizer.pad(examples, padding="longest", return_tensors="pt")
82
+
83
+ # Instantiate dataloaders.
84
+ train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=2)
85
+ eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)
86
+
87
+ return train_dataloader, eval_dataloader
88
+
89
+
90
+ def mocked_dataloaders_for_autoregressive_models(accelerator, batch_size: int = 16):
91
+ from datasets import load_dataset
92
+ from transformers import AutoTokenizer
93
+
94
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M")
95
+ tokenizer.pad_token = tokenizer.eos_token
96
+
97
+ data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
98
+ datasets = load_dataset("csv", data_files=data_files)
99
+
100
+ def tokenize_function(examples):
101
+ # max_length=None => use the model max length (it's actually the default)
102
+ outputs = tokenizer(examples["sentence1"], truncation=True, max_length=None, return_attention_mask=False)
103
+ return outputs
104
+
105
+ # Apply the method we just defined to all the examples in all the splits of the dataset
106
+ # starting with the main process first:
107
+ with accelerator.main_process_first():
108
+ tokenized_datasets = datasets.map(
109
+ tokenize_function,
110
+ batched=True,
111
+ remove_columns=["sentence1", "sentence2", "label"],
112
+ )
113
+
114
+ def collate_fn(examples):
115
+ # On TPU it's best to pad everything to the same length or training will be very slow.
116
+ max_length = (
117
+ 128
118
+ if accelerator.distributed_type == DistributedType.XLA
119
+ else max([len(e["input_ids"]) for e in examples])
120
+ )
121
+ # When using mixed precision we want round multiples of 8/16
122
+ if accelerator.mixed_precision == "fp8":
123
+ pad_to_multiple_of = 16
124
+ elif accelerator.mixed_precision != "no":
125
+ pad_to_multiple_of = 8
126
+ else:
127
+ pad_to_multiple_of = None
128
+
129
+ batch = tokenizer.pad(
130
+ examples,
131
+ padding="max_length",
132
+ max_length=max_length + 1,
133
+ pad_to_multiple_of=pad_to_multiple_of,
134
+ return_tensors="pt",
135
+ )
136
+
137
+ batch["labels"] = batch["input_ids"][:, 1:]
138
+ batch["input_ids"] = batch["input_ids"][:, :-1]
139
+ if "attention_mask" in batch:
140
+ batch["attention_mask"] = batch["attention_mask"][:, :-1]
141
+
142
+ batch["labels"] = torch.where(batch["labels"] == tokenizer.pad_token_id, -100, batch["labels"])
143
+
144
+ return batch
145
+
146
+ # Instantiate dataloaders.
147
+ train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=2)
148
+ eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)
149
+
150
+ return train_dataloader, eval_dataloader
accelerate/utils/__init__.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from ..parallelism_config import ParallelismConfig
15
+ from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
16
+ from .constants import (
17
+ MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
18
+ MODEL_NAME,
19
+ OPTIMIZER_NAME,
20
+ PROFILE_PATTERN_NAME,
21
+ RNG_STATE_NAME,
22
+ SAFE_MODEL_NAME,
23
+ SAFE_WEIGHTS_INDEX_NAME,
24
+ SAFE_WEIGHTS_NAME,
25
+ SAFE_WEIGHTS_PATTERN_NAME,
26
+ SAMPLER_NAME,
27
+ SCALER_NAME,
28
+ SCHEDULER_NAME,
29
+ TORCH_DISTRIBUTED_OPERATION_TYPES,
30
+ TORCH_LAUNCH_PARAMS,
31
+ WEIGHTS_INDEX_NAME,
32
+ WEIGHTS_NAME,
33
+ WEIGHTS_PATTERN_NAME,
34
+ XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
35
+ )
36
+ from .dataclasses import (
37
+ AORecipeKwargs,
38
+ AutocastKwargs,
39
+ BnbQuantizationConfig,
40
+ ComputeEnvironment,
41
+ CustomDtype,
42
+ DataLoaderConfiguration,
43
+ DDPCommunicationHookType,
44
+ DeepSpeedPlugin,
45
+ DeepSpeedSequenceParallelConfig,
46
+ DistributedDataParallelKwargs,
47
+ DistributedType,
48
+ DynamoBackend,
49
+ FP8RecipeKwargs,
50
+ FullyShardedDataParallelPlugin,
51
+ GradientAccumulationPlugin,
52
+ GradScalerKwargs,
53
+ InitProcessGroupKwargs,
54
+ KwargsHandler,
55
+ LoggerType,
56
+ MegatronLMPlugin,
57
+ MSAMPRecipeKwargs,
58
+ PrecisionType,
59
+ ProfileKwargs,
60
+ ProjectConfiguration,
61
+ RNGType,
62
+ SageMakerDistributedType,
63
+ TensorInformation,
64
+ TERecipeKwargs,
65
+ TorchContextParallelConfig,
66
+ TorchDynamoPlugin,
67
+ TorchTensorParallelConfig,
68
+ TorchTensorParallelPlugin,
69
+ add_model_config_to_megatron_parser,
70
+ )
71
+ from .environment import (
72
+ are_libraries_initialized,
73
+ check_cuda_fp8_capability,
74
+ check_cuda_p2p_ib_support,
75
+ clear_environment,
76
+ convert_dict_to_env_variables,
77
+ get_cpu_distributed_information,
78
+ get_current_device_type,
79
+ get_gpu_info,
80
+ get_int_from_env,
81
+ parse_choice_from_env,
82
+ parse_flag_from_env,
83
+ patch_environment,
84
+ purge_accelerate_environment,
85
+ set_numa_affinity,
86
+ str_to_bool,
87
+ )
88
+ from .imports import (
89
+ deepspeed_required,
90
+ is_4bit_bnb_available,
91
+ is_8bit_bnb_available,
92
+ is_aim_available,
93
+ is_bf16_available,
94
+ is_bitsandbytes_multi_backend_available,
95
+ is_bnb_available,
96
+ is_boto3_available,
97
+ is_clearml_available,
98
+ is_comet_ml_available,
99
+ is_cuda_available,
100
+ is_datasets_available,
101
+ is_deepspeed_available,
102
+ is_dvclive_available,
103
+ is_fp8_available,
104
+ is_fp16_available,
105
+ is_habana_gaudi1,
106
+ is_hpu_available,
107
+ is_import_timer_available,
108
+ is_lomo_available,
109
+ is_matplotlib_available,
110
+ is_megatron_lm_available,
111
+ is_mlflow_available,
112
+ is_mlu_available,
113
+ is_mps_available,
114
+ is_msamp_available,
115
+ is_musa_available,
116
+ is_neuron_available,
117
+ is_npu_available,
118
+ is_pandas_available,
119
+ is_peft_available,
120
+ is_pippy_available,
121
+ is_pynvml_available,
122
+ is_pytest_available,
123
+ is_rich_available,
124
+ is_sagemaker_available,
125
+ is_schedulefree_available,
126
+ is_sdaa_available,
127
+ is_swanlab_available,
128
+ is_tensorboard_available,
129
+ is_timm_available,
130
+ is_torch_xla_available,
131
+ is_torchao_available,
132
+ is_torchdata_available,
133
+ is_torchdata_stateful_dataloader_available,
134
+ is_torchvision_available,
135
+ is_trackio_available,
136
+ is_transformer_engine_available,
137
+ is_transformer_engine_mxfp8_available,
138
+ is_transformers_available,
139
+ is_triton_available,
140
+ is_wandb_available,
141
+ is_weights_only_available,
142
+ is_xccl_available,
143
+ is_xpu_available,
144
+ torchao_required,
145
+ )
146
+ from .modeling import (
147
+ align_module_device,
148
+ calculate_maximum_sizes,
149
+ check_device_map,
150
+ check_tied_parameters_in_config,
151
+ check_tied_parameters_on_same_device,
152
+ compute_module_sizes,
153
+ convert_file_size_to_int,
154
+ dtype_byte_size,
155
+ find_tied_parameters,
156
+ get_balanced_memory,
157
+ get_grad_scaler,
158
+ get_max_layer_size,
159
+ get_max_memory,
160
+ get_mixed_precision_context_manager,
161
+ has_offloaded_params,
162
+ id_tensor_storage,
163
+ infer_auto_device_map,
164
+ is_peft_model,
165
+ load_checkpoint_in_model,
166
+ load_offloaded_weights,
167
+ load_state_dict,
168
+ named_module_tensors,
169
+ retie_parameters,
170
+ set_module_tensor_to_device,
171
+ )
172
+ from .offload import (
173
+ OffloadedWeightsLoader,
174
+ PrefixedDataset,
175
+ extract_submodules_state_dict,
176
+ load_offloaded_weight,
177
+ offload_state_dict,
178
+ offload_weight,
179
+ save_offload_index,
180
+ )
181
+ from .operations import (
182
+ CannotPadNestedTensorWarning,
183
+ GatheredParameters,
184
+ broadcast,
185
+ broadcast_object_list,
186
+ concatenate,
187
+ convert_outputs_to_fp32,
188
+ convert_to_fp32,
189
+ copy_tensor_to_devices,
190
+ find_batch_size,
191
+ find_device,
192
+ gather,
193
+ gather_object,
194
+ get_data_structure,
195
+ honor_type,
196
+ ignorant_find_batch_size,
197
+ initialize_tensors,
198
+ is_namedtuple,
199
+ is_tensor_information,
200
+ is_torch_tensor,
201
+ listify,
202
+ pad_across_processes,
203
+ pad_input_tensors,
204
+ recursively_apply,
205
+ reduce,
206
+ send_to_device,
207
+ slice_tensors,
208
+ )
209
+ from .versions import compare_versions, is_torch_version
210
+
211
+
212
+ if is_deepspeed_available():
213
+ from .deepspeed import (
214
+ DeepSpeedEngineWrapper,
215
+ DeepSpeedOptimizerWrapper,
216
+ DeepSpeedSchedulerWrapper,
217
+ DummyOptim,
218
+ DummyScheduler,
219
+ HfDeepSpeedConfig,
220
+ get_active_deepspeed_plugin,
221
+ map_pytorch_optim_to_deepspeed,
222
+ )
223
+
224
+ from .bnb import has_4bit_bnb_layers, load_and_quantize_model
225
+ from .fsdp_utils import (
226
+ disable_fsdp_ram_efficient_loading,
227
+ enable_fsdp_ram_efficient_loading,
228
+ ensure_weights_retied,
229
+ fsdp2_apply_ac,
230
+ fsdp2_canonicalize_names,
231
+ fsdp2_load_full_state_dict,
232
+ fsdp2_prepare_model,
233
+ fsdp2_switch_optimizer_parameters,
234
+ get_fsdp2_grad_scaler,
235
+ load_fsdp_model,
236
+ load_fsdp_optimizer,
237
+ merge_fsdp_weights,
238
+ save_fsdp_model,
239
+ save_fsdp_optimizer,
240
+ )
241
+ from .launch import (
242
+ PrepareForLaunch,
243
+ _filter_args,
244
+ prepare_deepspeed_cmd_env,
245
+ prepare_multi_gpu_env,
246
+ prepare_sagemager_args_inputs,
247
+ prepare_simple_launcher_cmd_env,
248
+ prepare_tpu,
249
+ )
250
+
251
+ # For docs
252
+ from .megatron_lm import (
253
+ AbstractTrainStep,
254
+ BertTrainStep,
255
+ GPTTrainStep,
256
+ MegatronLMDummyDataLoader,
257
+ MegatronLMDummyScheduler,
258
+ T5TrainStep,
259
+ avg_losses_across_data_parallel_group,
260
+ )
261
+
262
+
263
+ if is_megatron_lm_available():
264
+ from .megatron_lm import (
265
+ MegatronEngine,
266
+ MegatronLMOptimizerWrapper,
267
+ MegatronLMSchedulerWrapper,
268
+ gather_across_data_parallel_groups,
269
+ )
270
+ from .megatron_lm import initialize as megatron_lm_initialize
271
+ from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
272
+ from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
273
+ from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
274
+ from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
275
+ from .memory import find_executable_batch_size, release_memory
276
+ from .other import (
277
+ check_os_kernel,
278
+ clean_state_dict_for_safetensors,
279
+ compile_regions,
280
+ compile_regions_deepspeed,
281
+ convert_bytes,
282
+ extract_model_from_parallel,
283
+ get_module_children_bottom_up,
284
+ get_pretty_name,
285
+ has_compiled_regions,
286
+ is_compiled_module,
287
+ is_port_in_use,
288
+ load,
289
+ merge_dicts,
290
+ model_has_dtensor,
291
+ recursive_getattr,
292
+ save,
293
+ wait_for_everyone,
294
+ write_basic_config,
295
+ )
296
+ from .random import set_seed, synchronize_rng_state, synchronize_rng_states
297
+ from .torch_xla import install_xla
298
+ from .tqdm import tqdm
299
+ from .transformer_engine import (
300
+ apply_fp8_autowrap,
301
+ contextual_fp8_autocast,
302
+ convert_model,
303
+ has_transformer_engine_layers,
304
+ )
accelerate/utils/ao.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Needed utilities for torchao FP8 training.
17
+ """
18
+
19
+ from functools import partial
20
+ from typing import TYPE_CHECKING, Callable, Optional
21
+
22
+ import torch
23
+
24
+ from .imports import is_torchao_available, torchao_required
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ if is_torchao_available():
29
+ from torchao.float8.float8_linear import Float8LinearConfig
30
+
31
+
32
+ def find_first_last_linear_layers(model: torch.nn.Module):
33
+ """
34
+ Finds the first and last linear layer names in a model.
35
+
36
+ This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
37
+
38
+ Ref: https://x.com/xariusrke/status/1826669142604141052
39
+ """
40
+ first_linear, last_linear = None, None
41
+ for name, module in model.named_modules():
42
+ if isinstance(module, torch.nn.Linear):
43
+ if first_linear is None:
44
+ first_linear = name
45
+ last_linear = name
46
+ return first_linear, last_linear
47
+
48
+
49
+ def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:
50
+ """
51
+ A function which will check if `module` is:
52
+ - a `torch.nn.Linear` layer
53
+ - has in_features and out_features divisible by 16
54
+ - is not part of `layers_to_filter`
55
+
56
+ Args:
57
+ module (`torch.nn.Module`):
58
+ The module to check.
59
+ fqn (`str`):
60
+ The fully qualified name of the layer.
61
+ layers_to_filter (`List[str]`):
62
+ The list of layers to filter.
63
+ """
64
+ if isinstance(module, torch.nn.Linear):
65
+ if module.in_features % 16 != 0 or module.out_features % 16 != 0:
66
+ return False
67
+ if fqn in layers_to_filter:
68
+ return False
69
+ return True
70
+
71
+
72
+ def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
73
+ """
74
+ A filter function which will filter out all linear layers except the first and last.
75
+
76
+ <Tip>
77
+
78
+ For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
79
+ converging properly
80
+
81
+ </Tip>
82
+
83
+ Args:
84
+ module (`torch.nn.Module`):
85
+ The module to check.
86
+ fqn (`str`):
87
+ The fully qualified name of the layer.
88
+ """
89
+ first_linear, last_linear = find_first_last_linear_layers(module)
90
+ return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])
91
+
92
+
93
+ @torchao_required
94
+ def has_ao_layers(model: torch.nn.Module):
95
+ from torchao.float8.float8_linear import Float8Linear
96
+
97
+ for name, module in model.named_modules():
98
+ if isinstance(module, Float8Linear):
99
+ return True
100
+ return False
101
+
102
+
103
+ @torchao_required
104
+ def convert_model_to_fp8_ao(
105
+ model: torch.nn.Module,
106
+ config: Optional["Float8LinearConfig"] = None,
107
+ module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
108
+ ):
109
+ """
110
+ Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
111
+
112
+ Args:
113
+ model (`torch.nn.Module`):
114
+ The model to convert.
115
+ config (`torchao.float8.Float8LinearConfig`, *optional*):
116
+ The configuration for the FP8 training. Recommended to utilize
117
+ `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
118
+ sufficient (what is passed when set to `None`).
119
+ module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
120
+ Optional function that must take in a module and layer name, and returns a boolean indicating whether the
121
+ module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
122
+
123
+ Example:
124
+
125
+ ```python
126
+ from accelerate.utils.ao import convert_model_to_fp8_ao
127
+ from accelerate import Accelerator
128
+
129
+ accelerator = Accelerator(
130
+
131
+ model = MyModel()
132
+ model.to(accelerator.device)
133
+ convert_to_float8_training(model)
134
+
135
+ model.train()
136
+ ```
137
+ """
138
+ from torchao.float8 import convert_to_float8_training
139
+
140
+ first_linear, last_linear = find_first_last_linear_layers(model)
141
+ if module_filter_func is None:
142
+ module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])
143
+ convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)
accelerate/utils/bnb.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+ import os
18
+ from copy import deepcopy
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from accelerate.utils.imports import (
25
+ is_4bit_bnb_available,
26
+ is_8bit_bnb_available,
27
+ )
28
+
29
+ from ..big_modeling import dispatch_model, init_empty_weights
30
+ from .dataclasses import BnbQuantizationConfig
31
+ from .modeling import (
32
+ find_tied_parameters,
33
+ get_balanced_memory,
34
+ infer_auto_device_map,
35
+ load_checkpoint_in_model,
36
+ offload_weight,
37
+ set_module_tensor_to_device,
38
+ )
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def load_and_quantize_model(
45
+ model: torch.nn.Module,
46
+ bnb_quantization_config: BnbQuantizationConfig,
47
+ weights_location: Optional[Union[str, os.PathLike]] = None,
48
+ device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
49
+ no_split_module_classes: Optional[list[str]] = None,
50
+ max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
51
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
52
+ offload_state_dict: bool = False,
53
+ ):
54
+ """
55
+ This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the
56
+ model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the
57
+ model is already loaded, we will quantize the model and put the model on the GPU,
58
+
59
+ Args:
60
+ model (`torch.nn.Module`):
61
+ Input model. The model can be already loaded or on the meta device
62
+ bnb_quantization_config (`BnbQuantizationConfig`):
63
+ The bitsandbytes quantization parameters
64
+ weights_location (`str` or `os.PathLike`):
65
+ The folder weights_location to load. It can be:
66
+ - a path to a file containing a whole model state dict
67
+ - a path to a `.json` file containing the index to a sharded checkpoint
68
+ - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
69
+ - a path to a folder containing a unique pytorch_model.bin file.
70
+ device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
71
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
72
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
73
+ no_split_module_classes (`List[str]`, *optional*):
74
+ A list of layer class names that should never be split across device (for instance any layer that has a
75
+ residual connection).
76
+ max_memory (`Dict`, *optional*):
77
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
78
+ offload_folder (`str` or `os.PathLike`, *optional*):
79
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
80
+ offload_state_dict (`bool`, *optional*, defaults to `False`):
81
+ If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
82
+ the weight of the CPU state dict + the biggest shard does not fit.
83
+
84
+ Returns:
85
+ `torch.nn.Module`: The quantized model
86
+ """
87
+
88
+ load_in_4bit = bnb_quantization_config.load_in_4bit
89
+ load_in_8bit = bnb_quantization_config.load_in_8bit
90
+
91
+ if load_in_8bit and not is_8bit_bnb_available():
92
+ raise ImportError(
93
+ "You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
94
+ " make sure you have the latest version of `bitsandbytes` installed."
95
+ )
96
+ if load_in_4bit and not is_4bit_bnb_available():
97
+ raise ValueError(
98
+ "You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
99
+ "make sure you have the latest version of `bitsandbytes` installed."
100
+ )
101
+
102
+ modules_on_cpu = []
103
+ # custom device map
104
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
105
+ modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
106
+
107
+ # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
108
+ if bnb_quantization_config.skip_modules is None:
109
+ bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
110
+
111
+ # add cpu modules to skip modules only for 4-bit modules
112
+ if load_in_4bit:
113
+ bnb_quantization_config.skip_modules.extend(modules_on_cpu)
114
+ modules_to_not_convert = bnb_quantization_config.skip_modules
115
+
116
+ # We add the modules we want to keep in full precision
117
+ if bnb_quantization_config.keep_in_fp32_modules is None:
118
+ bnb_quantization_config.keep_in_fp32_modules = []
119
+ keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
120
+ modules_to_not_convert.extend(keep_in_fp32_modules)
121
+
122
+ # compatibility with peft
123
+ model.is_loaded_in_4bit = load_in_4bit
124
+ model.is_loaded_in_8bit = load_in_8bit
125
+
126
+ model_device = get_parameter_device(model)
127
+ if model_device.type != "meta":
128
+ # quantization of an already loaded model
129
+ logger.warning(
130
+ "It is not recommended to quantize a loaded model. "
131
+ "The model should be instantiated under the `init_empty_weights` context manager."
132
+ )
133
+ model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
134
+ # convert param to the right dtype
135
+ dtype = bnb_quantization_config.torch_dtype
136
+ for name, param in model.named_parameters():
137
+ if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
138
+ param.data = param.data.to(torch.float32)
139
+ elif torch.is_floating_point(param):
140
+ param.data = param.data.to(dtype)
141
+ if model_device.type == "cuda":
142
+ model.cuda(torch.cuda.current_device())
143
+ torch.cuda.empty_cache()
144
+ elif torch.cuda.is_available():
145
+ model.to(torch.cuda.current_device())
146
+ elif torch.xpu.is_available():
147
+ model.to(torch.xpu.current_device())
148
+ else:
149
+ raise RuntimeError("No GPU or Intel XPU found. A GPU or Intel XPU is needed for quantization.")
150
+ logger.info(
151
+ f"The model device type is {model_device.type}. However, gpu or intel xpu is needed for quantization."
152
+ "We move the model to it."
153
+ )
154
+ return model
155
+
156
+ elif weights_location is None:
157
+ raise RuntimeError(
158
+ f"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} "
159
+ )
160
+
161
+ else:
162
+ with init_empty_weights():
163
+ model = replace_with_bnb_layers(
164
+ model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert
165
+ )
166
+ device_map = get_quantized_model_device_map(
167
+ model,
168
+ bnb_quantization_config,
169
+ device_map,
170
+ max_memory=max_memory,
171
+ no_split_module_classes=no_split_module_classes,
172
+ )
173
+ if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
174
+ offload_state_dict = True
175
+
176
+ offload = any(x in list(device_map.values()) for x in ["cpu", "disk"])
177
+
178
+ load_checkpoint_in_model(
179
+ model,
180
+ weights_location,
181
+ device_map,
182
+ dtype=bnb_quantization_config.torch_dtype,
183
+ offload_folder=offload_folder,
184
+ offload_state_dict=offload_state_dict,
185
+ keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules,
186
+ offload_8bit_bnb=load_in_8bit and offload,
187
+ )
188
+ return dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
189
+
190
+
191
+ def get_quantized_model_device_map(
192
+ model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None
193
+ ):
194
+ if device_map is None:
195
+ if torch.cuda.is_available():
196
+ device_map = {"": torch.cuda.current_device()}
197
+ elif torch.xpu.is_available():
198
+ device_map = {"": torch.xpu.current_device()}
199
+ else:
200
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
201
+ logger.info("The device_map was not initialized.Setting device_map to `{'':torch.cuda.current_device()}`.")
202
+
203
+ if isinstance(device_map, str):
204
+ if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
205
+ raise ValueError(
206
+ "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
207
+ "'sequential'."
208
+ )
209
+
210
+ special_dtypes = {}
211
+ special_dtypes.update(
212
+ {
213
+ name: bnb_quantization_config.torch_dtype
214
+ for name, _ in model.named_parameters()
215
+ if any(m in name for m in bnb_quantization_config.skip_modules)
216
+ }
217
+ )
218
+ special_dtypes.update(
219
+ {
220
+ name: torch.float32
221
+ for name, _ in model.named_parameters()
222
+ if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules)
223
+ }
224
+ )
225
+
226
+ kwargs = {}
227
+ kwargs["special_dtypes"] = special_dtypes
228
+ kwargs["no_split_module_classes"] = no_split_module_classes
229
+ kwargs["dtype"] = bnb_quantization_config.target_dtype
230
+
231
+ # get max_memory for each device.
232
+ if device_map != "sequential":
233
+ max_memory = get_balanced_memory(
234
+ model,
235
+ low_zero=(device_map == "balanced_low_0"),
236
+ max_memory=max_memory,
237
+ **kwargs,
238
+ )
239
+
240
+ kwargs["max_memory"] = max_memory
241
+ device_map = infer_auto_device_map(model, **kwargs)
242
+
243
+ if isinstance(device_map, dict):
244
+ # check if don't have any quantized module on the cpu
245
+ modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules
246
+
247
+ device_map_without_some_modules = {
248
+ key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert
249
+ }
250
+ for device in ["cpu", "disk"]:
251
+ if device in device_map_without_some_modules.values():
252
+ if bnb_quantization_config.load_in_4bit:
253
+ raise ValueError(
254
+ """
255
+ Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
256
+ the quantized model. If you want to dispatch the model on the CPU or the disk while keeping
257
+ these modules in `torch_dtype`, you need to pass a custom `device_map` to
258
+ `load_and_quantize_model`. Check
259
+ https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk
260
+ for more details.
261
+ """
262
+ )
263
+ else:
264
+ logger.info(
265
+ "Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit"
266
+ )
267
+ del device_map_without_some_modules
268
+ return device_map
269
+
270
+
271
+ def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
272
+ """
273
+ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
274
+ modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
275
+
276
+ Parameters:
277
+ model (`torch.nn.Module`):
278
+ Input model or `torch.nn.Module` as the function is run recursively.
279
+ modules_to_not_convert (`List[str]`):
280
+ Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
281
+ numerical stability reasons.
282
+ current_key_name (`List[str]`, *optional*):
283
+ An array to track the current key of the recursion. This is used to check whether the current key (part of
284
+ it) is not in the list of modules to not convert.
285
+ """
286
+
287
+ if modules_to_not_convert is None:
288
+ modules_to_not_convert = []
289
+
290
+ model, has_been_replaced = _replace_with_bnb_layers(
291
+ model, bnb_quantization_config, modules_to_not_convert, current_key_name
292
+ )
293
+ if not has_been_replaced:
294
+ logger.warning(
295
+ "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
296
+ " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
297
+ " Please double check your model architecture, or submit an issue on github if you think this is"
298
+ " a bug."
299
+ )
300
+ return model
301
+
302
+
303
+ def _replace_with_bnb_layers(
304
+ model,
305
+ bnb_quantization_config,
306
+ modules_to_not_convert=None,
307
+ current_key_name=None,
308
+ ):
309
+ """
310
+ Private method that wraps the recursion for module replacement.
311
+
312
+ Returns the converted model and a boolean that indicates if the conversion has been successful or not.
313
+ """
314
+ # bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily
315
+ import bitsandbytes as bnb
316
+
317
+ has_been_replaced = False
318
+ for name, module in model.named_children():
319
+ if current_key_name is None:
320
+ current_key_name = []
321
+ current_key_name.append(name)
322
+ if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
323
+ # Check if the current key is not in the `modules_to_not_convert`
324
+ current_key_name_str = ".".join(current_key_name)
325
+ proceed = True
326
+ for key in modules_to_not_convert:
327
+ if (
328
+ (key in current_key_name_str) and (key + "." in current_key_name_str)
329
+ ) or key == current_key_name_str:
330
+ proceed = False
331
+ break
332
+ if proceed:
333
+ # Load bnb module with empty weight and replace ``nn.Linear` module
334
+ if bnb_quantization_config.load_in_8bit:
335
+ bnb_module = bnb.nn.Linear8bitLt(
336
+ module.in_features,
337
+ module.out_features,
338
+ module.bias is not None,
339
+ has_fp16_weights=False,
340
+ threshold=bnb_quantization_config.llm_int8_threshold,
341
+ )
342
+ elif bnb_quantization_config.load_in_4bit:
343
+ bnb_module = bnb.nn.Linear4bit(
344
+ module.in_features,
345
+ module.out_features,
346
+ module.bias is not None,
347
+ bnb_quantization_config.bnb_4bit_compute_dtype,
348
+ compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
349
+ quant_type=bnb_quantization_config.bnb_4bit_quant_type,
350
+ )
351
+ else:
352
+ raise ValueError("load_in_8bit and load_in_4bit can't be both False")
353
+ bnb_module.weight.data = module.weight.data
354
+ if module.bias is not None:
355
+ bnb_module.bias.data = module.bias.data
356
+ bnb_module.requires_grad_(False)
357
+ setattr(model, name, bnb_module)
358
+ has_been_replaced = True
359
+ if len(list(module.children())) > 0:
360
+ _, _has_been_replaced = _replace_with_bnb_layers(
361
+ module, bnb_quantization_config, modules_to_not_convert, current_key_name
362
+ )
363
+ has_been_replaced = has_been_replaced | _has_been_replaced
364
+ # Remove the last key for recursion
365
+ current_key_name.pop(-1)
366
+ return model, has_been_replaced
367
+
368
+
369
+ def get_keys_to_not_convert(model):
370
+ r"""
371
+ An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
372
+ we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
373
+ to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
374
+ int8.
375
+
376
+ Parameters:
377
+ model (`torch.nn.Module`):
378
+ Input model
379
+ """
380
+ # Create a copy of the model
381
+ with init_empty_weights():
382
+ tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
383
+
384
+ tied_params = find_tied_parameters(tied_model)
385
+ # For compatibility with Accelerate < 0.18
386
+ if isinstance(tied_params, dict):
387
+ tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
388
+ else:
389
+ tied_keys = sum(tied_params, [])
390
+ has_tied_params = len(tied_keys) > 0
391
+
392
+ # Check if it is a base model
393
+ is_base_model = False
394
+ if hasattr(model, "base_model_prefix"):
395
+ is_base_model = not hasattr(model, model.base_model_prefix)
396
+
397
+ # Ignore this for base models (BertModel, GPT2Model, etc.)
398
+ if (not has_tied_params) and is_base_model:
399
+ return []
400
+
401
+ # otherwise they have an attached head
402
+ list_modules = list(model.named_children())
403
+ list_last_module = [list_modules[-1][0]]
404
+
405
+ # add last module together with tied weights
406
+ intersection = set(list_last_module) - set(tied_keys)
407
+ list_untouched = list(set(tied_keys)) + list(intersection)
408
+
409
+ # remove ".weight" from the keys
410
+ names_to_remove = [".weight", ".bias"]
411
+ filtered_module_names = []
412
+ for name in list_untouched:
413
+ for name_to_remove in names_to_remove:
414
+ if name_to_remove in name:
415
+ name = name.replace(name_to_remove, "")
416
+ filtered_module_names.append(name)
417
+
418
+ return filtered_module_names
419
+
420
+
421
+ def has_4bit_bnb_layers(model):
422
+ """Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model"""
423
+ # bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily
424
+ import bitsandbytes as bnb
425
+
426
+ for m in model.modules():
427
+ if isinstance(m, bnb.nn.Linear4bit):
428
+ return True
429
+ return False
430
+
431
+
432
+ def get_parameter_device(parameter: nn.Module):
433
+ return next(parameter.parameters()).device
434
+
435
+
436
+ def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):
437
+ # if it is not quantized, we quantize and offload the quantized weights and the SCB stats
438
+ if fp16_statistics is None:
439
+ set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
440
+ tensor_name = param_name
441
+ module = model
442
+ if "." in tensor_name:
443
+ splits = tensor_name.split(".")
444
+ for split in splits[:-1]:
445
+ new_module = getattr(module, split)
446
+ if new_module is None:
447
+ raise ValueError(f"{module} has no attribute {split}.")
448
+ module = new_module
449
+ tensor_name = splits[-1]
450
+ # offload weights
451
+ module._parameters[tensor_name].requires_grad = False
452
+ offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
453
+ if hasattr(module._parameters[tensor_name], "SCB"):
454
+ offload_weight(
455
+ module._parameters[tensor_name].SCB,
456
+ param_name.replace("weight", "SCB"),
457
+ offload_folder,
458
+ index=offload_index,
459
+ )
460
+ else:
461
+ offload_weight(param, param_name, offload_folder, index=offload_index)
462
+ offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index)
463
+
464
+ set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size()))
accelerate/utils/constants.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import operator as op
16
+
17
+ import torch
18
+
19
+
20
+ SCALER_NAME = "scaler.pt"
21
+ MODEL_NAME = "pytorch_model"
22
+ SAFE_MODEL_NAME = "model"
23
+ RNG_STATE_NAME = "random_states"
24
+ OPTIMIZER_NAME = "optimizer"
25
+ SCHEDULER_NAME = "scheduler"
26
+ SAMPLER_NAME = "sampler"
27
+ PROFILE_PATTERN_NAME = "profile_{suffix}.json"
28
+ WEIGHTS_NAME = f"{MODEL_NAME}.bin"
29
+ WEIGHTS_PATTERN_NAME = "pytorch_model{suffix}.bin"
30
+ WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json"
31
+ SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors"
32
+ SAFE_WEIGHTS_PATTERN_NAME = "model{suffix}.safetensors"
33
+ SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json"
34
+ SAGEMAKER_PYTORCH_VERSION = "1.10.2"
35
+ SAGEMAKER_PYTHON_VERSION = "py38"
36
+ SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0"
37
+ SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"]
38
+ FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "HYBRID_SHARD_ZERO2"]
39
+ FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
40
+ FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
41
+ FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
42
+ FSDP2_STATE_DICT_TYPE = ["SHARDED_STATE_DICT", "FULL_STATE_DICT"]
43
+ FSDP_PYTORCH_VERSION = (
44
+ "2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
45
+ )
46
+ FSDP2_PYTORCH_VERSION = "2.6.0"
47
+ DTENSOR_PYTORCH_VERSION = "2.5.0"
48
+ FSDP_MODEL_NAME = "pytorch_model_fsdp"
49
+ DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
50
+ TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
51
+ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
52
+ XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
53
+ MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
54
+ BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
55
+
56
+ BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
57
+ BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
58
+ BETA_SP_AVAILABLE_DEEPSPEED_VERSION = "0.18.2"
59
+
60
+ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
61
+
62
+ # These are the args for `torch.distributed.launch` for pytorch < 1.9
63
+ TORCH_LAUNCH_PARAMS = [
64
+ "nnodes",
65
+ "nproc_per_node",
66
+ "rdzv_backend",
67
+ "rdzv_endpoint",
68
+ "rdzv_id",
69
+ "rdzv_conf",
70
+ "standalone",
71
+ "max_restarts",
72
+ "monitor_interval",
73
+ "start_method",
74
+ "role",
75
+ "module",
76
+ "m",
77
+ "no_python",
78
+ "run_path",
79
+ "log_dir",
80
+ "r",
81
+ "redirects",
82
+ "t",
83
+ "tee",
84
+ "node_rank",
85
+ "master_addr",
86
+ "master_port",
87
+ ]
88
+
89
+ CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
90
+ TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
91
+ "MULTI_NPU",
92
+ "MULTI_MLU",
93
+ "MULTI_SDAA",
94
+ "MULTI_MUSA",
95
+ "MULTI_XPU",
96
+ "MULTI_CPU",
97
+ "MULTI_HPU",
98
+ "MULTI_NEURON",
99
+ ]
100
+ SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING = (
101
+ torch.nn.Conv1d,
102
+ torch.nn.Conv2d,
103
+ torch.nn.Conv3d,
104
+ torch.nn.ConvTranspose1d,
105
+ torch.nn.ConvTranspose2d,
106
+ torch.nn.ConvTranspose3d,
107
+ torch.nn.Linear,
108
+ )