Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- accelerate/commands/__init__.py +13 -0
- accelerate/commands/accelerate_cli.py +54 -0
- accelerate/commands/config/__init__.py +52 -0
- accelerate/commands/config/cluster.py +939 -0
- accelerate/commands/config/config.py +89 -0
- accelerate/commands/config/config_args.py +252 -0
- accelerate/commands/config/config_utils.py +122 -0
- accelerate/commands/config/default.py +172 -0
- accelerate/commands/config/sagemaker.py +274 -0
- accelerate/commands/config/update.py +63 -0
- accelerate/commands/env.py +143 -0
- accelerate/commands/estimate.py +318 -0
- accelerate/commands/launch.py +1415 -0
- accelerate/commands/menu/__init__.py +14 -0
- accelerate/commands/menu/cursor.py +65 -0
- accelerate/commands/menu/helpers.py +59 -0
- accelerate/commands/menu/input.py +84 -0
- accelerate/commands/menu/keymap.py +133 -0
- accelerate/commands/menu/selection_menu.py +145 -0
- accelerate/commands/merge.py +69 -0
- accelerate/commands/test.py +65 -0
- accelerate/commands/to_fsdp2.py +172 -0
- accelerate/commands/tpu.py +157 -0
- accelerate/commands/utils.py +123 -0
- accelerate/test_utils/__init__.py +66 -0
- accelerate/test_utils/examples.py +148 -0
- accelerate/test_utils/scripts/__init__.py +13 -0
- accelerate/test_utils/scripts/external_deps/__init__.py +13 -0
- accelerate/test_utils/scripts/external_deps/test_checkpointing.py +269 -0
- accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py +131 -0
- accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py +331 -0
- accelerate/test_utils/scripts/external_deps/test_metrics.py +307 -0
- accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py +323 -0
- accelerate/test_utils/scripts/external_deps/test_performance.py +299 -0
- accelerate/test_utils/scripts/external_deps/test_pippy.py +117 -0
- accelerate/test_utils/scripts/external_deps/test_zero3_integration.py +59 -0
- accelerate/test_utils/scripts/test_cli.py +32 -0
- accelerate/test_utils/scripts/test_ddp_comm_hook.py +85 -0
- accelerate/test_utils/scripts/test_distributed_data_loop.py +429 -0
- accelerate/test_utils/scripts/test_merge_weights.py +158 -0
- accelerate/test_utils/scripts/test_notebook.py +125 -0
- accelerate/test_utils/scripts/test_ops.py +181 -0
- accelerate/test_utils/scripts/test_script.py +909 -0
- accelerate/test_utils/scripts/test_sync.py +413 -0
- accelerate/test_utils/testing.py +889 -0
- accelerate/test_utils/training.py +150 -0
- accelerate/utils/__init__.py +304 -0
- accelerate/utils/ao.py +143 -0
- accelerate/utils/bnb.py +464 -0
- accelerate/utils/constants.py +108 -0
accelerate/commands/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
accelerate/commands/accelerate_cli.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from accelerate.commands.config import get_config_parser
|
| 18 |
+
from accelerate.commands.env import env_command_parser
|
| 19 |
+
from accelerate.commands.estimate import estimate_command_parser
|
| 20 |
+
from accelerate.commands.launch import launch_command_parser
|
| 21 |
+
from accelerate.commands.merge import merge_command_parser
|
| 22 |
+
from accelerate.commands.test import test_command_parser
|
| 23 |
+
from accelerate.commands.to_fsdp2 import to_fsdp2_command_parser
|
| 24 |
+
from accelerate.commands.tpu import tpu_command_parser
|
| 25 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
parser = CustomArgumentParser("Accelerate CLI tool", usage="accelerate <command> [<args>]", allow_abbrev=False)
|
| 30 |
+
subparsers = parser.add_subparsers(help="accelerate command helpers")
|
| 31 |
+
|
| 32 |
+
# Register commands
|
| 33 |
+
get_config_parser(subparsers=subparsers)
|
| 34 |
+
estimate_command_parser(subparsers=subparsers)
|
| 35 |
+
env_command_parser(subparsers=subparsers)
|
| 36 |
+
launch_command_parser(subparsers=subparsers)
|
| 37 |
+
merge_command_parser(subparsers=subparsers)
|
| 38 |
+
tpu_command_parser(subparsers=subparsers)
|
| 39 |
+
test_command_parser(subparsers=subparsers)
|
| 40 |
+
to_fsdp2_command_parser(subparsers=subparsers)
|
| 41 |
+
|
| 42 |
+
# Let's go
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
if not hasattr(args, "func"):
|
| 46 |
+
parser.print_help()
|
| 47 |
+
exit(1)
|
| 48 |
+
|
| 49 |
+
# Run
|
| 50 |
+
args.func(args)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
accelerate/commands/config/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from .config import config_command_parser
|
| 20 |
+
from .config_args import default_config_file, load_config_from_file # noqa: F401
|
| 21 |
+
from .default import default_command_parser
|
| 22 |
+
from .update import update_command_parser
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_config_parser(subparsers=None):
|
| 26 |
+
parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
| 27 |
+
# The main config parser
|
| 28 |
+
config_parser = config_command_parser(subparsers)
|
| 29 |
+
# The subparser to add commands to
|
| 30 |
+
subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand")
|
| 31 |
+
|
| 32 |
+
# Then add other parsers with the parent parser
|
| 33 |
+
default_command_parser(subcommands, parents=[parent_parser])
|
| 34 |
+
update_command_parser(subcommands, parents=[parent_parser])
|
| 35 |
+
|
| 36 |
+
return config_parser
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
config_parser = get_config_parser()
|
| 41 |
+
args = config_parser.parse_args()
|
| 42 |
+
|
| 43 |
+
if not hasattr(args, "func"):
|
| 44 |
+
config_parser.print_help()
|
| 45 |
+
exit(1)
|
| 46 |
+
|
| 47 |
+
# Run
|
| 48 |
+
args.func(args)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
main()
|
accelerate/commands/config/cluster.py
ADDED
|
@@ -0,0 +1,939 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from ...utils import (
|
| 20 |
+
ComputeEnvironment,
|
| 21 |
+
DistributedType,
|
| 22 |
+
is_deepspeed_available,
|
| 23 |
+
is_fp8_available,
|
| 24 |
+
is_hpu_available,
|
| 25 |
+
is_mlu_available,
|
| 26 |
+
is_mps_available,
|
| 27 |
+
is_msamp_available,
|
| 28 |
+
is_musa_available,
|
| 29 |
+
is_neuron_available,
|
| 30 |
+
is_npu_available,
|
| 31 |
+
is_sdaa_available,
|
| 32 |
+
is_torchao_available,
|
| 33 |
+
is_transformer_engine_available,
|
| 34 |
+
is_transformers_available,
|
| 35 |
+
is_xpu_available,
|
| 36 |
+
)
|
| 37 |
+
from ...utils.constants import (
|
| 38 |
+
DEEPSPEED_MULTINODE_LAUNCHERS,
|
| 39 |
+
FSDP2_STATE_DICT_TYPE,
|
| 40 |
+
FSDP_AUTO_WRAP_POLICY,
|
| 41 |
+
FSDP_BACKWARD_PREFETCH,
|
| 42 |
+
FSDP_SHARDING_STRATEGY,
|
| 43 |
+
FSDP_STATE_DICT_TYPE,
|
| 44 |
+
TORCH_DYNAMO_MODES,
|
| 45 |
+
)
|
| 46 |
+
from .config_args import ClusterConfig
|
| 47 |
+
from .config_utils import (
|
| 48 |
+
DYNAMO_BACKENDS,
|
| 49 |
+
_ask_field,
|
| 50 |
+
_ask_options,
|
| 51 |
+
_convert_distributed_mode,
|
| 52 |
+
_convert_dynamo_backend,
|
| 53 |
+
_convert_fp8_backend,
|
| 54 |
+
_convert_mixed_precision,
|
| 55 |
+
_convert_yes_no_to_bool,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_cluster_input():
|
| 60 |
+
distributed_type = _ask_options(
|
| 61 |
+
"Which type of machine are you using?",
|
| 62 |
+
[
|
| 63 |
+
"No distributed training",
|
| 64 |
+
"multi-CPU",
|
| 65 |
+
"multi-XPU",
|
| 66 |
+
"multi-HPU",
|
| 67 |
+
"multi-GPU",
|
| 68 |
+
"multi-NPU",
|
| 69 |
+
"multi-MLU",
|
| 70 |
+
"multi-SDAA",
|
| 71 |
+
"multi-MUSA",
|
| 72 |
+
"multi-NEURON",
|
| 73 |
+
"TPU",
|
| 74 |
+
],
|
| 75 |
+
_convert_distributed_mode,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
machine_rank = 0
|
| 79 |
+
num_machines = 1
|
| 80 |
+
num_processes = 1
|
| 81 |
+
gpu_ids = None
|
| 82 |
+
main_process_ip = None
|
| 83 |
+
main_process_port = None
|
| 84 |
+
rdzv_backend = "static"
|
| 85 |
+
same_network = True
|
| 86 |
+
debug = False
|
| 87 |
+
|
| 88 |
+
if distributed_type in [
|
| 89 |
+
DistributedType.MULTI_GPU,
|
| 90 |
+
DistributedType.MULTI_MLU,
|
| 91 |
+
DistributedType.MULTI_SDAA,
|
| 92 |
+
DistributedType.MULTI_MUSA,
|
| 93 |
+
DistributedType.MULTI_NPU,
|
| 94 |
+
DistributedType.MULTI_XPU,
|
| 95 |
+
DistributedType.MULTI_CPU,
|
| 96 |
+
DistributedType.MULTI_HPU,
|
| 97 |
+
DistributedType.MULTI_NEURON,
|
| 98 |
+
]:
|
| 99 |
+
num_machines = _ask_field(
|
| 100 |
+
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
|
| 101 |
+
int,
|
| 102 |
+
default=1,
|
| 103 |
+
)
|
| 104 |
+
if num_machines > 1:
|
| 105 |
+
machine_rank = _ask_options(
|
| 106 |
+
"What is the rank of this machine?",
|
| 107 |
+
list(range(num_machines)),
|
| 108 |
+
int,
|
| 109 |
+
)
|
| 110 |
+
main_process_ip = _ask_field(
|
| 111 |
+
"What is the IP address of the machine that will host the main process? ",
|
| 112 |
+
)
|
| 113 |
+
main_process_port = _ask_field(
|
| 114 |
+
"What is the port you will use to communicate with the main process? ",
|
| 115 |
+
int,
|
| 116 |
+
)
|
| 117 |
+
same_network = _ask_field(
|
| 118 |
+
"Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: ",
|
| 119 |
+
_convert_yes_no_to_bool,
|
| 120 |
+
default=True,
|
| 121 |
+
error_message="Please enter yes or no.",
|
| 122 |
+
)
|
| 123 |
+
if not same_network:
|
| 124 |
+
rdzv_backend = _ask_field(
|
| 125 |
+
"What rendezvous backend will you use? ('static', 'c10d', ...): ", default="static"
|
| 126 |
+
)
|
| 127 |
+
debug = _ask_field(
|
| 128 |
+
"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
|
| 129 |
+
_convert_yes_no_to_bool,
|
| 130 |
+
default=False,
|
| 131 |
+
error_message="Please enter yes or no.",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if distributed_type == DistributedType.NO:
|
| 135 |
+
use_cpu = _ask_field(
|
| 136 |
+
"Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:",
|
| 137 |
+
_convert_yes_no_to_bool,
|
| 138 |
+
default=False,
|
| 139 |
+
error_message="Please enter yes or no.",
|
| 140 |
+
)
|
| 141 |
+
elif distributed_type == DistributedType.MULTI_CPU:
|
| 142 |
+
use_cpu = True
|
| 143 |
+
else:
|
| 144 |
+
use_cpu = False
|
| 145 |
+
|
| 146 |
+
mpirun_config = {}
|
| 147 |
+
|
| 148 |
+
if use_cpu:
|
| 149 |
+
if distributed_type == DistributedType.MULTI_CPU:
|
| 150 |
+
use_mpirun = _ask_field(
|
| 151 |
+
"Do you want accelerate to launch mpirun? [yes/NO]: ",
|
| 152 |
+
_convert_yes_no_to_bool,
|
| 153 |
+
default=False,
|
| 154 |
+
error_message="Please enter yes or no.",
|
| 155 |
+
)
|
| 156 |
+
if use_mpirun:
|
| 157 |
+
mpirun_hostfile = _ask_field(
|
| 158 |
+
"Please enter the path to the hostfile to use with mpirun [~/hostfile]: ",
|
| 159 |
+
str,
|
| 160 |
+
default="~/hostfile",
|
| 161 |
+
)
|
| 162 |
+
mpirun_config["mpirun_hostfile"] = os.path.expanduser(mpirun_hostfile.strip())
|
| 163 |
+
|
| 164 |
+
dynamo_config = {}
|
| 165 |
+
use_dynamo = _ask_field(
|
| 166 |
+
"Do you wish to optimize your script with torch dynamo?[yes/NO]:",
|
| 167 |
+
_convert_yes_no_to_bool,
|
| 168 |
+
default=False,
|
| 169 |
+
error_message="Please enter yes or no.",
|
| 170 |
+
)
|
| 171 |
+
if use_dynamo:
|
| 172 |
+
prefix = "dynamo_"
|
| 173 |
+
dynamo_config[prefix + "backend"] = _ask_options(
|
| 174 |
+
"Which dynamo backend would you like to use?",
|
| 175 |
+
[x.lower() for x in DYNAMO_BACKENDS],
|
| 176 |
+
_convert_dynamo_backend,
|
| 177 |
+
default=2,
|
| 178 |
+
)
|
| 179 |
+
use_custom_options = _ask_field(
|
| 180 |
+
"Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
|
| 181 |
+
_convert_yes_no_to_bool,
|
| 182 |
+
default=False,
|
| 183 |
+
error_message="Please enter yes or no.",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if use_custom_options:
|
| 187 |
+
dynamo_config[prefix + "mode"] = _ask_options(
|
| 188 |
+
"Which mode do you want to use?",
|
| 189 |
+
TORCH_DYNAMO_MODES,
|
| 190 |
+
lambda x: TORCH_DYNAMO_MODES[int(x)],
|
| 191 |
+
default=0,
|
| 192 |
+
)
|
| 193 |
+
dynamo_config[prefix + "use_fullgraph"] = _ask_field(
|
| 194 |
+
"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
|
| 195 |
+
_convert_yes_no_to_bool,
|
| 196 |
+
default=False,
|
| 197 |
+
error_message="Please enter yes or no.",
|
| 198 |
+
)
|
| 199 |
+
dynamo_config[prefix + "use_dynamic"] = _ask_field(
|
| 200 |
+
"Do you want to enable dynamic shape tracing? [yes/NO]: ",
|
| 201 |
+
_convert_yes_no_to_bool,
|
| 202 |
+
default=False,
|
| 203 |
+
error_message="Please enter yes or no.",
|
| 204 |
+
)
|
| 205 |
+
dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
|
| 206 |
+
"Do you want to enable regional compilation? [yes/NO]: ",
|
| 207 |
+
_convert_yes_no_to_bool,
|
| 208 |
+
default=False,
|
| 209 |
+
error_message="Please enter yes or no.",
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
use_mps = not use_cpu and is_mps_available()
|
| 213 |
+
deepspeed_config = {}
|
| 214 |
+
if (
|
| 215 |
+
distributed_type
|
| 216 |
+
in [
|
| 217 |
+
DistributedType.MULTI_GPU,
|
| 218 |
+
DistributedType.MULTI_XPU,
|
| 219 |
+
DistributedType.MULTI_HPU,
|
| 220 |
+
DistributedType.MULTI_NPU,
|
| 221 |
+
DistributedType.MULTI_MLU,
|
| 222 |
+
DistributedType.MULTI_SDAA,
|
| 223 |
+
DistributedType.MULTI_MUSA,
|
| 224 |
+
DistributedType.MULTI_NEURON,
|
| 225 |
+
DistributedType.NO,
|
| 226 |
+
]
|
| 227 |
+
and not use_mps
|
| 228 |
+
):
|
| 229 |
+
use_deepspeed = _ask_field(
|
| 230 |
+
"Do you want to use DeepSpeed? [yes/NO]: ",
|
| 231 |
+
_convert_yes_no_to_bool,
|
| 232 |
+
default=False,
|
| 233 |
+
error_message="Please enter yes or no.",
|
| 234 |
+
)
|
| 235 |
+
if use_deepspeed:
|
| 236 |
+
if distributed_type is DistributedType.MULTI_NEURON:
|
| 237 |
+
raise RuntimeError("DeepSpeed is not supported on Neuron devices.")
|
| 238 |
+
|
| 239 |
+
distributed_type = DistributedType.DEEPSPEED
|
| 240 |
+
assert is_deepspeed_available(), (
|
| 241 |
+
"DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if distributed_type == DistributedType.DEEPSPEED:
|
| 245 |
+
use_deepspeed_config = _ask_field(
|
| 246 |
+
"Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ",
|
| 247 |
+
_convert_yes_no_to_bool,
|
| 248 |
+
default=False,
|
| 249 |
+
error_message="Please enter yes or no.",
|
| 250 |
+
)
|
| 251 |
+
if use_deepspeed_config:
|
| 252 |
+
deepspeed_config["deepspeed_config_file"] = _ask_field(
|
| 253 |
+
"Please enter the path to the json DeepSpeed config file: ",
|
| 254 |
+
str,
|
| 255 |
+
default="none",
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
deepspeed_config["zero_stage"] = _ask_options(
|
| 259 |
+
"What should be your DeepSpeed's ZeRO optimization stage?",
|
| 260 |
+
[0, 1, 2, 3],
|
| 261 |
+
int,
|
| 262 |
+
default=2,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
deepspeed_devices = ["none", "cpu", "nvme"]
|
| 266 |
+
if deepspeed_config["zero_stage"] >= 2:
|
| 267 |
+
deepspeed_config["offload_optimizer_device"] = _ask_options(
|
| 268 |
+
"Where to offload optimizer states?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
|
| 269 |
+
)
|
| 270 |
+
deepspeed_config["offload_param_device"] = _ask_options(
|
| 271 |
+
"Where to offload parameters?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
|
| 272 |
+
)
|
| 273 |
+
if deepspeed_config["offload_param_device"] == "nvme":
|
| 274 |
+
deepspeed_config["offload_param_nvme_path"] = _ask_field(
|
| 275 |
+
"Nvme Path to offload parameters?",
|
| 276 |
+
str,
|
| 277 |
+
default="/nvme",
|
| 278 |
+
)
|
| 279 |
+
if deepspeed_config["offload_optimizer_device"] == "nvme":
|
| 280 |
+
deepspeed_config["offload_optimizer_nvme_path"] = _ask_field(
|
| 281 |
+
"Nvme Path to offload optimizer states?",
|
| 282 |
+
str,
|
| 283 |
+
default="/nvme",
|
| 284 |
+
)
|
| 285 |
+
deepspeed_config["gradient_accumulation_steps"] = _ask_field(
|
| 286 |
+
"How many gradient accumulation steps you're passing in your script? [1]: ",
|
| 287 |
+
int,
|
| 288 |
+
default=1,
|
| 289 |
+
)
|
| 290 |
+
use_gradient_clipping = _ask_field(
|
| 291 |
+
"Do you want to use gradient clipping? [yes/NO]: ",
|
| 292 |
+
_convert_yes_no_to_bool,
|
| 293 |
+
default=False,
|
| 294 |
+
error_message="Please enter yes or no.",
|
| 295 |
+
)
|
| 296 |
+
if use_gradient_clipping:
|
| 297 |
+
deepspeed_config["gradient_clipping"] = _ask_field(
|
| 298 |
+
"What is the gradient clipping value? [1.0]: ",
|
| 299 |
+
float,
|
| 300 |
+
default=1.0,
|
| 301 |
+
)
|
| 302 |
+
if deepspeed_config["zero_stage"] == 3:
|
| 303 |
+
deepspeed_config["zero3_save_16bit_model"] = _ask_field(
|
| 304 |
+
"Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ",
|
| 305 |
+
_convert_yes_no_to_bool,
|
| 306 |
+
default=False,
|
| 307 |
+
error_message="Please enter yes or no.",
|
| 308 |
+
)
|
| 309 |
+
deepspeed_config["zero3_init_flag"] = _ask_field(
|
| 310 |
+
"Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ",
|
| 311 |
+
_convert_yes_no_to_bool,
|
| 312 |
+
default=False,
|
| 313 |
+
error_message="Please enter yes or no.",
|
| 314 |
+
)
|
| 315 |
+
if deepspeed_config["zero3_init_flag"]:
|
| 316 |
+
if not is_transformers_available():
|
| 317 |
+
raise Exception(
|
| 318 |
+
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
|
| 319 |
+
"Please run `pip3 install transformers`."
|
| 320 |
+
)
|
| 321 |
+
use_moe = _ask_field(
|
| 322 |
+
"Do you want to enable Mixture-of-Experts training (MoE)? [yes/NO]: ",
|
| 323 |
+
_convert_yes_no_to_bool,
|
| 324 |
+
default=False,
|
| 325 |
+
error_message="Please enter yes or no.",
|
| 326 |
+
)
|
| 327 |
+
if use_moe:
|
| 328 |
+
deepspeed_config["deepspeed_moe_layer_cls_names"] = _ask_field(
|
| 329 |
+
"Specify the comma-separated list of transformers MoE layer class names (case-sensitive), e.g : "
|
| 330 |
+
" `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ... : ",
|
| 331 |
+
str,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if num_machines > 1:
|
| 335 |
+
launcher_query = "Which Type of launcher do you want to use?"
|
| 336 |
+
deepspeed_config["deepspeed_multinode_launcher"] = _ask_options(
|
| 337 |
+
launcher_query,
|
| 338 |
+
DEEPSPEED_MULTINODE_LAUNCHERS,
|
| 339 |
+
lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)],
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
|
| 343 |
+
deepspeed_config["deepspeed_hostfile"] = _ask_field(
|
| 344 |
+
"DeepSpeed configures multi-node compute resources with hostfile. "
|
| 345 |
+
"Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; "
|
| 346 |
+
"for more information please refer official [documentation]"
|
| 347 |
+
"(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). "
|
| 348 |
+
"Please specify the location of hostfile: ",
|
| 349 |
+
str,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
is_exclusion_filter = _ask_field(
|
| 353 |
+
"Do you want to specify exclusion filter string? [yes/NO]: ",
|
| 354 |
+
_convert_yes_no_to_bool,
|
| 355 |
+
default=False,
|
| 356 |
+
error_message="Please enter yes or no.",
|
| 357 |
+
)
|
| 358 |
+
if is_exclusion_filter:
|
| 359 |
+
deepspeed_config["deepspeed_exclusion_filter"] = _ask_field(
|
| 360 |
+
"DeepSpeed exclusion filter string: ",
|
| 361 |
+
str,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
is_inclusion_filter = _ask_field(
|
| 365 |
+
"Do you want to specify inclusion filter string? [yes/NO]: ",
|
| 366 |
+
_convert_yes_no_to_bool,
|
| 367 |
+
default=False,
|
| 368 |
+
error_message="Please enter yes or no.",
|
| 369 |
+
)
|
| 370 |
+
if is_inclusion_filter:
|
| 371 |
+
deepspeed_config["deepspeed_inclusion_filter"] = _ask_field(
|
| 372 |
+
"DeepSpeed inclusion filter string: ",
|
| 373 |
+
str,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
fsdp_config = {}
|
| 377 |
+
|
| 378 |
+
if distributed_type in [
|
| 379 |
+
DistributedType.MULTI_GPU,
|
| 380 |
+
DistributedType.MULTI_NPU,
|
| 381 |
+
DistributedType.MULTI_MLU,
|
| 382 |
+
DistributedType.MULTI_SDAA,
|
| 383 |
+
DistributedType.MULTI_MUSA,
|
| 384 |
+
DistributedType.MULTI_XPU,
|
| 385 |
+
DistributedType.MULTI_HPU,
|
| 386 |
+
DistributedType.MULTI_NEURON,
|
| 387 |
+
]:
|
| 388 |
+
use_fsdp = _ask_field(
|
| 389 |
+
"Do you want to use FullyShardedDataParallel? [yes/NO]: ",
|
| 390 |
+
_convert_yes_no_to_bool,
|
| 391 |
+
default=False,
|
| 392 |
+
error_message="Please enter yes or no.",
|
| 393 |
+
)
|
| 394 |
+
if use_fsdp:
|
| 395 |
+
if distributed_type is DistributedType.MULTI_NEURON:
|
| 396 |
+
raise NotImplementedError("FSDP is not currently supported on Neuron devices.")
|
| 397 |
+
distributed_type = DistributedType.FSDP
|
| 398 |
+
|
| 399 |
+
if distributed_type == DistributedType.FSDP:
|
| 400 |
+
fsdp_config["fsdp_version"] = _ask_options(
|
| 401 |
+
"What should be your FSDP version? [2]: ",
|
| 402 |
+
[1, 2],
|
| 403 |
+
lambda x: int(x) + 1,
|
| 404 |
+
default=1,
|
| 405 |
+
)
|
| 406 |
+
fsdp_version = fsdp_config["fsdp_version"] # extract to a variable to simplify usage later
|
| 407 |
+
|
| 408 |
+
if fsdp_version == 1:
|
| 409 |
+
sharding_strategy_query = "What should be your sharding strategy?"
|
| 410 |
+
fsdp_config["fsdp_reshard_after_forward"] = _ask_options(
|
| 411 |
+
sharding_strategy_query,
|
| 412 |
+
FSDP_SHARDING_STRATEGY,
|
| 413 |
+
lambda x: FSDP_SHARDING_STRATEGY[int(x)],
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
fsdp_config["fsdp_reshard_after_forward"] = _ask_field(
|
| 417 |
+
"Do you want to enable resharding after forward? [YES/no]: ",
|
| 418 |
+
_convert_yes_no_to_bool,
|
| 419 |
+
default=True,
|
| 420 |
+
error_message="Please enter yes or no.",
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
fsdp_config["fsdp_offload_params"] = _ask_field(
|
| 424 |
+
"Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
|
| 425 |
+
_convert_yes_no_to_bool,
|
| 426 |
+
default=False,
|
| 427 |
+
error_message="Please enter yes or no.",
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
fsdp_wrap_query = "What should be your auto wrap policy?"
|
| 431 |
+
fsdp_config["fsdp_auto_wrap_policy"] = _ask_options(
|
| 432 |
+
fsdp_wrap_query,
|
| 433 |
+
FSDP_AUTO_WRAP_POLICY,
|
| 434 |
+
lambda x: FSDP_AUTO_WRAP_POLICY[int(x)],
|
| 435 |
+
)
|
| 436 |
+
if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]:
|
| 437 |
+
use_no_split_modules = _ask_field(
|
| 438 |
+
"Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: ",
|
| 439 |
+
_convert_yes_no_to_bool,
|
| 440 |
+
default=False,
|
| 441 |
+
error_message="Please enter yes or no.",
|
| 442 |
+
)
|
| 443 |
+
if not use_no_split_modules:
|
| 444 |
+
fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field(
|
| 445 |
+
"Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :"
|
| 446 |
+
"`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : ",
|
| 447 |
+
str,
|
| 448 |
+
)
|
| 449 |
+
elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]:
|
| 450 |
+
fsdp_config["fsdp_min_num_params"] = _ask_field(
|
| 451 |
+
"What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ",
|
| 452 |
+
int,
|
| 453 |
+
default=100000000,
|
| 454 |
+
)
|
| 455 |
+
# Removed in FSDP2, ask for user input for FSDP1
|
| 456 |
+
if fsdp_version == 1:
|
| 457 |
+
fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?"
|
| 458 |
+
fsdp_config["fsdp_backward_prefetch"] = _ask_options(
|
| 459 |
+
fsdp_backward_prefetch_query,
|
| 460 |
+
FSDP_BACKWARD_PREFETCH,
|
| 461 |
+
lambda x: FSDP_BACKWARD_PREFETCH[int(x)],
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
fsdp_state_dict_type_query = "What should be your FSDP's state dict type?"
|
| 465 |
+
fsdp_config["fsdp_state_dict_type"] = _ask_options(
|
| 466 |
+
fsdp_state_dict_type_query,
|
| 467 |
+
FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE,
|
| 468 |
+
lambda x: FSDP_STATE_DICT_TYPE[int(x)] if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE[int(x)],
|
| 469 |
+
default=0,
|
| 470 |
+
)
|
| 471 |
+
# Not implemented in FSDP2, ask for user input for FSDP1
|
| 472 |
+
if fsdp_version == 1:
|
| 473 |
+
fsdp_config["fsdp_forward_prefetch"] = _ask_field(
|
| 474 |
+
"Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ",
|
| 475 |
+
_convert_yes_no_to_bool,
|
| 476 |
+
default=False,
|
| 477 |
+
error_message="Please enter yes or no.",
|
| 478 |
+
)
|
| 479 |
+
# Obsolete in FSDP2, ask for user input for FSDP1
|
| 480 |
+
if fsdp_version == 1:
|
| 481 |
+
fsdp_config["fsdp_use_orig_params"] = _ask_field(
|
| 482 |
+
"Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ",
|
| 483 |
+
_convert_yes_no_to_bool,
|
| 484 |
+
default=True,
|
| 485 |
+
error_message="Please enter yes or no.",
|
| 486 |
+
)
|
| 487 |
+
fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field(
|
| 488 |
+
"Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ",
|
| 489 |
+
_convert_yes_no_to_bool,
|
| 490 |
+
default=True,
|
| 491 |
+
error_message="Please enter yes or no.",
|
| 492 |
+
)
|
| 493 |
+
# Obsolete in FSDP2, ask for user input for FSDP1
|
| 494 |
+
if fsdp_version == 1:
|
| 495 |
+
if fsdp_config["fsdp_cpu_ram_efficient_loading"]:
|
| 496 |
+
fsdp_config["fsdp_sync_module_states"] = True
|
| 497 |
+
else:
|
| 498 |
+
fsdp_config["fsdp_sync_module_states"] = _ask_field(
|
| 499 |
+
"Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
|
| 500 |
+
_convert_yes_no_to_bool,
|
| 501 |
+
default=True,
|
| 502 |
+
error_message="Please enter yes or no.",
|
| 503 |
+
)
|
| 504 |
+
fsdp_config["fsdp_activation_checkpointing"] = _ask_field(
|
| 505 |
+
"Do you want to enable FSDP activation checkpointing? [yes/NO]: ",
|
| 506 |
+
_convert_yes_no_to_bool,
|
| 507 |
+
default=False,
|
| 508 |
+
error_message="Please enter yes or no.",
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
parallelism_config = {}
|
| 512 |
+
|
| 513 |
+
if fsdp_config.get("fsdp_version", 1) == 2:
|
| 514 |
+
use_parallelism_config = _ask_field(
|
| 515 |
+
"Do you want to use the parallelism config? [yes/NO]: ",
|
| 516 |
+
_convert_yes_no_to_bool,
|
| 517 |
+
default=False,
|
| 518 |
+
error_message="Please enter yes or no.",
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if use_parallelism_config:
|
| 522 |
+
prefix = "parallelism_config_"
|
| 523 |
+
parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
|
| 524 |
+
"What is the data parallelism replicate size? [1]: ",
|
| 525 |
+
int,
|
| 526 |
+
default=1,
|
| 527 |
+
error_message="Please enter an integer.",
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
parallelism_config[prefix + "dp_shard_size"] = _ask_field(
|
| 531 |
+
"What is the FSDP shard size? [1]: ",
|
| 532 |
+
int,
|
| 533 |
+
default=1,
|
| 534 |
+
error_message="Please enter an integer.",
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
parallelism_config[prefix + "tp_size"] = _ask_field(
|
| 538 |
+
"What is the tensor parallelism size? [1]: ",
|
| 539 |
+
int,
|
| 540 |
+
default=1,
|
| 541 |
+
error_message="Please enter an integer.",
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
parallelism_config[prefix + "cp_size"] = _ask_field(
|
| 545 |
+
"What is the context parallelism size? [1]: ",
|
| 546 |
+
int,
|
| 547 |
+
default=1,
|
| 548 |
+
error_message="Please enter an integer.",
|
| 549 |
+
)
|
| 550 |
+
if parallelism_config[prefix + "cp_size"] > 1:
|
| 551 |
+
parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
|
| 552 |
+
"What is the compute parallelism communication strategy?",
|
| 553 |
+
["allgather", "alltoall"],
|
| 554 |
+
lambda x: ["allgather", "alltoall"][int(x)],
|
| 555 |
+
default=0,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
megatron_lm_config = {}
|
| 559 |
+
if distributed_type in [DistributedType.MULTI_GPU]:
|
| 560 |
+
use_megatron_lm = _ask_field(
|
| 561 |
+
"Do you want to use Megatron-LM ? [yes/NO]: ",
|
| 562 |
+
_convert_yes_no_to_bool,
|
| 563 |
+
default=False,
|
| 564 |
+
error_message="Please enter yes or no.",
|
| 565 |
+
)
|
| 566 |
+
if use_megatron_lm:
|
| 567 |
+
distributed_type = DistributedType.MEGATRON_LM
|
| 568 |
+
if distributed_type == DistributedType.MEGATRON_LM:
|
| 569 |
+
prefix = "megatron_lm_"
|
| 570 |
+
megatron_lm_config[prefix + "tp_degree"] = _ask_field(
|
| 571 |
+
"What is the Tensor Parallelism degree/size? [1]:",
|
| 572 |
+
int,
|
| 573 |
+
default=1,
|
| 574 |
+
error_message="Please enter an integer.",
|
| 575 |
+
)
|
| 576 |
+
if megatron_lm_config[prefix + "tp_degree"] > 1:
|
| 577 |
+
megatron_lm_config[prefix + "sequence_parallelism"] = _ask_field(
|
| 578 |
+
"Do you want to enable Sequence Parallelism? [YES/no]: ",
|
| 579 |
+
_convert_yes_no_to_bool,
|
| 580 |
+
default=True,
|
| 581 |
+
error_message="Please enter yes or no.",
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
megatron_lm_config[prefix + "pp_degree"] = _ask_field(
|
| 585 |
+
"What is the Pipeline Parallelism degree/size? [1]:",
|
| 586 |
+
int,
|
| 587 |
+
default=1,
|
| 588 |
+
error_message="Please enter an integer.",
|
| 589 |
+
)
|
| 590 |
+
if megatron_lm_config[prefix + "pp_degree"] > 1:
|
| 591 |
+
megatron_lm_config[prefix + "num_micro_batches"] = _ask_field(
|
| 592 |
+
"What is the number of micro-batches? [1]:",
|
| 593 |
+
int,
|
| 594 |
+
default=1,
|
| 595 |
+
error_message="Please enter an integer.",
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
megatron_lm_config[prefix + "recompute_activations"] = _ask_field(
|
| 599 |
+
"Do you want to enable selective activation recomputation? [YES/no]: ",
|
| 600 |
+
_convert_yes_no_to_bool,
|
| 601 |
+
default=True,
|
| 602 |
+
error_message="Please enter yes or no.",
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
megatron_lm_config[prefix + "use_distributed_optimizer"] = _ask_field(
|
| 606 |
+
"Do you want to use distributed optimizer "
|
| 607 |
+
"which shards optimizer state and gradients across data parallel ranks? [YES/no]: ",
|
| 608 |
+
_convert_yes_no_to_bool,
|
| 609 |
+
default=True,
|
| 610 |
+
error_message="Please enter yes or no.",
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
megatron_lm_config[prefix + "gradient_clipping"] = _ask_field(
|
| 614 |
+
"What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: ",
|
| 615 |
+
float,
|
| 616 |
+
default=1.0,
|
| 617 |
+
)
|
| 618 |
+
# TPU specific defaults
|
| 619 |
+
tpu_commands = None
|
| 620 |
+
tpu_command_file = None
|
| 621 |
+
tpu_downcast_bf16 = "no"
|
| 622 |
+
tpu_env = []
|
| 623 |
+
tpu_name = None
|
| 624 |
+
tpu_vm = None
|
| 625 |
+
tpu_zone = None
|
| 626 |
+
tpu_use_sudo = False
|
| 627 |
+
tpu_use_cluster = False
|
| 628 |
+
|
| 629 |
+
if distributed_type in [
|
| 630 |
+
DistributedType.MULTI_CPU,
|
| 631 |
+
DistributedType.MULTI_XPU,
|
| 632 |
+
DistributedType.MULTI_HPU,
|
| 633 |
+
DistributedType.MULTI_GPU,
|
| 634 |
+
DistributedType.MULTI_MLU,
|
| 635 |
+
DistributedType.MULTI_SDAA,
|
| 636 |
+
DistributedType.MULTI_MUSA,
|
| 637 |
+
DistributedType.MULTI_NPU,
|
| 638 |
+
DistributedType.MULTI_NEURON,
|
| 639 |
+
DistributedType.XLA,
|
| 640 |
+
]:
|
| 641 |
+
machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
|
| 642 |
+
if machine_type in ["TPU", "NEURON"]:
|
| 643 |
+
machine_type += " cores"
|
| 644 |
+
elif machine_type == "CPU":
|
| 645 |
+
machine_type = "processes"
|
| 646 |
+
else:
|
| 647 |
+
machine_type += "(s)"
|
| 648 |
+
num_processes = _ask_field(
|
| 649 |
+
f"How many {machine_type} should be used for distributed training? [1]:",
|
| 650 |
+
int,
|
| 651 |
+
default=1,
|
| 652 |
+
error_message="Please enter an integer.",
|
| 653 |
+
)
|
| 654 |
+
elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
|
| 655 |
+
num_processes = _ask_field(
|
| 656 |
+
"How many GPU(s) should be used for distributed training? [1]:",
|
| 657 |
+
int,
|
| 658 |
+
default=1,
|
| 659 |
+
error_message="Please enter an integer.",
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
num_processes = 1
|
| 663 |
+
|
| 664 |
+
if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1):
|
| 665 |
+
raise ValueError(
|
| 666 |
+
f"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using."
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
if (
|
| 670 |
+
distributed_type
|
| 671 |
+
in [
|
| 672 |
+
DistributedType.MULTI_GPU,
|
| 673 |
+
DistributedType.MULTI_MLU,
|
| 674 |
+
DistributedType.MULTI_SDAA,
|
| 675 |
+
DistributedType.MULTI_MUSA,
|
| 676 |
+
DistributedType.MULTI_NPU,
|
| 677 |
+
DistributedType.MULTI_XPU,
|
| 678 |
+
DistributedType.MULTI_HPU,
|
| 679 |
+
DistributedType.MULTI_NEURON,
|
| 680 |
+
DistributedType.NO,
|
| 681 |
+
]
|
| 682 |
+
and not use_cpu
|
| 683 |
+
and not use_mps
|
| 684 |
+
):
|
| 685 |
+
if is_npu_available():
|
| 686 |
+
machine_type = "NPU(s)"
|
| 687 |
+
elif is_mlu_available():
|
| 688 |
+
machine_type = "MLU(s)"
|
| 689 |
+
elif is_sdaa_available():
|
| 690 |
+
machine_type = "SDAA(s)"
|
| 691 |
+
elif is_musa_available():
|
| 692 |
+
machine_type = "MUSA(s)"
|
| 693 |
+
elif is_xpu_available():
|
| 694 |
+
machine_type = "XPU(s)"
|
| 695 |
+
elif is_hpu_available():
|
| 696 |
+
machine_type = "HPU(s)"
|
| 697 |
+
elif is_neuron_available():
|
| 698 |
+
machine_type = "Neuron cores"
|
| 699 |
+
else:
|
| 700 |
+
machine_type = "GPU(s)"
|
| 701 |
+
gpu_ids = _ask_field(
|
| 702 |
+
f"What {machine_type} (by id) should be used for training on this machine as a comma-separated list? [all]:",
|
| 703 |
+
default="all",
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
# CPU affinity is only supported on NVIDIA hardware for now
|
| 707 |
+
enable_cpu_affinity = False
|
| 708 |
+
if distributed_type in (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps:
|
| 709 |
+
enable_cpu_affinity = _ask_field(
|
| 710 |
+
"Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: ",
|
| 711 |
+
_convert_yes_no_to_bool,
|
| 712 |
+
default=False,
|
| 713 |
+
error_message="Please enter yes or no.",
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
fp8_config = None
|
| 717 |
+
if distributed_type == DistributedType.XLA:
|
| 718 |
+
mixed_precision = "no"
|
| 719 |
+
main_training_function = _ask_field(
|
| 720 |
+
"What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
|
| 721 |
+
default="main",
|
| 722 |
+
)
|
| 723 |
+
tpu_use_cluster = _ask_field(
|
| 724 |
+
"Are you using a TPU cluster? [yes/NO]: ",
|
| 725 |
+
_convert_yes_no_to_bool,
|
| 726 |
+
default=False,
|
| 727 |
+
error_message="Please enter yes or no.",
|
| 728 |
+
)
|
| 729 |
+
if tpu_use_cluster:
|
| 730 |
+
tpu_name = _ask_field(
|
| 731 |
+
"What is the name of your TPU cluster? ",
|
| 732 |
+
default=None,
|
| 733 |
+
error_message="Please enter the name of your TPU cluster.",
|
| 734 |
+
)
|
| 735 |
+
tpu_zone = _ask_field(
|
| 736 |
+
"What is the zone of your TPU cluster? ",
|
| 737 |
+
default=None,
|
| 738 |
+
error_message="Please enter the zone of your TPU cluster.",
|
| 739 |
+
)
|
| 740 |
+
tpu_use_sudo = _ask_field(
|
| 741 |
+
"To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: ",
|
| 742 |
+
default=False,
|
| 743 |
+
error_message="Please enter yes or no.",
|
| 744 |
+
)
|
| 745 |
+
run_commands = _ask_field(
|
| 746 |
+
"Do you have code you wish to run on startup in each pod? [yes/NO]: ",
|
| 747 |
+
_convert_yes_no_to_bool,
|
| 748 |
+
default=False,
|
| 749 |
+
error_message="Please enter yes or no.",
|
| 750 |
+
)
|
| 751 |
+
if run_commands:
|
| 752 |
+
use_command_file = _ask_field(
|
| 753 |
+
"Is this code located in a bash script? [yes/NO]: ",
|
| 754 |
+
_convert_yes_no_to_bool,
|
| 755 |
+
default=False,
|
| 756 |
+
error_message="Please enter yes or no.",
|
| 757 |
+
)
|
| 758 |
+
if use_command_file:
|
| 759 |
+
tpu_command_file = _ask_field(
|
| 760 |
+
"What is the path to your bash script? ",
|
| 761 |
+
default=None,
|
| 762 |
+
error_message="Please enter the path to your bash script.",
|
| 763 |
+
)
|
| 764 |
+
tpu_command_file = os.path.abspath(tpu_command_file)
|
| 765 |
+
else:
|
| 766 |
+
print("Please enter each command separately you wish to run on startup in each pod.")
|
| 767 |
+
tpu_commands = []
|
| 768 |
+
another_command = True
|
| 769 |
+
while another_command:
|
| 770 |
+
tpu_commands.append(
|
| 771 |
+
_ask_field(
|
| 772 |
+
"Please enter a single command to be ran ",
|
| 773 |
+
default=None,
|
| 774 |
+
error_message="Please enter the commands you wish to run on startup in each pod as a single string.",
|
| 775 |
+
)
|
| 776 |
+
)
|
| 777 |
+
another_command = _ask_field(
|
| 778 |
+
"Do you wish to add another command? [yes/NO]: ",
|
| 779 |
+
_convert_yes_no_to_bool,
|
| 780 |
+
default=False,
|
| 781 |
+
error_message="Please enter yes or no.",
|
| 782 |
+
)
|
| 783 |
+
tpu_vm = _ask_field(
|
| 784 |
+
"If not using an instance group, what are the names of the Compute VM instances to be used, separated by a comma: ",
|
| 785 |
+
default="",
|
| 786 |
+
).split(",")
|
| 787 |
+
tpu_env = _ask_field(
|
| 788 |
+
"What environment variables do you wish to set in each pod, separated by a comma: ",
|
| 789 |
+
default="",
|
| 790 |
+
).split(",")
|
| 791 |
+
|
| 792 |
+
else:
|
| 793 |
+
main_training_function = "main"
|
| 794 |
+
if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
|
| 795 |
+
mixed_precision = None
|
| 796 |
+
else:
|
| 797 |
+
mixed_precision = _ask_options(
|
| 798 |
+
"Do you wish to use mixed precision?",
|
| 799 |
+
["no", "fp16", "bf16", "fp8"],
|
| 800 |
+
_convert_mixed_precision,
|
| 801 |
+
)
|
| 802 |
+
if mixed_precision == "fp8":
|
| 803 |
+
if not is_fp8_available():
|
| 804 |
+
raise ValueError(
|
| 805 |
+
"FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine."
|
| 806 |
+
)
|
| 807 |
+
fp8_config = {}
|
| 808 |
+
fp8_config["backend"] = _ask_options(
|
| 809 |
+
"Which FP8 backend do you want to use?",
|
| 810 |
+
["ao", "te", "msamp"],
|
| 811 |
+
_convert_fp8_backend,
|
| 812 |
+
)
|
| 813 |
+
if fp8_config["backend"] == "TE":
|
| 814 |
+
if not is_transformer_engine_available():
|
| 815 |
+
raise ValueError("TransformersEngine was selected, but it is not installed on this machine.")
|
| 816 |
+
fp8_config["use_autocast_during_eval"] = _ask_field(
|
| 817 |
+
"Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ",
|
| 818 |
+
_convert_yes_no_to_bool,
|
| 819 |
+
default=False,
|
| 820 |
+
)
|
| 821 |
+
fp8_config["margin"] = _ask_field(
|
| 822 |
+
"What margin should be used for gradient scaling? [0]: ",
|
| 823 |
+
int,
|
| 824 |
+
default=0,
|
| 825 |
+
)
|
| 826 |
+
fp8_config["interval"] = _ask_field(
|
| 827 |
+
"What interval should be used for for how often the scaling factor is recomputed? [1]: ",
|
| 828 |
+
int,
|
| 829 |
+
default=1,
|
| 830 |
+
)
|
| 831 |
+
fp8_config["fp8_format"] = _ask_options(
|
| 832 |
+
"Which weight format should be used?",
|
| 833 |
+
["HYBRID", "E4M3", "E5M2"],
|
| 834 |
+
lambda i: ["HYBRID", "E4M3", "E5M2"][i],
|
| 835 |
+
default=0,
|
| 836 |
+
)
|
| 837 |
+
fp8_config["amax_history_length"] = _ask_field(
|
| 838 |
+
"What length of history should be used for the amax scaling factor computation? [1024]: ",
|
| 839 |
+
int,
|
| 840 |
+
default=1024,
|
| 841 |
+
)
|
| 842 |
+
fp8_config["amax_compute_algorithm"] = _ask_options(
|
| 843 |
+
"Which algorithm should be used for the amax scaling factor computation?",
|
| 844 |
+
["max", "most_recent"],
|
| 845 |
+
lambda x: "max" if x == 0 else "most_recent",
|
| 846 |
+
default=0,
|
| 847 |
+
)
|
| 848 |
+
fp8_config["override_linear_precision"] = _ask_field(
|
| 849 |
+
"Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ",
|
| 850 |
+
_convert_yes_no_to_bool,
|
| 851 |
+
default=False,
|
| 852 |
+
)
|
| 853 |
+
if fp8_config["override_linear_precision"]:
|
| 854 |
+
fprop = _ask_field(
|
| 855 |
+
"Should `fprop` be executed in higher precision? [yes/NO]: ",
|
| 856 |
+
_convert_yes_no_to_bool,
|
| 857 |
+
default=False,
|
| 858 |
+
)
|
| 859 |
+
dgrad = _ask_field(
|
| 860 |
+
"Should `dgrad` be executed in higher precision? [yes/NO]: ",
|
| 861 |
+
_convert_yes_no_to_bool,
|
| 862 |
+
default=False,
|
| 863 |
+
)
|
| 864 |
+
wgrad = _ask_field(
|
| 865 |
+
"Should `wgrad` be executed in higher precision? [yes/NO]: ",
|
| 866 |
+
_convert_yes_no_to_bool,
|
| 867 |
+
default=False,
|
| 868 |
+
)
|
| 869 |
+
fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad)
|
| 870 |
+
else:
|
| 871 |
+
fp8_config["override_linear_precision"] = (False, False, False)
|
| 872 |
+
|
| 873 |
+
elif fp8_config["backend"] == "MSAMP":
|
| 874 |
+
if not is_msamp_available():
|
| 875 |
+
raise ValueError("MSAMP was selected, but it is not installed on this machine.")
|
| 876 |
+
fp8_config["optimization_level"] = _ask_options(
|
| 877 |
+
"Which optimization level should be used?",
|
| 878 |
+
["O1", "O2"],
|
| 879 |
+
lambda x: "O1" if x == 0 else "O2",
|
| 880 |
+
default=1,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
elif fp8_config["backend"] == "AO":
|
| 884 |
+
if not is_torchao_available():
|
| 885 |
+
raise ValueError("torchao was selected, but it is not installed on this machine.")
|
| 886 |
+
fp8_config["enable_fsdp_float8_all_gather"] = _ask_field(
|
| 887 |
+
"Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: ",
|
| 888 |
+
_convert_yes_no_to_bool,
|
| 889 |
+
default=True,
|
| 890 |
+
)
|
| 891 |
+
fp8_config["pad_inner_dim"] = _ask_field(
|
| 892 |
+
"Do you want to pad the inner dimension of weight matrices before float8 matmuls? This is required for _scaled_mm which has strict alignment requirements. Note: padding may cause memory spikes. [YES/no]: ",
|
| 893 |
+
_convert_yes_no_to_bool,
|
| 894 |
+
default=True,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
if use_dynamo and mixed_precision == "no" and not use_cpu:
|
| 898 |
+
print(
|
| 899 |
+
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
if distributed_type == DistributedType.XLA and mixed_precision == "bf16":
|
| 903 |
+
tpu_downcast_bf16 = _ask_field(
|
| 904 |
+
"Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
return ClusterConfig(
|
| 908 |
+
compute_environment=ComputeEnvironment.LOCAL_MACHINE,
|
| 909 |
+
distributed_type=distributed_type,
|
| 910 |
+
num_processes=num_processes,
|
| 911 |
+
gpu_ids=gpu_ids,
|
| 912 |
+
mixed_precision=mixed_precision,
|
| 913 |
+
downcast_bf16=tpu_downcast_bf16,
|
| 914 |
+
machine_rank=machine_rank,
|
| 915 |
+
num_machines=num_machines,
|
| 916 |
+
main_process_ip=main_process_ip,
|
| 917 |
+
main_process_port=main_process_port,
|
| 918 |
+
main_training_function=main_training_function,
|
| 919 |
+
fp8_config=fp8_config,
|
| 920 |
+
deepspeed_config=deepspeed_config,
|
| 921 |
+
fsdp_config=fsdp_config,
|
| 922 |
+
parallelism_config=parallelism_config,
|
| 923 |
+
megatron_lm_config=megatron_lm_config,
|
| 924 |
+
mpirun_config=mpirun_config,
|
| 925 |
+
use_cpu=use_cpu,
|
| 926 |
+
rdzv_backend=rdzv_backend,
|
| 927 |
+
same_network=same_network,
|
| 928 |
+
commands=tpu_commands,
|
| 929 |
+
command_file=tpu_command_file,
|
| 930 |
+
tpu_env=tpu_env,
|
| 931 |
+
tpu_name=tpu_name,
|
| 932 |
+
tpu_vm=tpu_vm,
|
| 933 |
+
tpu_zone=tpu_zone,
|
| 934 |
+
tpu_use_sudo=tpu_use_sudo,
|
| 935 |
+
tpu_use_cluster=tpu_use_cluster,
|
| 936 |
+
dynamo_config=dynamo_config,
|
| 937 |
+
debug=debug,
|
| 938 |
+
enable_cpu_affinity=enable_cpu_affinity,
|
| 939 |
+
)
|
accelerate/commands/config/config.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from accelerate.utils import ComputeEnvironment
|
| 21 |
+
|
| 22 |
+
from .cluster import get_cluster_input
|
| 23 |
+
from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401
|
| 24 |
+
from .config_utils import _ask_field, _ask_options, _convert_compute_environment # noqa: F401
|
| 25 |
+
from .sagemaker import get_sagemaker_input
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
description = "Launches a series of prompts to create and save a `default_config.yaml` configuration file for your training system. Should always be ran first on your machine"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_user_input():
|
| 32 |
+
compute_environment = _ask_options(
|
| 33 |
+
"In which compute environment are you running?",
|
| 34 |
+
["This machine", "AWS (Amazon SageMaker)"],
|
| 35 |
+
_convert_compute_environment,
|
| 36 |
+
)
|
| 37 |
+
if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
| 38 |
+
config = get_sagemaker_input()
|
| 39 |
+
else:
|
| 40 |
+
config = get_cluster_input()
|
| 41 |
+
return config
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def config_command_parser(subparsers=None):
|
| 45 |
+
if subparsers is not None:
|
| 46 |
+
parser = subparsers.add_parser("config", description=description)
|
| 47 |
+
else:
|
| 48 |
+
parser = argparse.ArgumentParser("Accelerate config command", description=description)
|
| 49 |
+
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--config_file",
|
| 52 |
+
default=None,
|
| 53 |
+
help=(
|
| 54 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
| 55 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 56 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 57 |
+
"with 'huggingface'."
|
| 58 |
+
),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if subparsers is not None:
|
| 62 |
+
parser.set_defaults(func=config_command)
|
| 63 |
+
return parser
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def config_command(args):
|
| 67 |
+
config = get_user_input()
|
| 68 |
+
if args.config_file is not None:
|
| 69 |
+
config_file = args.config_file
|
| 70 |
+
else:
|
| 71 |
+
if not os.path.isdir(cache_dir):
|
| 72 |
+
os.makedirs(cache_dir)
|
| 73 |
+
config_file = default_yaml_config_file
|
| 74 |
+
|
| 75 |
+
if config_file.endswith(".json"):
|
| 76 |
+
config.to_json_file(config_file)
|
| 77 |
+
else:
|
| 78 |
+
config.to_yaml_file(config_file)
|
| 79 |
+
print(f"accelerate configuration saved at {config_file}")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
parser = config_command_parser()
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
config_command(args)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
accelerate/commands/config/config_args.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from enum import Enum
|
| 21 |
+
from typing import Optional, Union
|
| 22 |
+
|
| 23 |
+
import yaml
|
| 24 |
+
|
| 25 |
+
from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
| 26 |
+
from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
hf_cache_home = os.path.expanduser(
|
| 30 |
+
os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
| 31 |
+
)
|
| 32 |
+
cache_dir = os.path.join(hf_cache_home, "accelerate")
|
| 33 |
+
default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
|
| 34 |
+
default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
|
| 35 |
+
|
| 36 |
+
# For backward compatibility: the default config is the json one if it's the only existing file.
|
| 37 |
+
if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
|
| 38 |
+
default_config_file = default_yaml_config_file
|
| 39 |
+
else:
|
| 40 |
+
default_config_file = default_json_config_file
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_config_from_file(config_file):
|
| 44 |
+
if config_file is not None:
|
| 45 |
+
if not os.path.isfile(config_file):
|
| 46 |
+
raise FileNotFoundError(
|
| 47 |
+
f"The passed configuration file `{config_file}` does not exist. "
|
| 48 |
+
"Please pass an existing file to `accelerate launch`, or use the default one "
|
| 49 |
+
"created through `accelerate config` and run `accelerate launch` "
|
| 50 |
+
"without the `--config_file` argument."
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
config_file = default_config_file
|
| 54 |
+
with open(config_file, encoding="utf-8") as f:
|
| 55 |
+
if config_file.endswith(".json"):
|
| 56 |
+
if (
|
| 57 |
+
json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
| 58 |
+
== ComputeEnvironment.LOCAL_MACHINE
|
| 59 |
+
):
|
| 60 |
+
config_class = ClusterConfig
|
| 61 |
+
else:
|
| 62 |
+
config_class = SageMakerConfig
|
| 63 |
+
return config_class.from_json_file(json_file=config_file)
|
| 64 |
+
else:
|
| 65 |
+
if (
|
| 66 |
+
yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
| 67 |
+
== ComputeEnvironment.LOCAL_MACHINE
|
| 68 |
+
):
|
| 69 |
+
config_class = ClusterConfig
|
| 70 |
+
else:
|
| 71 |
+
config_class = SageMakerConfig
|
| 72 |
+
return config_class.from_yaml_file(yaml_file=config_file)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class BaseConfig:
|
| 77 |
+
compute_environment: ComputeEnvironment
|
| 78 |
+
distributed_type: Union[DistributedType, SageMakerDistributedType]
|
| 79 |
+
mixed_precision: str
|
| 80 |
+
use_cpu: bool
|
| 81 |
+
debug: bool
|
| 82 |
+
|
| 83 |
+
def to_dict(self):
|
| 84 |
+
result = self.__dict__
|
| 85 |
+
# For serialization, it's best to convert Enums to strings (or their underlying value type).
|
| 86 |
+
|
| 87 |
+
def _convert_enums(value):
|
| 88 |
+
if isinstance(value, Enum):
|
| 89 |
+
return value.value
|
| 90 |
+
if isinstance(value, dict):
|
| 91 |
+
if not bool(value):
|
| 92 |
+
return None
|
| 93 |
+
for key1, value1 in value.items():
|
| 94 |
+
value[key1] = _convert_enums(value1)
|
| 95 |
+
return value
|
| 96 |
+
|
| 97 |
+
for key, value in result.items():
|
| 98 |
+
result[key] = _convert_enums(value)
|
| 99 |
+
result = {k: v for k, v in result.items() if v is not None}
|
| 100 |
+
return result
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def process_config(config_dict):
|
| 104 |
+
"""
|
| 105 |
+
Processes `config_dict` and sets default values for any missing keys
|
| 106 |
+
"""
|
| 107 |
+
if "compute_environment" not in config_dict:
|
| 108 |
+
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
| 109 |
+
if "distributed_type" not in config_dict:
|
| 110 |
+
raise ValueError("A `distributed_type` must be specified in the config file.")
|
| 111 |
+
if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
|
| 112 |
+
config_dict["num_processes"] = 1
|
| 113 |
+
if "mixed_precision" not in config_dict:
|
| 114 |
+
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
|
| 115 |
+
if "fp16" in config_dict: # Convert the config to the new format.
|
| 116 |
+
del config_dict["fp16"]
|
| 117 |
+
if "dynamo_backend" in config_dict: # Convert the config to the new format.
|
| 118 |
+
dynamo_backend = config_dict.pop("dynamo_backend")
|
| 119 |
+
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
|
| 120 |
+
if "use_cpu" not in config_dict:
|
| 121 |
+
config_dict["use_cpu"] = False
|
| 122 |
+
if "debug" not in config_dict:
|
| 123 |
+
config_dict["debug"] = False
|
| 124 |
+
if "enable_cpu_affinity" not in config_dict:
|
| 125 |
+
config_dict["enable_cpu_affinity"] = False
|
| 126 |
+
return config_dict
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def from_json_file(cls, json_file=None):
|
| 130 |
+
json_file = default_json_config_file if json_file is None else json_file
|
| 131 |
+
with open(json_file, encoding="utf-8") as f:
|
| 132 |
+
config_dict = json.load(f)
|
| 133 |
+
config_dict = cls.process_config(config_dict)
|
| 134 |
+
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
| 135 |
+
if len(extra_keys) > 0:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
| 138 |
+
" version or fix (and potentially remove) these keys from your config file."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return cls(**config_dict)
|
| 142 |
+
|
| 143 |
+
def to_json_file(self, json_file):
|
| 144 |
+
with open(json_file, "w", encoding="utf-8") as f:
|
| 145 |
+
content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
| 146 |
+
f.write(content)
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def from_yaml_file(cls, yaml_file=None):
|
| 150 |
+
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
|
| 151 |
+
with open(yaml_file, encoding="utf-8") as f:
|
| 152 |
+
config_dict = yaml.safe_load(f)
|
| 153 |
+
config_dict = cls.process_config(config_dict)
|
| 154 |
+
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
| 155 |
+
if len(extra_keys) > 0:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
| 158 |
+
" version or fix (and potentially remove) these keys from your config file."
|
| 159 |
+
)
|
| 160 |
+
return cls(**config_dict)
|
| 161 |
+
|
| 162 |
+
def to_yaml_file(self, yaml_file):
|
| 163 |
+
with open(yaml_file, "w", encoding="utf-8") as f:
|
| 164 |
+
yaml.safe_dump(self.to_dict(), f)
|
| 165 |
+
|
| 166 |
+
def __post_init__(self):
|
| 167 |
+
if isinstance(self.compute_environment, str):
|
| 168 |
+
self.compute_environment = ComputeEnvironment(self.compute_environment)
|
| 169 |
+
if isinstance(self.distributed_type, str):
|
| 170 |
+
if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
| 171 |
+
self.distributed_type = SageMakerDistributedType(self.distributed_type)
|
| 172 |
+
else:
|
| 173 |
+
self.distributed_type = DistributedType(self.distributed_type)
|
| 174 |
+
if getattr(self, "dynamo_config", None) is None:
|
| 175 |
+
self.dynamo_config = {}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class ClusterConfig(BaseConfig):
|
| 180 |
+
num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
|
| 181 |
+
machine_rank: int = 0
|
| 182 |
+
num_machines: int = 1
|
| 183 |
+
gpu_ids: Optional[str] = None
|
| 184 |
+
main_process_ip: Optional[str] = None
|
| 185 |
+
main_process_port: Optional[int] = None
|
| 186 |
+
rdzv_backend: Optional[str] = "static"
|
| 187 |
+
same_network: Optional[bool] = False
|
| 188 |
+
main_training_function: str = "main"
|
| 189 |
+
enable_cpu_affinity: bool = False
|
| 190 |
+
|
| 191 |
+
# args for FP8 training
|
| 192 |
+
fp8_config: Optional[dict] = None
|
| 193 |
+
# args for deepspeed_plugin
|
| 194 |
+
deepspeed_config: Optional[dict] = None
|
| 195 |
+
# args for fsdp
|
| 196 |
+
fsdp_config: Optional[dict] = None
|
| 197 |
+
# args for parallelism config
|
| 198 |
+
parallelism_config: Optional[dict] = None
|
| 199 |
+
# args for megatron_lm
|
| 200 |
+
megatron_lm_config: Optional[dict] = None
|
| 201 |
+
# args for mpirun
|
| 202 |
+
mpirun_config: Optional[dict] = None
|
| 203 |
+
# args for TPU
|
| 204 |
+
downcast_bf16: bool = False
|
| 205 |
+
|
| 206 |
+
# args for TPU pods
|
| 207 |
+
tpu_name: Optional[str] = None
|
| 208 |
+
tpu_zone: Optional[str] = None
|
| 209 |
+
tpu_use_cluster: bool = False
|
| 210 |
+
tpu_use_sudo: bool = False
|
| 211 |
+
command_file: Optional[str] = None
|
| 212 |
+
commands: list[str] = None
|
| 213 |
+
tpu_vm: list[str] = None
|
| 214 |
+
tpu_env: list[str] = None
|
| 215 |
+
|
| 216 |
+
# args for dynamo
|
| 217 |
+
dynamo_config: Optional[dict] = None
|
| 218 |
+
|
| 219 |
+
def __post_init__(self):
|
| 220 |
+
if self.deepspeed_config is None:
|
| 221 |
+
self.deepspeed_config = {}
|
| 222 |
+
if self.fsdp_config is None:
|
| 223 |
+
self.fsdp_config = {}
|
| 224 |
+
if self.megatron_lm_config is None:
|
| 225 |
+
self.megatron_lm_config = {}
|
| 226 |
+
if self.mpirun_config is None:
|
| 227 |
+
self.mpirun_config = {}
|
| 228 |
+
if self.fp8_config is None:
|
| 229 |
+
self.fp8_config = {}
|
| 230 |
+
if self.parallelism_config is None:
|
| 231 |
+
self.parallelism_config = {}
|
| 232 |
+
return super().__post_init__()
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@dataclass
|
| 236 |
+
class SageMakerConfig(BaseConfig):
|
| 237 |
+
ec2_instance_type: str
|
| 238 |
+
iam_role_name: str
|
| 239 |
+
image_uri: Optional[str] = None
|
| 240 |
+
profile: Optional[str] = None
|
| 241 |
+
region: str = "us-east-1"
|
| 242 |
+
num_machines: int = 1
|
| 243 |
+
gpu_ids: str = "all"
|
| 244 |
+
base_job_name: str = f"accelerate-sagemaker-{num_machines}"
|
| 245 |
+
pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
|
| 246 |
+
transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
|
| 247 |
+
py_version: str = SAGEMAKER_PYTHON_VERSION
|
| 248 |
+
sagemaker_inputs_file: Optional[str] = None
|
| 249 |
+
sagemaker_metrics_file: Optional[str] = None
|
| 250 |
+
additional_args: Optional[dict] = None
|
| 251 |
+
dynamo_config: Optional[dict] = None
|
| 252 |
+
enable_cpu_affinity: bool = False
|
accelerate/commands/config/config_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from ...utils.dataclasses import (
|
| 20 |
+
ComputeEnvironment,
|
| 21 |
+
DistributedType,
|
| 22 |
+
DynamoBackend,
|
| 23 |
+
FP8BackendType,
|
| 24 |
+
PrecisionType,
|
| 25 |
+
SageMakerDistributedType,
|
| 26 |
+
)
|
| 27 |
+
from ..menu import BulletMenu
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DYNAMO_BACKENDS = [
|
| 31 |
+
"EAGER",
|
| 32 |
+
"AOT_EAGER",
|
| 33 |
+
"INDUCTOR",
|
| 34 |
+
"AOT_TS_NVFUSER",
|
| 35 |
+
"NVPRIMS_NVFUSER",
|
| 36 |
+
"CUDAGRAPHS",
|
| 37 |
+
"OFI",
|
| 38 |
+
"FX2TRT",
|
| 39 |
+
"ONNXRT",
|
| 40 |
+
"TENSORRT",
|
| 41 |
+
"AOT_TORCHXLA_TRACE_ONCE",
|
| 42 |
+
"TORHCHXLA_TRACE_ONCE",
|
| 43 |
+
"TVM",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _ask_field(input_text, convert_value=None, default=None, error_message=None):
|
| 48 |
+
ask_again = True
|
| 49 |
+
while ask_again:
|
| 50 |
+
result = input(input_text)
|
| 51 |
+
try:
|
| 52 |
+
if default is not None and len(result) == 0:
|
| 53 |
+
return default
|
| 54 |
+
return convert_value(result) if convert_value is not None else result
|
| 55 |
+
except Exception:
|
| 56 |
+
if error_message is not None:
|
| 57 |
+
print(error_message)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _ask_options(input_text, options=[], convert_value=None, default=0):
|
| 61 |
+
menu = BulletMenu(input_text, options)
|
| 62 |
+
result = menu.run(default_choice=default)
|
| 63 |
+
return convert_value(result) if convert_value is not None else result
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _convert_compute_environment(value):
|
| 67 |
+
value = int(value)
|
| 68 |
+
return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _convert_distributed_mode(value):
|
| 72 |
+
value = int(value)
|
| 73 |
+
return DistributedType(
|
| 74 |
+
[
|
| 75 |
+
"NO",
|
| 76 |
+
"MULTI_CPU",
|
| 77 |
+
"MULTI_XPU",
|
| 78 |
+
"MULTI_HPU",
|
| 79 |
+
"MULTI_GPU",
|
| 80 |
+
"MULTI_NPU",
|
| 81 |
+
"MULTI_MLU",
|
| 82 |
+
"MULTI_SDAA",
|
| 83 |
+
"MULTI_MUSA",
|
| 84 |
+
"MULTI_NEURON",
|
| 85 |
+
"XLA",
|
| 86 |
+
][value]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _convert_dynamo_backend(value):
|
| 91 |
+
value = int(value)
|
| 92 |
+
return DynamoBackend(DYNAMO_BACKENDS[value]).value
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _convert_mixed_precision(value):
|
| 96 |
+
value = int(value)
|
| 97 |
+
return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _convert_sagemaker_distributed_mode(value):
|
| 101 |
+
value = int(value)
|
| 102 |
+
return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _convert_fp8_backend(value):
|
| 106 |
+
value = int(value)
|
| 107 |
+
return FP8BackendType(["AO", "TE", "MSAMP"][value])
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _convert_yes_no_to_bool(value):
|
| 111 |
+
return {"yes": True, "no": False}[value.lower()]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
|
| 115 |
+
"""
|
| 116 |
+
A custom formatter that will remove the usage line from the help message for subcommands.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def _format_usage(self, usage, actions, groups, prefix):
|
| 120 |
+
usage = super()._format_usage(usage, actions, groups, prefix)
|
| 121 |
+
usage = usage.replace("<command> [<args>] ", "")
|
| 122 |
+
return usage
|
accelerate/commands/config/default.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ...utils import (
|
| 22 |
+
is_hpu_available,
|
| 23 |
+
is_mlu_available,
|
| 24 |
+
is_musa_available,
|
| 25 |
+
is_neuron_available,
|
| 26 |
+
is_npu_available,
|
| 27 |
+
is_sdaa_available,
|
| 28 |
+
is_xpu_available,
|
| 29 |
+
)
|
| 30 |
+
from .config_args import ClusterConfig, default_json_config_file
|
| 31 |
+
from .config_utils import SubcommandHelpFormatter
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
description = "Create a default config file for Accelerate with only a few flags set."
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file):
|
| 38 |
+
"""
|
| 39 |
+
Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
|
| 40 |
+
set CPU if it is a CPU-only machine.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
mixed_precision (`str`, *optional*, defaults to "no"):
|
| 44 |
+
Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
|
| 45 |
+
save_location (`str`, *optional*, defaults to `default_json_config_file`):
|
| 46 |
+
Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default
|
| 47 |
+
location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overridden by setting
|
| 48 |
+
the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.
|
| 49 |
+
"""
|
| 50 |
+
path = Path(save_location)
|
| 51 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
if path.exists():
|
| 53 |
+
print(
|
| 54 |
+
f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
|
| 55 |
+
)
|
| 56 |
+
return False
|
| 57 |
+
mixed_precision = mixed_precision.lower()
|
| 58 |
+
if mixed_precision not in ["no", "fp16", "bf16", "fp8"]:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}"
|
| 61 |
+
)
|
| 62 |
+
config = {
|
| 63 |
+
"compute_environment": "LOCAL_MACHINE",
|
| 64 |
+
"mixed_precision": mixed_precision,
|
| 65 |
+
}
|
| 66 |
+
if is_mlu_available():
|
| 67 |
+
num_mlus = torch.mlu.device_count()
|
| 68 |
+
config["num_processes"] = num_mlus
|
| 69 |
+
config["use_cpu"] = False
|
| 70 |
+
if num_mlus > 1:
|
| 71 |
+
config["distributed_type"] = "MULTI_MLU"
|
| 72 |
+
else:
|
| 73 |
+
config["distributed_type"] = "NO"
|
| 74 |
+
if is_sdaa_available():
|
| 75 |
+
num_sdaas = torch.sdaa.device_count()
|
| 76 |
+
config["num_processes"] = num_sdaas
|
| 77 |
+
config["use_cpu"] = False
|
| 78 |
+
if num_sdaas > 1:
|
| 79 |
+
config["distributed_type"] = "MULTI_SDAA"
|
| 80 |
+
else:
|
| 81 |
+
config["distributed_type"] = "NO"
|
| 82 |
+
elif is_musa_available():
|
| 83 |
+
num_musas = torch.musa.device_count()
|
| 84 |
+
config["num_processes"] = num_musas
|
| 85 |
+
config["use_cpu"] = False
|
| 86 |
+
if num_musas > 1:
|
| 87 |
+
config["distributed_type"] = "MULTI_MUSA"
|
| 88 |
+
else:
|
| 89 |
+
config["distributed_type"] = "NO"
|
| 90 |
+
elif is_hpu_available():
|
| 91 |
+
num_hpus = torch.hpu.device_count()
|
| 92 |
+
config["num_processes"] = num_hpus
|
| 93 |
+
config["use_cpu"] = False
|
| 94 |
+
if num_hpus > 1:
|
| 95 |
+
config["distributed_type"] = "MULTI_HPU"
|
| 96 |
+
else:
|
| 97 |
+
config["distributed_type"] = "NO"
|
| 98 |
+
elif torch.cuda.is_available():
|
| 99 |
+
num_gpus = torch.cuda.device_count()
|
| 100 |
+
config["num_processes"] = num_gpus
|
| 101 |
+
config["use_cpu"] = False
|
| 102 |
+
if num_gpus > 1:
|
| 103 |
+
config["distributed_type"] = "MULTI_GPU"
|
| 104 |
+
else:
|
| 105 |
+
config["distributed_type"] = "NO"
|
| 106 |
+
elif is_xpu_available():
|
| 107 |
+
num_xpus = torch.xpu.device_count()
|
| 108 |
+
config["num_processes"] = num_xpus
|
| 109 |
+
config["use_cpu"] = False
|
| 110 |
+
if num_xpus > 1:
|
| 111 |
+
config["distributed_type"] = "MULTI_XPU"
|
| 112 |
+
else:
|
| 113 |
+
config["distributed_type"] = "NO"
|
| 114 |
+
elif is_npu_available():
|
| 115 |
+
num_npus = torch.npu.device_count()
|
| 116 |
+
config["num_processes"] = num_npus
|
| 117 |
+
config["use_cpu"] = False
|
| 118 |
+
if num_npus > 1:
|
| 119 |
+
config["distributed_type"] = "MULTI_NPU"
|
| 120 |
+
else:
|
| 121 |
+
config["distributed_type"] = "NO"
|
| 122 |
+
elif is_neuron_available():
|
| 123 |
+
num_neuron_cores = torch.neuron.device_count()
|
| 124 |
+
config["num_processes"] = num_neuron_cores
|
| 125 |
+
config["use_cpu"] = False
|
| 126 |
+
if num_neuron_cores > 1:
|
| 127 |
+
config["distributed_type"] = "MULTI_NEURON"
|
| 128 |
+
else:
|
| 129 |
+
config["distributed_type"] = "NO"
|
| 130 |
+
else:
|
| 131 |
+
num_xpus = 0
|
| 132 |
+
config["use_cpu"] = True
|
| 133 |
+
config["num_processes"] = 1
|
| 134 |
+
config["distributed_type"] = "NO"
|
| 135 |
+
config["debug"] = False
|
| 136 |
+
config["enable_cpu_affinity"] = False
|
| 137 |
+
config = ClusterConfig(**config)
|
| 138 |
+
config.to_json_file(path)
|
| 139 |
+
return path
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def default_command_parser(parser, parents):
|
| 143 |
+
parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--config_file",
|
| 146 |
+
default=default_json_config_file,
|
| 147 |
+
help=(
|
| 148 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
| 149 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 150 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 151 |
+
"with 'huggingface'."
|
| 152 |
+
),
|
| 153 |
+
dest="save_location",
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--mixed_precision",
|
| 158 |
+
choices=["no", "fp16", "bf16"],
|
| 159 |
+
type=str,
|
| 160 |
+
help="Whether or not to use mixed precision training. "
|
| 161 |
+
"Choose between FP16 and BF16 (bfloat16) training. "
|
| 162 |
+
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
|
| 163 |
+
default="no",
|
| 164 |
+
)
|
| 165 |
+
parser.set_defaults(func=default_config_command)
|
| 166 |
+
return parser
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def default_config_command(args):
|
| 170 |
+
config_file = write_basic_config(args.mixed_precision, args.save_location)
|
| 171 |
+
if config_file:
|
| 172 |
+
print(f"accelerate configuration saved at {config_file}")
|
accelerate/commands/config/sagemaker.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
|
| 20 |
+
from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
|
| 21 |
+
from ...utils.imports import is_boto3_available
|
| 22 |
+
from .config_args import SageMakerConfig
|
| 23 |
+
from .config_utils import (
|
| 24 |
+
DYNAMO_BACKENDS,
|
| 25 |
+
_ask_field,
|
| 26 |
+
_ask_options,
|
| 27 |
+
_convert_dynamo_backend,
|
| 28 |
+
_convert_mixed_precision,
|
| 29 |
+
_convert_sagemaker_distributed_mode,
|
| 30 |
+
_convert_yes_no_to_bool,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_boto3_available():
|
| 35 |
+
import boto3 # noqa: F401
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _create_iam_role_for_sagemaker(role_name):
|
| 39 |
+
iam_client = boto3.client("iam")
|
| 40 |
+
|
| 41 |
+
sagemaker_trust_policy = {
|
| 42 |
+
"Version": "2012-10-17",
|
| 43 |
+
"Statement": [
|
| 44 |
+
{"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
|
| 45 |
+
],
|
| 46 |
+
}
|
| 47 |
+
try:
|
| 48 |
+
# create the role, associated with the chosen trust policy
|
| 49 |
+
iam_client.create_role(
|
| 50 |
+
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
|
| 51 |
+
)
|
| 52 |
+
policy_document = {
|
| 53 |
+
"Version": "2012-10-17",
|
| 54 |
+
"Statement": [
|
| 55 |
+
{
|
| 56 |
+
"Effect": "Allow",
|
| 57 |
+
"Action": [
|
| 58 |
+
"sagemaker:*",
|
| 59 |
+
"ecr:GetDownloadUrlForLayer",
|
| 60 |
+
"ecr:BatchGetImage",
|
| 61 |
+
"ecr:BatchCheckLayerAvailability",
|
| 62 |
+
"ecr:GetAuthorizationToken",
|
| 63 |
+
"cloudwatch:PutMetricData",
|
| 64 |
+
"cloudwatch:GetMetricData",
|
| 65 |
+
"cloudwatch:GetMetricStatistics",
|
| 66 |
+
"cloudwatch:ListMetrics",
|
| 67 |
+
"logs:CreateLogGroup",
|
| 68 |
+
"logs:CreateLogStream",
|
| 69 |
+
"logs:DescribeLogStreams",
|
| 70 |
+
"logs:PutLogEvents",
|
| 71 |
+
"logs:GetLogEvents",
|
| 72 |
+
"s3:CreateBucket",
|
| 73 |
+
"s3:ListBucket",
|
| 74 |
+
"s3:GetBucketLocation",
|
| 75 |
+
"s3:GetObject",
|
| 76 |
+
"s3:PutObject",
|
| 77 |
+
],
|
| 78 |
+
"Resource": "*",
|
| 79 |
+
}
|
| 80 |
+
],
|
| 81 |
+
}
|
| 82 |
+
# attach policy to role
|
| 83 |
+
iam_client.put_role_policy(
|
| 84 |
+
RoleName=role_name,
|
| 85 |
+
PolicyName=f"{role_name}_policy_permission",
|
| 86 |
+
PolicyDocument=json.dumps(policy_document, indent=2),
|
| 87 |
+
)
|
| 88 |
+
except iam_client.exceptions.EntityAlreadyExistsException:
|
| 89 |
+
print(f"role {role_name} already exists. Using existing one")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _get_iam_role_arn(role_name):
|
| 93 |
+
iam_client = boto3.client("iam")
|
| 94 |
+
return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_sagemaker_input():
|
| 98 |
+
credentials_configuration = _ask_options(
|
| 99 |
+
"How do you want to authorize?",
|
| 100 |
+
["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
|
| 101 |
+
int,
|
| 102 |
+
)
|
| 103 |
+
aws_profile = None
|
| 104 |
+
if credentials_configuration == 0:
|
| 105 |
+
aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
|
| 106 |
+
os.environ["AWS_PROFILE"] = aws_profile
|
| 107 |
+
else:
|
| 108 |
+
print(
|
| 109 |
+
"Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
|
| 110 |
+
"`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
|
| 111 |
+
)
|
| 112 |
+
aws_access_key_id = _ask_field("AWS Access Key ID: ")
|
| 113 |
+
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
| 114 |
+
|
| 115 |
+
aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
|
| 116 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
| 117 |
+
|
| 118 |
+
aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
|
| 119 |
+
os.environ["AWS_DEFAULT_REGION"] = aws_region
|
| 120 |
+
|
| 121 |
+
role_management = _ask_options(
|
| 122 |
+
"Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
|
| 123 |
+
["Provide IAM Role name", "Create new IAM role using credentials"],
|
| 124 |
+
int,
|
| 125 |
+
)
|
| 126 |
+
if role_management == 0:
|
| 127 |
+
iam_role_name = _ask_field("Enter your IAM role name: ")
|
| 128 |
+
else:
|
| 129 |
+
iam_role_name = "accelerate_sagemaker_execution_role"
|
| 130 |
+
print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
|
| 131 |
+
_create_iam_role_for_sagemaker(iam_role_name)
|
| 132 |
+
|
| 133 |
+
is_custom_docker_image = _ask_field(
|
| 134 |
+
"Do you want to use custom Docker image? [yes/NO]: ",
|
| 135 |
+
_convert_yes_no_to_bool,
|
| 136 |
+
default=False,
|
| 137 |
+
error_message="Please enter yes or no.",
|
| 138 |
+
)
|
| 139 |
+
docker_image = None
|
| 140 |
+
if is_custom_docker_image:
|
| 141 |
+
docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
|
| 142 |
+
|
| 143 |
+
is_sagemaker_inputs_enabled = _ask_field(
|
| 144 |
+
"Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
|
| 145 |
+
_convert_yes_no_to_bool,
|
| 146 |
+
default=False,
|
| 147 |
+
error_message="Please enter yes or no.",
|
| 148 |
+
)
|
| 149 |
+
sagemaker_inputs_file = None
|
| 150 |
+
if is_sagemaker_inputs_enabled:
|
| 151 |
+
sagemaker_inputs_file = _ask_field(
|
| 152 |
+
"Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
|
| 153 |
+
lambda x: str(x).lower(),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
is_sagemaker_metrics_enabled = _ask_field(
|
| 157 |
+
"Do you want to enable SageMaker metrics? [yes/NO]: ",
|
| 158 |
+
_convert_yes_no_to_bool,
|
| 159 |
+
default=False,
|
| 160 |
+
error_message="Please enter yes or no.",
|
| 161 |
+
)
|
| 162 |
+
sagemaker_metrics_file = None
|
| 163 |
+
if is_sagemaker_metrics_enabled:
|
| 164 |
+
sagemaker_metrics_file = _ask_field(
|
| 165 |
+
"Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
|
| 166 |
+
lambda x: str(x).lower(),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
distributed_type = _ask_options(
|
| 170 |
+
"What is the distributed mode?",
|
| 171 |
+
["No distributed training", "Data parallelism"],
|
| 172 |
+
_convert_sagemaker_distributed_mode,
|
| 173 |
+
)
|
| 174 |
+
dynamo_config = {}
|
| 175 |
+
use_dynamo = _ask_field(
|
| 176 |
+
"Do you wish to optimize your script with torch dynamo?[yes/NO]:",
|
| 177 |
+
_convert_yes_no_to_bool,
|
| 178 |
+
default=False,
|
| 179 |
+
error_message="Please enter yes or no.",
|
| 180 |
+
)
|
| 181 |
+
if use_dynamo:
|
| 182 |
+
prefix = "dynamo_"
|
| 183 |
+
dynamo_config[prefix + "backend"] = _ask_options(
|
| 184 |
+
"Which dynamo backend would you like to use?",
|
| 185 |
+
[x.lower() for x in DYNAMO_BACKENDS],
|
| 186 |
+
_convert_dynamo_backend,
|
| 187 |
+
default=2,
|
| 188 |
+
)
|
| 189 |
+
use_custom_options = _ask_field(
|
| 190 |
+
"Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
|
| 191 |
+
_convert_yes_no_to_bool,
|
| 192 |
+
default=False,
|
| 193 |
+
error_message="Please enter yes or no.",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if use_custom_options:
|
| 197 |
+
dynamo_config[prefix + "mode"] = _ask_options(
|
| 198 |
+
"Which mode do you want to use?",
|
| 199 |
+
TORCH_DYNAMO_MODES,
|
| 200 |
+
lambda x: TORCH_DYNAMO_MODES[int(x)],
|
| 201 |
+
default="default",
|
| 202 |
+
)
|
| 203 |
+
dynamo_config[prefix + "use_fullgraph"] = _ask_field(
|
| 204 |
+
"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
|
| 205 |
+
_convert_yes_no_to_bool,
|
| 206 |
+
default=False,
|
| 207 |
+
error_message="Please enter yes or no.",
|
| 208 |
+
)
|
| 209 |
+
dynamo_config[prefix + "use_dynamic"] = _ask_field(
|
| 210 |
+
"Do you want to enable dynamic shape tracing? [yes/NO]: ",
|
| 211 |
+
_convert_yes_no_to_bool,
|
| 212 |
+
default=False,
|
| 213 |
+
error_message="Please enter yes or no.",
|
| 214 |
+
)
|
| 215 |
+
dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
|
| 216 |
+
"Do you want to enable regional compilation? [yes/NO]: ",
|
| 217 |
+
_convert_yes_no_to_bool,
|
| 218 |
+
default=False,
|
| 219 |
+
error_message="Please enter yes or no.",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
ec2_instance_query = "Which EC2 instance type you want to use for your training?"
|
| 223 |
+
if distributed_type != SageMakerDistributedType.NO:
|
| 224 |
+
ec2_instance_type = _ask_options(
|
| 225 |
+
ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
ec2_instance_query += "? [ml.p3.2xlarge]:"
|
| 229 |
+
ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
|
| 230 |
+
|
| 231 |
+
debug = False
|
| 232 |
+
if distributed_type != SageMakerDistributedType.NO:
|
| 233 |
+
debug = _ask_field(
|
| 234 |
+
"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
|
| 235 |
+
_convert_yes_no_to_bool,
|
| 236 |
+
default=False,
|
| 237 |
+
error_message="Please enter yes or no.",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
num_machines = 1
|
| 241 |
+
if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
|
| 242 |
+
num_machines = _ask_field(
|
| 243 |
+
"How many machines do you want use? [1]: ",
|
| 244 |
+
int,
|
| 245 |
+
default=1,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
mixed_precision = _ask_options(
|
| 249 |
+
"Do you wish to use FP16 or BF16 (mixed precision)?",
|
| 250 |
+
["no", "fp16", "bf16", "fp8"],
|
| 251 |
+
_convert_mixed_precision,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if use_dynamo and mixed_precision == "no":
|
| 255 |
+
print(
|
| 256 |
+
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return SageMakerConfig(
|
| 260 |
+
image_uri=docker_image,
|
| 261 |
+
compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
|
| 262 |
+
distributed_type=distributed_type,
|
| 263 |
+
use_cpu=False,
|
| 264 |
+
dynamo_config=dynamo_config,
|
| 265 |
+
ec2_instance_type=ec2_instance_type,
|
| 266 |
+
profile=aws_profile,
|
| 267 |
+
region=aws_region,
|
| 268 |
+
iam_role_name=iam_role_name,
|
| 269 |
+
mixed_precision=mixed_precision,
|
| 270 |
+
num_machines=num_machines,
|
| 271 |
+
sagemaker_inputs_file=sagemaker_inputs_file,
|
| 272 |
+
sagemaker_metrics_file=sagemaker_metrics_file,
|
| 273 |
+
debug=debug,
|
| 274 |
+
)
|
accelerate/commands/config/update.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
from .config_args import default_config_file, load_config_from_file
|
| 20 |
+
from .config_utils import SubcommandHelpFormatter
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
description = "Update an existing config file with the latest defaults while maintaining the old configuration."
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def update_config(args):
|
| 27 |
+
"""
|
| 28 |
+
Update an existing config file with the latest defaults while maintaining the old configuration.
|
| 29 |
+
"""
|
| 30 |
+
config_file = args.config_file
|
| 31 |
+
if config_file is None and Path(default_config_file).exists():
|
| 32 |
+
config_file = default_config_file
|
| 33 |
+
elif not Path(config_file).exists():
|
| 34 |
+
raise ValueError(f"The passed config file located at {config_file} doesn't exist.")
|
| 35 |
+
config = load_config_from_file(config_file)
|
| 36 |
+
|
| 37 |
+
if config_file.endswith(".json"):
|
| 38 |
+
config.to_json_file(config_file)
|
| 39 |
+
else:
|
| 40 |
+
config.to_yaml_file(config_file)
|
| 41 |
+
return config_file
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def update_command_parser(parser, parents):
|
| 45 |
+
parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--config_file",
|
| 48 |
+
default=None,
|
| 49 |
+
help=(
|
| 50 |
+
"The path to the config file to update. Will default to a file named default_config.yaml in the cache "
|
| 51 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 52 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 53 |
+
"with 'huggingface'."
|
| 54 |
+
),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
parser.set_defaults(func=update_config_command)
|
| 58 |
+
return parser
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def update_config_command(args):
|
| 62 |
+
config_file = update_config(args)
|
| 63 |
+
print(f"Successfully updated the configuration file at {config_file}.")
|
accelerate/commands/env.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import platform
|
| 20 |
+
import subprocess
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import psutil
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from accelerate import __version__ as version
|
| 27 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
| 28 |
+
|
| 29 |
+
from ..utils import (
|
| 30 |
+
is_mlu_available,
|
| 31 |
+
is_musa_available,
|
| 32 |
+
is_neuron_available,
|
| 33 |
+
is_npu_available,
|
| 34 |
+
is_sdaa_available,
|
| 35 |
+
is_xpu_available,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def env_command_parser(subparsers=None):
|
| 40 |
+
if subparsers is not None:
|
| 41 |
+
parser = subparsers.add_parser("env")
|
| 42 |
+
else:
|
| 43 |
+
parser = argparse.ArgumentParser("Accelerate env command")
|
| 44 |
+
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--config_file", default=None, help="The config file to use for the default values in the launching script."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if subparsers is not None:
|
| 50 |
+
parser.set_defaults(func=env_command)
|
| 51 |
+
return parser
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def env_command(args):
|
| 55 |
+
pt_version = torch.__version__
|
| 56 |
+
pt_cuda_available = torch.cuda.is_available()
|
| 57 |
+
pt_xpu_available = is_xpu_available()
|
| 58 |
+
pt_mlu_available = is_mlu_available()
|
| 59 |
+
pt_sdaa_available = is_sdaa_available()
|
| 60 |
+
pt_musa_available = is_musa_available()
|
| 61 |
+
pt_npu_available = is_npu_available()
|
| 62 |
+
pt_neuron_available = is_neuron_available()
|
| 63 |
+
|
| 64 |
+
accelerator = "N/A"
|
| 65 |
+
if pt_cuda_available:
|
| 66 |
+
accelerator = "CUDA"
|
| 67 |
+
elif pt_xpu_available:
|
| 68 |
+
accelerator = "XPU"
|
| 69 |
+
elif pt_mlu_available:
|
| 70 |
+
accelerator = "MLU"
|
| 71 |
+
elif pt_sdaa_available:
|
| 72 |
+
accelerator = "SDAA"
|
| 73 |
+
elif pt_musa_available:
|
| 74 |
+
accelerator = "MUSA"
|
| 75 |
+
elif pt_npu_available:
|
| 76 |
+
accelerator = "NPU"
|
| 77 |
+
elif pt_neuron_available:
|
| 78 |
+
accelerator = "NEURON"
|
| 79 |
+
|
| 80 |
+
accelerate_config = "Not found"
|
| 81 |
+
# Get the default from the config file.
|
| 82 |
+
if args.config_file is not None or os.path.isfile(default_config_file):
|
| 83 |
+
accelerate_config = load_config_from_file(args.config_file).to_dict()
|
| 84 |
+
|
| 85 |
+
# if we can run which, get it
|
| 86 |
+
command = None
|
| 87 |
+
bash_location = "Not found"
|
| 88 |
+
if os.name == "nt":
|
| 89 |
+
command = ["where", "accelerate"]
|
| 90 |
+
elif os.name == "posix":
|
| 91 |
+
command = ["which", "accelerate"]
|
| 92 |
+
if command is not None:
|
| 93 |
+
bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()
|
| 94 |
+
info = {
|
| 95 |
+
"`Accelerate` version": version,
|
| 96 |
+
"Platform": platform.platform(),
|
| 97 |
+
"`accelerate` bash location": bash_location,
|
| 98 |
+
"Python version": platform.python_version(),
|
| 99 |
+
"Numpy version": np.__version__,
|
| 100 |
+
"PyTorch version": f"{pt_version}",
|
| 101 |
+
"PyTorch accelerator": accelerator,
|
| 102 |
+
"System RAM": f"{psutil.virtual_memory().total / 1024**3:.2f} GB",
|
| 103 |
+
}
|
| 104 |
+
if pt_cuda_available:
|
| 105 |
+
info["GPU type"] = torch.cuda.get_device_name()
|
| 106 |
+
elif pt_xpu_available:
|
| 107 |
+
info["XPU type"] = torch.xpu.get_device_name()
|
| 108 |
+
elif pt_mlu_available:
|
| 109 |
+
info["MLU type"] = torch.mlu.get_device_name()
|
| 110 |
+
elif pt_sdaa_available:
|
| 111 |
+
info["SDAA type"] = torch.sdaa.get_device_name()
|
| 112 |
+
elif pt_musa_available:
|
| 113 |
+
info["MUSA type"] = torch.musa.get_device_name()
|
| 114 |
+
elif pt_neuron_available:
|
| 115 |
+
info["NEURON type"] = torch.neuron.get_device_name()
|
| 116 |
+
elif pt_npu_available:
|
| 117 |
+
info["CANN version"] = torch.version.cann
|
| 118 |
+
|
| 119 |
+
print("\nCopy-and-paste the text below in your GitHub issue\n")
|
| 120 |
+
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]))
|
| 121 |
+
|
| 122 |
+
print("- `Accelerate` default config:" if args.config_file is None else "- `Accelerate` config passed:")
|
| 123 |
+
accelerate_config_str = (
|
| 124 |
+
"\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
|
| 125 |
+
if isinstance(accelerate_config, dict)
|
| 126 |
+
else f"\t{accelerate_config}"
|
| 127 |
+
)
|
| 128 |
+
print(accelerate_config_str)
|
| 129 |
+
|
| 130 |
+
info["`Accelerate` configs"] = accelerate_config
|
| 131 |
+
|
| 132 |
+
return info
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def main() -> int:
|
| 136 |
+
parser = env_command_parser()
|
| 137 |
+
args = parser.parse_args()
|
| 138 |
+
env_command(args)
|
| 139 |
+
return 0
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
raise SystemExit(main())
|
accelerate/commands/estimate.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from typing import Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from huggingface_hub import model_info
|
| 20 |
+
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
| 21 |
+
|
| 22 |
+
from accelerate import init_empty_weights
|
| 23 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 24 |
+
from accelerate.utils import (
|
| 25 |
+
calculate_maximum_sizes,
|
| 26 |
+
convert_bytes,
|
| 27 |
+
is_timm_available,
|
| 28 |
+
is_transformers_available,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if is_transformers_available():
|
| 33 |
+
import transformers
|
| 34 |
+
from transformers import AutoConfig, AutoModel
|
| 35 |
+
|
| 36 |
+
if is_timm_available():
|
| 37 |
+
import timm
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def verify_on_hub(repo: str, token: Optional[str] = None):
|
| 41 |
+
"Verifies that the model is on the hub and returns the model info."
|
| 42 |
+
try:
|
| 43 |
+
return model_info(repo, token=token)
|
| 44 |
+
except (OSError, GatedRepoError):
|
| 45 |
+
return "gated"
|
| 46 |
+
except RepositoryNotFoundError:
|
| 47 |
+
return "repo"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def check_has_model(error):
|
| 51 |
+
"""
|
| 52 |
+
Checks what library spawned `error` when a model is not found
|
| 53 |
+
"""
|
| 54 |
+
if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]:
|
| 55 |
+
return "timm"
|
| 56 |
+
elif (
|
| 57 |
+
is_transformers_available()
|
| 58 |
+
and isinstance(error, OSError)
|
| 59 |
+
and "does not appear to have a file named" in error.args[0]
|
| 60 |
+
):
|
| 61 |
+
return "transformers"
|
| 62 |
+
else:
|
| 63 |
+
return "unknown"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def create_empty_model(
|
| 67 |
+
model_name: str, library_name: str, trust_remote_code: bool = False, access_token: Optional[str] = None
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory
|
| 71 |
+
consumption.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model_name (`str`):
|
| 75 |
+
The model name on the Hub
|
| 76 |
+
library_name (`str`):
|
| 77 |
+
The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no
|
| 78 |
+
metadata on the Hub to determine the library.
|
| 79 |
+
trust_remote_code (`bool`, `optional`, defaults to `False`):
|
| 80 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 81 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 82 |
+
execute code present on the Hub on your local machine.
|
| 83 |
+
access_token (`str`, `optional`, defaults to `None`):
|
| 84 |
+
The access token to use to access private or gated models on the Hub. (for use on the Gradio app)
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
`torch.nn.Module`: The torch model that has been initialized on the `meta` device.
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
model_info = verify_on_hub(model_name, access_token)
|
| 91 |
+
# Simplified errors
|
| 92 |
+
if model_info == "gated":
|
| 93 |
+
raise OSError(
|
| 94 |
+
f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`."
|
| 95 |
+
)
|
| 96 |
+
elif model_info == "repo":
|
| 97 |
+
raise OSError(
|
| 98 |
+
f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo,"
|
| 99 |
+
" make sure you are authenticated via `huggingface-cli login` and have access."
|
| 100 |
+
)
|
| 101 |
+
if library_name is None:
|
| 102 |
+
library_name = getattr(model_info, "library_name", False)
|
| 103 |
+
if not library_name:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)"
|
| 106 |
+
)
|
| 107 |
+
if library_name == "transformers":
|
| 108 |
+
if not is_transformers_available():
|
| 109 |
+
raise ImportError(
|
| 110 |
+
f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`"
|
| 111 |
+
)
|
| 112 |
+
print(f"Loading pretrained config for `{model_name}` from `transformers`...")
|
| 113 |
+
if model_info.config is None:
|
| 114 |
+
raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.")
|
| 115 |
+
|
| 116 |
+
auto_map = model_info.config.get("auto_map", False)
|
| 117 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token)
|
| 118 |
+
with init_empty_weights():
|
| 119 |
+
# remote code could specify a specific `AutoModel` class in the `auto_map`
|
| 120 |
+
constructor = AutoModel
|
| 121 |
+
if isinstance(auto_map, dict):
|
| 122 |
+
value = None
|
| 123 |
+
for key in auto_map.keys():
|
| 124 |
+
if key.startswith("AutoModelFor"):
|
| 125 |
+
value = key
|
| 126 |
+
break
|
| 127 |
+
if value is not None:
|
| 128 |
+
constructor = getattr(transformers, value)
|
| 129 |
+
# we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config
|
| 130 |
+
model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code)
|
| 131 |
+
elif library_name == "timm":
|
| 132 |
+
if not is_timm_available():
|
| 133 |
+
raise ImportError(
|
| 134 |
+
f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`"
|
| 135 |
+
)
|
| 136 |
+
print(f"Loading pretrained config for `{model_name}` from `timm`...")
|
| 137 |
+
with init_empty_weights():
|
| 138 |
+
model = timm.create_model(model_name, pretrained=False)
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support."
|
| 142 |
+
)
|
| 143 |
+
return model
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def create_ascii_table(headers: list, rows: list, title: str):
|
| 147 |
+
"Creates a pretty table from a list of rows, minimal version of `tabulate`."
|
| 148 |
+
sep_char, in_between = "│", "─"
|
| 149 |
+
column_widths = []
|
| 150 |
+
for i in range(len(headers)):
|
| 151 |
+
column_values = [row[i] for row in rows] + [headers[i]]
|
| 152 |
+
max_column_width = max(len(value) for value in column_values)
|
| 153 |
+
column_widths.append(max_column_width)
|
| 154 |
+
|
| 155 |
+
formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))]
|
| 156 |
+
|
| 157 |
+
pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}"
|
| 158 |
+
diff = 0
|
| 159 |
+
|
| 160 |
+
def make_row(left_char, middle_char, right_char):
|
| 161 |
+
return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}"
|
| 162 |
+
|
| 163 |
+
separator = make_row("├", "┼", "┤")
|
| 164 |
+
if len(title) > sum(column_widths):
|
| 165 |
+
diff = abs(len(title) - len(separator))
|
| 166 |
+
column_widths[-1] += diff
|
| 167 |
+
|
| 168 |
+
# Update with diff
|
| 169 |
+
separator = make_row("├", "┼", "┤")
|
| 170 |
+
initial_rows = [
|
| 171 |
+
make_row("┌", in_between, "┐"),
|
| 172 |
+
f"{sep_char}{title.center(len(separator) - 2)}{sep_char}",
|
| 173 |
+
make_row("├", "┬", "┤"),
|
| 174 |
+
]
|
| 175 |
+
table = "\n".join(initial_rows) + "\n"
|
| 176 |
+
column_widths[-1] += diff
|
| 177 |
+
centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)]
|
| 178 |
+
table += f"{pattern % tuple(centered_line)}\n{separator}\n"
|
| 179 |
+
for i, line in enumerate(rows):
|
| 180 |
+
centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)]
|
| 181 |
+
table += f"{pattern % tuple(centered_line)}\n"
|
| 182 |
+
table += f"└{'┴'.join([in_between * n for n in column_widths])}┘"
|
| 183 |
+
|
| 184 |
+
return table
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def estimate_command_parser(subparsers=None):
|
| 188 |
+
if subparsers is not None:
|
| 189 |
+
parser = subparsers.add_parser("estimate-memory")
|
| 190 |
+
else:
|
| 191 |
+
parser = CustomArgumentParser(
|
| 192 |
+
description="Model size estimator for fitting a model onto device(e.g. cuda, xpu) memory."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.")
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--library_name",
|
| 198 |
+
type=str,
|
| 199 |
+
help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.",
|
| 200 |
+
choices=["timm", "transformers"],
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--dtypes",
|
| 204 |
+
type=str,
|
| 205 |
+
nargs="+",
|
| 206 |
+
default=["float32", "float16", "int8", "int4"],
|
| 207 |
+
help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`",
|
| 208 |
+
choices=["float32", "float16", "int8", "int4"],
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--trust_remote_code",
|
| 212 |
+
action="store_true",
|
| 213 |
+
help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag
|
| 214 |
+
should only be used for repositories you trust and in which you have read the code, as it will execute
|
| 215 |
+
code present on the Hub on your local machine.""",
|
| 216 |
+
default=False,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if subparsers is not None:
|
| 220 |
+
parser.set_defaults(func=estimate_command)
|
| 221 |
+
return parser
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: Optional[str] = None) -> dict:
|
| 225 |
+
"""
|
| 226 |
+
Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of
|
| 227 |
+
1.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
bytes (`int`):
|
| 231 |
+
The size of the model being trained.
|
| 232 |
+
mixed_precision (`str`):
|
| 233 |
+
The mixed precision that would be ran.
|
| 234 |
+
msamp_config (`str`):
|
| 235 |
+
The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`.
|
| 236 |
+
"""
|
| 237 |
+
memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1}
|
| 238 |
+
fp32_size = bytes
|
| 239 |
+
fp16_size = bytes // 2
|
| 240 |
+
|
| 241 |
+
if mixed_precision == "float32":
|
| 242 |
+
memory_sizes["model"] = fp32_size
|
| 243 |
+
memory_sizes["gradients"] = fp32_size
|
| 244 |
+
memory_sizes["optimizer"] = fp32_size * 2
|
| 245 |
+
memory_sizes["step"] = fp32_size * 4
|
| 246 |
+
elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None):
|
| 247 |
+
# With native `TransformersEngine`, there is no memory savings with FP8
|
| 248 |
+
# With mixed precision training, the model has weights stored
|
| 249 |
+
# in FP16 and FP32
|
| 250 |
+
memory_sizes["model"] = fp32_size
|
| 251 |
+
# 1.5 from weight gradient + computation (GEMM)
|
| 252 |
+
memory_sizes["gradients"] = fp32_size + fp16_size
|
| 253 |
+
# 2x from optimizer states
|
| 254 |
+
memory_sizes["optimizer"] = fp32_size * 2 # Optimizer states
|
| 255 |
+
memory_sizes["step"] = memory_sizes["optimizer"]
|
| 256 |
+
return memory_sizes
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def gather_data(args):
|
| 260 |
+
"Creates an empty model and gathers the data for the sizes"
|
| 261 |
+
try:
|
| 262 |
+
model = create_empty_model(
|
| 263 |
+
args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code
|
| 264 |
+
)
|
| 265 |
+
except (RuntimeError, OSError) as e:
|
| 266 |
+
library = check_has_model(e)
|
| 267 |
+
if library != "unknown":
|
| 268 |
+
raise RuntimeError(
|
| 269 |
+
f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo."
|
| 270 |
+
)
|
| 271 |
+
raise e
|
| 272 |
+
|
| 273 |
+
total_size, largest_layer = calculate_maximum_sizes(model)
|
| 274 |
+
|
| 275 |
+
data = []
|
| 276 |
+
|
| 277 |
+
for dtype in args.dtypes:
|
| 278 |
+
dtype_total_size = total_size
|
| 279 |
+
dtype_largest_layer = largest_layer[0]
|
| 280 |
+
dtype_training_size = estimate_training_usage(dtype_total_size, dtype)
|
| 281 |
+
if dtype == "float16":
|
| 282 |
+
dtype_total_size /= 2
|
| 283 |
+
dtype_largest_layer /= 2
|
| 284 |
+
elif dtype == "int8":
|
| 285 |
+
dtype_total_size /= 4
|
| 286 |
+
dtype_largest_layer /= 4
|
| 287 |
+
elif dtype == "int4":
|
| 288 |
+
dtype_total_size /= 8
|
| 289 |
+
dtype_largest_layer /= 8
|
| 290 |
+
data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size])
|
| 291 |
+
return data
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def estimate_command(args):
|
| 295 |
+
data = gather_data(args)
|
| 296 |
+
for row in data:
|
| 297 |
+
for i, item in enumerate(row):
|
| 298 |
+
if isinstance(item, (int, float)):
|
| 299 |
+
row[i] = convert_bytes(item)
|
| 300 |
+
elif isinstance(item, dict):
|
| 301 |
+
training_usage = max(item.values())
|
| 302 |
+
row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A"
|
| 303 |
+
|
| 304 |
+
headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"]
|
| 305 |
+
|
| 306 |
+
title = f"Memory Usage for loading `{args.model_name}`"
|
| 307 |
+
table = create_ascii_table(headers, data, title)
|
| 308 |
+
print(table)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def main():
|
| 312 |
+
parser = estimate_command_parser()
|
| 313 |
+
args = parser.parse_args()
|
| 314 |
+
estimate_command(args)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
accelerate/commands/launch.py
ADDED
|
@@ -0,0 +1,1415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import importlib
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
import subprocess
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
| 28 |
+
from accelerate.commands.config.config_args import SageMakerConfig
|
| 29 |
+
from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
|
| 30 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 31 |
+
from accelerate.state import get_int_from_env
|
| 32 |
+
from accelerate.utils import (
|
| 33 |
+
ComputeEnvironment,
|
| 34 |
+
DistributedType,
|
| 35 |
+
PrepareForLaunch,
|
| 36 |
+
_filter_args,
|
| 37 |
+
check_cuda_p2p_ib_support,
|
| 38 |
+
convert_dict_to_env_variables,
|
| 39 |
+
is_bf16_available,
|
| 40 |
+
is_deepspeed_available,
|
| 41 |
+
is_hpu_available,
|
| 42 |
+
is_mlu_available,
|
| 43 |
+
is_musa_available,
|
| 44 |
+
is_neuron_available,
|
| 45 |
+
is_npu_available,
|
| 46 |
+
is_rich_available,
|
| 47 |
+
is_sagemaker_available,
|
| 48 |
+
is_sdaa_available,
|
| 49 |
+
is_torch_xla_available,
|
| 50 |
+
is_xpu_available,
|
| 51 |
+
patch_environment,
|
| 52 |
+
prepare_deepspeed_cmd_env,
|
| 53 |
+
prepare_multi_gpu_env,
|
| 54 |
+
prepare_sagemager_args_inputs,
|
| 55 |
+
prepare_simple_launcher_cmd_env,
|
| 56 |
+
prepare_tpu,
|
| 57 |
+
str_to_bool,
|
| 58 |
+
)
|
| 59 |
+
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if is_rich_available():
|
| 63 |
+
from rich import get_console
|
| 64 |
+
from rich.logging import RichHandler
|
| 65 |
+
|
| 66 |
+
FORMAT = "%(message)s"
|
| 67 |
+
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
options_to_group = {
|
| 74 |
+
"multi_gpu": "Distributed GPUs",
|
| 75 |
+
"tpu": "TPU",
|
| 76 |
+
"use_deepspeed": "DeepSpeed Arguments",
|
| 77 |
+
"use_fsdp": "FSDP Arguments",
|
| 78 |
+
"use_megatron_lm": "Megatron-LM Arguments",
|
| 79 |
+
"fp8_backend": "FP8 Arguments",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def clean_option(option):
|
| 84 |
+
"Finds all cases of - after the first two characters and changes them to _"
|
| 85 |
+
if "fp8_backend" in option:
|
| 86 |
+
option = "--fp8_backend"
|
| 87 |
+
if option.startswith("--"):
|
| 88 |
+
return option[2:].replace("-", "_")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CustomHelpFormatter(argparse.HelpFormatter):
|
| 92 |
+
"""
|
| 93 |
+
This is a custom help formatter that will hide all arguments that are not used in the command line when the help is
|
| 94 |
+
called. This is useful for the case where the user is using a specific platform and only wants to see the arguments
|
| 95 |
+
for that platform.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, *args, **kwargs):
|
| 99 |
+
super().__init__(*args, **kwargs)
|
| 100 |
+
self.titles = [
|
| 101 |
+
"Hardware Selection Arguments",
|
| 102 |
+
"Resource Selection Arguments",
|
| 103 |
+
"Training Paradigm Arguments",
|
| 104 |
+
"positional arguments",
|
| 105 |
+
"optional arguments",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
def add_argument(self, action: argparse.Action):
|
| 109 |
+
if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]:
|
| 110 |
+
args = sys.argv[2:]
|
| 111 |
+
else:
|
| 112 |
+
args = sys.argv[1:]
|
| 113 |
+
|
| 114 |
+
if len(args) > 1:
|
| 115 |
+
args = list(map(clean_option, args))
|
| 116 |
+
used_platforms = [arg for arg in args if arg in options_to_group.keys()]
|
| 117 |
+
used_titles = [options_to_group[o] for o in used_platforms]
|
| 118 |
+
if action.container.title not in self.titles + used_titles:
|
| 119 |
+
action.help = argparse.SUPPRESS
|
| 120 |
+
elif action.container.title == "Hardware Selection Arguments":
|
| 121 |
+
if set(action.option_strings).isdisjoint(set(args)):
|
| 122 |
+
action.help = argparse.SUPPRESS
|
| 123 |
+
else:
|
| 124 |
+
action.help = action.help + " (currently selected)"
|
| 125 |
+
elif action.container.title == "Training Paradigm Arguments":
|
| 126 |
+
if set(action.option_strings).isdisjoint(set(args)):
|
| 127 |
+
action.help = argparse.SUPPRESS
|
| 128 |
+
else:
|
| 129 |
+
action.help = action.help + " (currently selected)"
|
| 130 |
+
|
| 131 |
+
action.option_strings = [s for s in action.option_strings if "-" not in s[2:]]
|
| 132 |
+
super().add_argument(action)
|
| 133 |
+
|
| 134 |
+
def end_section(self):
|
| 135 |
+
if len(self._current_section.items) < 2:
|
| 136 |
+
self._current_section.items = []
|
| 137 |
+
self._current_section.heading = ""
|
| 138 |
+
super().end_section()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def launch_command_parser(subparsers=None):
|
| 142 |
+
description = "Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)"
|
| 143 |
+
if subparsers is not None:
|
| 144 |
+
parser = subparsers.add_parser(
|
| 145 |
+
"launch", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
parser = CustomArgumentParser(
|
| 149 |
+
"Accelerate launch command",
|
| 150 |
+
description=description,
|
| 151 |
+
add_help=False,
|
| 152 |
+
allow_abbrev=False,
|
| 153 |
+
formatter_class=CustomHelpFormatter,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.")
|
| 157 |
+
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--config_file",
|
| 160 |
+
default=None,
|
| 161 |
+
help="The config file to use for the default values in the launching script.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--quiet",
|
| 165 |
+
"-q",
|
| 166 |
+
action="store_true",
|
| 167 |
+
help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)",
|
| 168 |
+
)
|
| 169 |
+
# Hardware selection arguments
|
| 170 |
+
hardware_args = parser.add_argument_group(
|
| 171 |
+
"Hardware Selection Arguments", "Arguments for selecting the hardware to be used."
|
| 172 |
+
)
|
| 173 |
+
hardware_args.add_argument(
|
| 174 |
+
"--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU."
|
| 175 |
+
)
|
| 176 |
+
hardware_args.add_argument(
|
| 177 |
+
"--multi_gpu",
|
| 178 |
+
default=False,
|
| 179 |
+
action="store_true",
|
| 180 |
+
help="Whether or not this should launch a distributed GPU training.",
|
| 181 |
+
)
|
| 182 |
+
hardware_args.add_argument(
|
| 183 |
+
"--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
|
| 184 |
+
)
|
| 185 |
+
# Resource selection arguments
|
| 186 |
+
resource_args = parser.add_argument_group(
|
| 187 |
+
"Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
|
| 188 |
+
)
|
| 189 |
+
resource_args.add_argument(
|
| 190 |
+
"--mixed_precision",
|
| 191 |
+
type=str,
|
| 192 |
+
choices=["no", "fp16", "bf16", "fp8"],
|
| 193 |
+
help="Whether or not to use mixed precision training. "
|
| 194 |
+
"Choose between FP16 and BF16 (bfloat16) training. "
|
| 195 |
+
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
|
| 196 |
+
)
|
| 197 |
+
resource_args.add_argument(
|
| 198 |
+
"--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel."
|
| 199 |
+
)
|
| 200 |
+
resource_args.add_argument(
|
| 201 |
+
"--num_machines", type=int, default=None, help="The total number of machines used in this training."
|
| 202 |
+
)
|
| 203 |
+
resource_args.add_argument(
|
| 204 |
+
"--num_cpu_threads_per_process",
|
| 205 |
+
type=int,
|
| 206 |
+
default=None,
|
| 207 |
+
help="The number of CPU threads per process. Can be tuned for optimal performance.",
|
| 208 |
+
)
|
| 209 |
+
resource_args.add_argument(
|
| 210 |
+
"--enable_cpu_affinity",
|
| 211 |
+
default=False,
|
| 212 |
+
action="store_true",
|
| 213 |
+
help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.",
|
| 214 |
+
)
|
| 215 |
+
# Dynamo arguments
|
| 216 |
+
resource_args.add_argument(
|
| 217 |
+
"--dynamo_backend",
|
| 218 |
+
type=str,
|
| 219 |
+
choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS],
|
| 220 |
+
help="Choose a backend to optimize your training with dynamo, see more at "
|
| 221 |
+
"https://github.com/pytorch/torchdynamo.",
|
| 222 |
+
)
|
| 223 |
+
resource_args.add_argument(
|
| 224 |
+
"--dynamo_mode",
|
| 225 |
+
type=str,
|
| 226 |
+
default="default",
|
| 227 |
+
choices=TORCH_DYNAMO_MODES,
|
| 228 |
+
help="Choose a mode to optimize your training with dynamo.",
|
| 229 |
+
)
|
| 230 |
+
resource_args.add_argument(
|
| 231 |
+
"--dynamo_use_fullgraph",
|
| 232 |
+
default=False,
|
| 233 |
+
action="store_true",
|
| 234 |
+
help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
|
| 235 |
+
)
|
| 236 |
+
resource_args.add_argument(
|
| 237 |
+
"--dynamo_use_dynamic",
|
| 238 |
+
default=False,
|
| 239 |
+
action="store_true",
|
| 240 |
+
help="Whether to enable dynamic shape tracing.",
|
| 241 |
+
)
|
| 242 |
+
resource_args.add_argument(
|
| 243 |
+
"--dynamo_use_regional_compilation",
|
| 244 |
+
default=False,
|
| 245 |
+
action="store_true",
|
| 246 |
+
help="Whether to enable regional compilation.",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Training Paradigm arguments
|
| 250 |
+
paradigm_args = parser.add_argument_group(
|
| 251 |
+
"Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used."
|
| 252 |
+
)
|
| 253 |
+
paradigm_args.add_argument(
|
| 254 |
+
"--use_deepspeed",
|
| 255 |
+
default=False,
|
| 256 |
+
action="store_true",
|
| 257 |
+
help="Whether to use deepspeed.",
|
| 258 |
+
)
|
| 259 |
+
paradigm_args.add_argument(
|
| 260 |
+
"--use_fsdp",
|
| 261 |
+
default=False,
|
| 262 |
+
action="store_true",
|
| 263 |
+
help="Whether to use fsdp.",
|
| 264 |
+
)
|
| 265 |
+
paradigm_args.add_argument(
|
| 266 |
+
"--use_parallelism_config",
|
| 267 |
+
default=False,
|
| 268 |
+
action="store_true",
|
| 269 |
+
help="Whether to use the parallelism config to configure the N-d distributed training.",
|
| 270 |
+
)
|
| 271 |
+
paradigm_args.add_argument(
|
| 272 |
+
"--use_megatron_lm",
|
| 273 |
+
default=False,
|
| 274 |
+
action="store_true",
|
| 275 |
+
help="Whether to use Megatron-LM.",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# distributed GPU training arguments
|
| 279 |
+
distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.")
|
| 280 |
+
distributed_args.add_argument(
|
| 281 |
+
"--gpu_ids",
|
| 282 |
+
default=None,
|
| 283 |
+
help="What GPUs (by id) should be used for training on this machine as a comma-separated list",
|
| 284 |
+
)
|
| 285 |
+
distributed_args.add_argument(
|
| 286 |
+
"--same_network",
|
| 287 |
+
default=False,
|
| 288 |
+
action="store_true",
|
| 289 |
+
help="Whether all machines used for multinode training exist on the same local network.",
|
| 290 |
+
)
|
| 291 |
+
distributed_args.add_argument(
|
| 292 |
+
"--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched."
|
| 293 |
+
)
|
| 294 |
+
distributed_args.add_argument(
|
| 295 |
+
"--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0."
|
| 296 |
+
)
|
| 297 |
+
distributed_args.add_argument(
|
| 298 |
+
"--main_process_port",
|
| 299 |
+
type=int,
|
| 300 |
+
default=None,
|
| 301 |
+
help="The port to use to communicate with the machine of rank 0.",
|
| 302 |
+
)
|
| 303 |
+
distributed_args.add_argument(
|
| 304 |
+
"-t",
|
| 305 |
+
"--tee",
|
| 306 |
+
default="0",
|
| 307 |
+
type=str,
|
| 308 |
+
help="Tee std streams into a log file and also to console.",
|
| 309 |
+
)
|
| 310 |
+
distributed_args.add_argument(
|
| 311 |
+
"--log_dir",
|
| 312 |
+
type=str,
|
| 313 |
+
default=None,
|
| 314 |
+
help=(
|
| 315 |
+
"Base directory to use for log files when using torchrun/torch.distributed.run as launcher. "
|
| 316 |
+
"Use with --tee to redirect std streams info log files."
|
| 317 |
+
),
|
| 318 |
+
)
|
| 319 |
+
distributed_args.add_argument(
|
| 320 |
+
"--role",
|
| 321 |
+
type=str,
|
| 322 |
+
default="default",
|
| 323 |
+
help="User-defined role for the workers.",
|
| 324 |
+
)
|
| 325 |
+
# Rendezvous related arguments
|
| 326 |
+
distributed_args.add_argument(
|
| 327 |
+
"--rdzv_backend",
|
| 328 |
+
type=str,
|
| 329 |
+
default="static",
|
| 330 |
+
help="The rendezvous method to use, such as 'static' (the default) or 'c10d'",
|
| 331 |
+
)
|
| 332 |
+
distributed_args.add_argument(
|
| 333 |
+
"--rdzv_conf",
|
| 334 |
+
type=str,
|
| 335 |
+
default="",
|
| 336 |
+
help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
|
| 337 |
+
)
|
| 338 |
+
distributed_args.add_argument(
|
| 339 |
+
"--max_restarts",
|
| 340 |
+
type=int,
|
| 341 |
+
default=0,
|
| 342 |
+
help="Maximum number of worker group restarts before failing.",
|
| 343 |
+
)
|
| 344 |
+
distributed_args.add_argument(
|
| 345 |
+
"--monitor_interval",
|
| 346 |
+
type=float,
|
| 347 |
+
default=0.1,
|
| 348 |
+
help="Interval, in seconds, to monitor the state of workers.",
|
| 349 |
+
)
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"-m",
|
| 352 |
+
"--module",
|
| 353 |
+
action="store_true",
|
| 354 |
+
help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.",
|
| 355 |
+
)
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--no_python",
|
| 358 |
+
action="store_true",
|
| 359 |
+
help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.",
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# TPU arguments
|
| 363 |
+
tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.")
|
| 364 |
+
tpu_args.add_argument(
|
| 365 |
+
"--tpu_cluster",
|
| 366 |
+
action="store_true",
|
| 367 |
+
dest="tpu_use_cluster",
|
| 368 |
+
help="Whether to use a GCP TPU pod for training.",
|
| 369 |
+
)
|
| 370 |
+
tpu_args.add_argument(
|
| 371 |
+
"--no_tpu_cluster",
|
| 372 |
+
action="store_false",
|
| 373 |
+
dest="tpu_use_cluster",
|
| 374 |
+
help="Should not be passed explicitly, this is for internal use only.",
|
| 375 |
+
)
|
| 376 |
+
tpu_args.add_argument(
|
| 377 |
+
"--tpu_use_sudo",
|
| 378 |
+
action="store_true",
|
| 379 |
+
help="Whether to use `sudo` when running the TPU training script in each pod.",
|
| 380 |
+
)
|
| 381 |
+
tpu_args.add_argument(
|
| 382 |
+
"--vm",
|
| 383 |
+
type=str,
|
| 384 |
+
action="append",
|
| 385 |
+
help=(
|
| 386 |
+
"List of single Compute VM instance names. "
|
| 387 |
+
"If not provided we assume usage of instance groups. For TPU pods."
|
| 388 |
+
),
|
| 389 |
+
)
|
| 390 |
+
tpu_args.add_argument(
|
| 391 |
+
"--env",
|
| 392 |
+
type=str,
|
| 393 |
+
action="append",
|
| 394 |
+
help="List of environment variables to set on the Compute VM instances. For TPU pods.",
|
| 395 |
+
)
|
| 396 |
+
tpu_args.add_argument(
|
| 397 |
+
"--main_training_function",
|
| 398 |
+
type=str,
|
| 399 |
+
default=None,
|
| 400 |
+
help="The name of the main function to be executed in your script (only for TPU training).",
|
| 401 |
+
)
|
| 402 |
+
tpu_args.add_argument(
|
| 403 |
+
"--downcast_bf16",
|
| 404 |
+
action="store_true",
|
| 405 |
+
help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# DeepSpeed arguments
|
| 409 |
+
deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.")
|
| 410 |
+
deepspeed_args.add_argument(
|
| 411 |
+
"--deepspeed_config_file",
|
| 412 |
+
default=None,
|
| 413 |
+
type=str,
|
| 414 |
+
help="DeepSpeed config file.",
|
| 415 |
+
)
|
| 416 |
+
deepspeed_args.add_argument(
|
| 417 |
+
"--zero_stage",
|
| 418 |
+
default=None,
|
| 419 |
+
type=int,
|
| 420 |
+
help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). "
|
| 421 |
+
"If unspecified, will default to `2`.",
|
| 422 |
+
)
|
| 423 |
+
deepspeed_args.add_argument(
|
| 424 |
+
"--offload_optimizer_device",
|
| 425 |
+
default=None,
|
| 426 |
+
type=str,
|
| 427 |
+
help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
|
| 428 |
+
"If unspecified, will default to 'none'.",
|
| 429 |
+
)
|
| 430 |
+
deepspeed_args.add_argument(
|
| 431 |
+
"--offload_param_device",
|
| 432 |
+
default=None,
|
| 433 |
+
type=str,
|
| 434 |
+
help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). "
|
| 435 |
+
"If unspecified, will default to 'none'.",
|
| 436 |
+
)
|
| 437 |
+
deepspeed_args.add_argument(
|
| 438 |
+
"--offload_optimizer_nvme_path",
|
| 439 |
+
default=None,
|
| 440 |
+
type=str,
|
| 441 |
+
help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
|
| 442 |
+
"If unspecified, will default to 'none'.",
|
| 443 |
+
)
|
| 444 |
+
deepspeed_args.add_argument(
|
| 445 |
+
"--offload_param_nvme_path",
|
| 446 |
+
default=None,
|
| 447 |
+
type=str,
|
| 448 |
+
help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). "
|
| 449 |
+
"If unspecified, will default to 'none'.",
|
| 450 |
+
)
|
| 451 |
+
deepspeed_args.add_argument(
|
| 452 |
+
"--gradient_accumulation_steps",
|
| 453 |
+
default=None,
|
| 454 |
+
type=int,
|
| 455 |
+
help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). "
|
| 456 |
+
"If unspecified, will default to `1`.",
|
| 457 |
+
)
|
| 458 |
+
deepspeed_args.add_argument(
|
| 459 |
+
"--gradient_clipping",
|
| 460 |
+
default=None,
|
| 461 |
+
type=float,
|
| 462 |
+
help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). "
|
| 463 |
+
"If unspecified, will default to `1.0`.",
|
| 464 |
+
)
|
| 465 |
+
deepspeed_args.add_argument(
|
| 466 |
+
"--zero3_init_flag",
|
| 467 |
+
default=None,
|
| 468 |
+
type=str,
|
| 469 |
+
help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. "
|
| 470 |
+
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.",
|
| 471 |
+
)
|
| 472 |
+
deepspeed_args.add_argument(
|
| 473 |
+
"--zero3_save_16bit_model",
|
| 474 |
+
default=None,
|
| 475 |
+
type=str,
|
| 476 |
+
help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
|
| 477 |
+
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.",
|
| 478 |
+
)
|
| 479 |
+
deepspeed_args.add_argument(
|
| 480 |
+
"--deepspeed_hostfile",
|
| 481 |
+
default=None,
|
| 482 |
+
type=str,
|
| 483 |
+
help="DeepSpeed hostfile for configuring multi-node compute resources.",
|
| 484 |
+
)
|
| 485 |
+
deepspeed_args.add_argument(
|
| 486 |
+
"--deepspeed_exclusion_filter",
|
| 487 |
+
default=None,
|
| 488 |
+
type=str,
|
| 489 |
+
help="DeepSpeed exclusion filter string when using multi-node setup.",
|
| 490 |
+
)
|
| 491 |
+
deepspeed_args.add_argument(
|
| 492 |
+
"--deepspeed_inclusion_filter",
|
| 493 |
+
default=None,
|
| 494 |
+
type=str,
|
| 495 |
+
help="DeepSpeed inclusion filter string when using multi-node setup.",
|
| 496 |
+
)
|
| 497 |
+
deepspeed_args.add_argument(
|
| 498 |
+
"--deepspeed_multinode_launcher",
|
| 499 |
+
default=None,
|
| 500 |
+
type=str,
|
| 501 |
+
help="DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.",
|
| 502 |
+
)
|
| 503 |
+
deepspeed_args.add_argument(
|
| 504 |
+
"--deepspeed_moe_layer_cls_names",
|
| 505 |
+
default=None,
|
| 506 |
+
type=str,
|
| 507 |
+
help="comma-separated list of transformer MoE layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..."
|
| 508 |
+
" (useful only when `use_deepspeed` flag is passed).",
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# fsdp arguments
|
| 512 |
+
fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.")
|
| 513 |
+
fsdp_args.add_argument(
|
| 514 |
+
"--fsdp_version",
|
| 515 |
+
type=str,
|
| 516 |
+
default="1",
|
| 517 |
+
choices=["1", "2"],
|
| 518 |
+
help="FSDP version to use. (useful only when `use_fsdp` flag is passed).",
|
| 519 |
+
)
|
| 520 |
+
fsdp_args.add_argument(
|
| 521 |
+
"--fsdp_offload_params",
|
| 522 |
+
default="false",
|
| 523 |
+
type=str,
|
| 524 |
+
help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).",
|
| 525 |
+
)
|
| 526 |
+
fsdp_args.add_argument(
|
| 527 |
+
"--fsdp_min_num_params",
|
| 528 |
+
type=int,
|
| 529 |
+
default=int(1e8),
|
| 530 |
+
help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).",
|
| 531 |
+
)
|
| 532 |
+
# We enable this for backwards compatibility, throw a warning if this is set in `FullyShardedDataParallelPlugin`
|
| 533 |
+
fsdp_args.add_argument(
|
| 534 |
+
"--fsdp_sharding_strategy",
|
| 535 |
+
type=str,
|
| 536 |
+
default="FULL_SHARD",
|
| 537 |
+
help="FSDP's sharding strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version=1`).",
|
| 538 |
+
)
|
| 539 |
+
fsdp_args.add_argument(
|
| 540 |
+
"--fsdp_reshard_after_forward",
|
| 541 |
+
type=str,
|
| 542 |
+
default="true",
|
| 543 |
+
help="FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).",
|
| 544 |
+
)
|
| 545 |
+
fsdp_args.add_argument(
|
| 546 |
+
"--fsdp_auto_wrap_policy",
|
| 547 |
+
type=str,
|
| 548 |
+
default=None,
|
| 549 |
+
help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).",
|
| 550 |
+
)
|
| 551 |
+
fsdp_args.add_argument(
|
| 552 |
+
"--fsdp_transformer_layer_cls_to_wrap",
|
| 553 |
+
default=None,
|
| 554 |
+
type=str,
|
| 555 |
+
help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
|
| 556 |
+
"(useful only when `use_fsdp` flag is passed).",
|
| 557 |
+
)
|
| 558 |
+
fsdp_args.add_argument(
|
| 559 |
+
"--fsdp_backward_prefetch",
|
| 560 |
+
default=None,
|
| 561 |
+
type=str,
|
| 562 |
+
help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
|
| 563 |
+
)
|
| 564 |
+
fsdp_args.add_argument(
|
| 565 |
+
"--fsdp_state_dict_type",
|
| 566 |
+
default=None,
|
| 567 |
+
type=str,
|
| 568 |
+
help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).",
|
| 569 |
+
)
|
| 570 |
+
fsdp_args.add_argument(
|
| 571 |
+
"--fsdp_forward_prefetch",
|
| 572 |
+
default="false",
|
| 573 |
+
type=str,
|
| 574 |
+
help="If True, then FSDP explicitly prefetches the next upcoming "
|
| 575 |
+
"all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).",
|
| 576 |
+
)
|
| 577 |
+
fsdp_args.add_argument(
|
| 578 |
+
"--fsdp_use_orig_params",
|
| 579 |
+
default="true",
|
| 580 |
+
type=str,
|
| 581 |
+
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters."
|
| 582 |
+
" (useful only when `use_fsdp` flag is passed).",
|
| 583 |
+
)
|
| 584 |
+
fsdp_args.add_argument(
|
| 585 |
+
"--fsdp_cpu_ram_efficient_loading",
|
| 586 |
+
default="true",
|
| 587 |
+
type=str,
|
| 588 |
+
help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
|
| 589 |
+
"Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. "
|
| 590 |
+
"(useful only when `use_fsdp` flag is passed).",
|
| 591 |
+
)
|
| 592 |
+
fsdp_args.add_argument(
|
| 593 |
+
"--fsdp_sync_module_states",
|
| 594 |
+
default="true",
|
| 595 |
+
type=str,
|
| 596 |
+
help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
|
| 597 |
+
" (useful only when `use_fsdp` flag is passed).",
|
| 598 |
+
)
|
| 599 |
+
fsdp_args.add_argument(
|
| 600 |
+
"--fsdp_activation_checkpointing",
|
| 601 |
+
default="false",
|
| 602 |
+
type=str,
|
| 603 |
+
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# megatron_lm args
|
| 607 |
+
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
|
| 608 |
+
megatron_lm_args.add_argument(
|
| 609 |
+
"--megatron_lm_tp_degree",
|
| 610 |
+
type=int,
|
| 611 |
+
default=1,
|
| 612 |
+
help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).",
|
| 613 |
+
)
|
| 614 |
+
megatron_lm_args.add_argument(
|
| 615 |
+
"--megatron_lm_use_custom_fsdp",
|
| 616 |
+
type=bool,
|
| 617 |
+
default=False,
|
| 618 |
+
help="Whether to use custom FSDP. (useful only when `use_megatron_lm` flag is passed).",
|
| 619 |
+
)
|
| 620 |
+
megatron_lm_args.add_argument(
|
| 621 |
+
"--megatron_lm_no_load_optim",
|
| 622 |
+
type=bool,
|
| 623 |
+
default=False,
|
| 624 |
+
help="Whether to not load optimizer. (useful only when `use_megatron_lm` flag is passed).",
|
| 625 |
+
)
|
| 626 |
+
megatron_lm_args.add_argument(
|
| 627 |
+
"--megatron_lm_eod_mask_loss",
|
| 628 |
+
type=bool,
|
| 629 |
+
default=False,
|
| 630 |
+
help="Whether to use eod mask loss. (useful only when `use_megatron_lm` flag is passed).",
|
| 631 |
+
)
|
| 632 |
+
megatron_lm_args.add_argument(
|
| 633 |
+
"--megatron_lm_overlap_cpu_optimizer_d2h_h2d",
|
| 634 |
+
type=bool,
|
| 635 |
+
default=False,
|
| 636 |
+
help="Whether to overlap CPU optimizer step, gradients D2H and updated parameters H2D. (useful only when `use_megatron_lm` flag is passed).",
|
| 637 |
+
)
|
| 638 |
+
megatron_lm_args.add_argument(
|
| 639 |
+
"--megatron_lm_no_save_optim",
|
| 640 |
+
type=bool,
|
| 641 |
+
default=False,
|
| 642 |
+
help="Whether to not save optimizer. (useful only when `use_megatron_lm` flag is passed).",
|
| 643 |
+
)
|
| 644 |
+
megatron_lm_args.add_argument(
|
| 645 |
+
"--megatron_lm_optimizer_cpu_offload",
|
| 646 |
+
type=bool,
|
| 647 |
+
default=False,
|
| 648 |
+
help="Whether to use CPU offload for optimizer. (useful only when `use_megatron_lm` flag is passed).",
|
| 649 |
+
)
|
| 650 |
+
megatron_lm_args.add_argument(
|
| 651 |
+
"--megatron_lm_use_precision_aware_optimizer",
|
| 652 |
+
type=bool,
|
| 653 |
+
default=False,
|
| 654 |
+
help="Whether to use precision aware optimizer. (useful only when `use_megatron_lm` flag is passed).",
|
| 655 |
+
)
|
| 656 |
+
megatron_lm_args.add_argument(
|
| 657 |
+
"--megatron_lm_decoder_last_pipeline_num_layers",
|
| 658 |
+
type=int,
|
| 659 |
+
default=None,
|
| 660 |
+
help="Megatron-LM's decoder last pipeline number of layers, default None is even split of transformer layers across all pipeline stages.",
|
| 661 |
+
)
|
| 662 |
+
megatron_lm_args.add_argument(
|
| 663 |
+
"--megatron_lm_pp_degree",
|
| 664 |
+
type=int,
|
| 665 |
+
default=1,
|
| 666 |
+
help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).",
|
| 667 |
+
)
|
| 668 |
+
megatron_lm_args.add_argument(
|
| 669 |
+
"--megatron_lm_num_micro_batches",
|
| 670 |
+
type=int,
|
| 671 |
+
default=None,
|
| 672 |
+
help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).",
|
| 673 |
+
)
|
| 674 |
+
megatron_lm_args.add_argument(
|
| 675 |
+
"--megatron_lm_sequence_parallelism",
|
| 676 |
+
default=None,
|
| 677 |
+
type=str,
|
| 678 |
+
help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. "
|
| 679 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 680 |
+
)
|
| 681 |
+
megatron_lm_args.add_argument(
|
| 682 |
+
"--megatron_lm_recompute_activations",
|
| 683 |
+
default=None,
|
| 684 |
+
type=str,
|
| 685 |
+
help="Decides Whether (true|false) to enable Selective Activation Recomputation. "
|
| 686 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 687 |
+
)
|
| 688 |
+
megatron_lm_args.add_argument(
|
| 689 |
+
"--megatron_lm_use_distributed_optimizer",
|
| 690 |
+
default=None,
|
| 691 |
+
type=str,
|
| 692 |
+
help="Decides Whether (true|false) to use distributed optimizer "
|
| 693 |
+
"which shards optimizer state and gradients across Data Pralellel (DP) ranks. "
|
| 694 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 695 |
+
)
|
| 696 |
+
megatron_lm_args.add_argument(
|
| 697 |
+
"--megatron_lm_gradient_clipping",
|
| 698 |
+
default=1.0,
|
| 699 |
+
type=float,
|
| 700 |
+
help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). "
|
| 701 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 702 |
+
)
|
| 703 |
+
megatron_lm_args.add_argument(
|
| 704 |
+
"--megatron_lm_recompute_granularity",
|
| 705 |
+
default=None,
|
| 706 |
+
type=str,
|
| 707 |
+
help="Megatron-LM's recompute granularity (full, selective). "
|
| 708 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 709 |
+
)
|
| 710 |
+
megatron_lm_args.add_argument(
|
| 711 |
+
"--megatron_lm_recompute_method",
|
| 712 |
+
default=None,
|
| 713 |
+
type=str,
|
| 714 |
+
help="Megatron-LM's recompute method (uniform, block). (useful only when `use_megatron_lm` flag is passed).",
|
| 715 |
+
)
|
| 716 |
+
megatron_lm_args.add_argument(
|
| 717 |
+
"--megatron_lm_recompute_num_layers",
|
| 718 |
+
default=None,
|
| 719 |
+
type=int,
|
| 720 |
+
help="Megatron-LM's number of layers to recompute. (useful only when `use_megatron_lm` flag is passed).",
|
| 721 |
+
)
|
| 722 |
+
megatron_lm_args.add_argument(
|
| 723 |
+
"--megatron_lm_attention_backend",
|
| 724 |
+
default=None,
|
| 725 |
+
type=str,
|
| 726 |
+
help="Decides Whether (true|false) to enable attention backend. "
|
| 727 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 728 |
+
)
|
| 729 |
+
megatron_lm_args.add_argument(
|
| 730 |
+
"--megatron_lm_expert_model_parallel_size",
|
| 731 |
+
default=None,
|
| 732 |
+
type=int,
|
| 733 |
+
help="Megatron-LM's expert model parallel size. (useful only when `use_megatron_lm` flag is passed).",
|
| 734 |
+
)
|
| 735 |
+
megatron_lm_args.add_argument(
|
| 736 |
+
"--megatron_lm_context_parallel_size",
|
| 737 |
+
default=None,
|
| 738 |
+
type=int,
|
| 739 |
+
help="Megatron-LM's context parallel size. (useful only when `use_megatron_lm` flag is passed).",
|
| 740 |
+
)
|
| 741 |
+
megatron_lm_args.add_argument(
|
| 742 |
+
"--megatron_lm_attention_dropout",
|
| 743 |
+
default=None,
|
| 744 |
+
type=float,
|
| 745 |
+
help="Megatron-LM's attention dropout rate. (useful only when `use_megatron_lm` flag is passed).",
|
| 746 |
+
)
|
| 747 |
+
megatron_lm_args.add_argument(
|
| 748 |
+
"--megatron_lm_hidden_dropout",
|
| 749 |
+
default=None,
|
| 750 |
+
type=float,
|
| 751 |
+
help="Megatron-LM's hidden dropout rate. (useful only when `use_megatron_lm` flag is passed).",
|
| 752 |
+
)
|
| 753 |
+
megatron_lm_args.add_argument(
|
| 754 |
+
"--megatron_lm_attention_softmax_in_fp32",
|
| 755 |
+
default=None,
|
| 756 |
+
type=str,
|
| 757 |
+
help="Decides Whether (true|false) to use fp32 for attention softmax. "
|
| 758 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 759 |
+
)
|
| 760 |
+
megatron_lm_args.add_argument(
|
| 761 |
+
"--megatron_lm_expert_tensor_parallel_size",
|
| 762 |
+
default=None,
|
| 763 |
+
type=int,
|
| 764 |
+
help="Megatron-LM's expert tensor parallel size. (useful only when `use_megatron_lm` flag is passed).",
|
| 765 |
+
)
|
| 766 |
+
megatron_lm_args.add_argument(
|
| 767 |
+
"--megatron_lm_calculate_per_token_loss",
|
| 768 |
+
default=None,
|
| 769 |
+
type=str,
|
| 770 |
+
help="Decides Whether (true|false) to calculate per token loss. "
|
| 771 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 772 |
+
)
|
| 773 |
+
megatron_lm_args.add_argument(
|
| 774 |
+
"--megatron_lm_use_rotary_position_embeddings",
|
| 775 |
+
default=None,
|
| 776 |
+
type=str,
|
| 777 |
+
help="Decides Whether (true|false) to use rotary position embeddings. "
|
| 778 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
# FP8 arguments
|
| 782 |
+
fp8_args = parser.add_argument_group(
|
| 783 |
+
"FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)"
|
| 784 |
+
)
|
| 785 |
+
fp8_args.add_argument(
|
| 786 |
+
"--fp8_backend",
|
| 787 |
+
type=str,
|
| 788 |
+
choices=["ao", "te", "msamp"],
|
| 789 |
+
help="Choose a backend to train with FP8 (ao: torchao, te: TransformerEngine, msamp: MS-AMP)",
|
| 790 |
+
)
|
| 791 |
+
fp8_args.add_argument(
|
| 792 |
+
"--fp8_use_autocast_during_eval",
|
| 793 |
+
default=False,
|
| 794 |
+
action="store_true",
|
| 795 |
+
help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.",
|
| 796 |
+
)
|
| 797 |
+
fp8_args.add_argument(
|
| 798 |
+
"--fp8_margin",
|
| 799 |
+
type=int,
|
| 800 |
+
default=0,
|
| 801 |
+
help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).",
|
| 802 |
+
)
|
| 803 |
+
fp8_args.add_argument(
|
| 804 |
+
"--fp8_interval",
|
| 805 |
+
type=int,
|
| 806 |
+
default=1,
|
| 807 |
+
help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).",
|
| 808 |
+
)
|
| 809 |
+
fp8_args.add_argument(
|
| 810 |
+
"--fp8_format",
|
| 811 |
+
type=str,
|
| 812 |
+
default="HYBRID",
|
| 813 |
+
choices=["HYBRID", "E4M3", "E5M2"],
|
| 814 |
+
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
|
| 815 |
+
)
|
| 816 |
+
fp8_args.add_argument(
|
| 817 |
+
"--fp8_amax_history_len",
|
| 818 |
+
type=int,
|
| 819 |
+
default=1024,
|
| 820 |
+
help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).",
|
| 821 |
+
)
|
| 822 |
+
fp8_args.add_argument(
|
| 823 |
+
"--fp8_amax_compute_algo",
|
| 824 |
+
type=str,
|
| 825 |
+
default="most_recent",
|
| 826 |
+
choices=["max", "most_recent"],
|
| 827 |
+
help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).",
|
| 828 |
+
)
|
| 829 |
+
fp8_args.add_argument(
|
| 830 |
+
"--fp8_override_linear_precision",
|
| 831 |
+
type=lambda x: tuple(map(str_to_bool, x.split(","))),
|
| 832 |
+
default=(False, False, False),
|
| 833 |
+
help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-separated string of booleans (useful only when `--fp8_backend=te` is passed).",
|
| 834 |
+
)
|
| 835 |
+
fp8_args.add_argument(
|
| 836 |
+
"--fp8_opt_level",
|
| 837 |
+
type=str,
|
| 838 |
+
default="O2",
|
| 839 |
+
choices=["O1", "O2"],
|
| 840 |
+
help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
|
| 841 |
+
)
|
| 842 |
+
fp8_args.add_argument(
|
| 843 |
+
"--fp8_enable_fsdp_float8_all_gather",
|
| 844 |
+
default="true",
|
| 845 |
+
type=str_to_bool,
|
| 846 |
+
help="Whether to enable FSDP2 float8 all gather (useful only when `--fp8_backend=ao` is passed).",
|
| 847 |
+
)
|
| 848 |
+
fp8_args.add_argument(
|
| 849 |
+
"--fp8_pad_inner_dim",
|
| 850 |
+
default="true",
|
| 851 |
+
type=str_to_bool,
|
| 852 |
+
help="Whether to pad the inner dimension for FP8 GEMMs (useful only when `--fp8_backend=ao` is passed).",
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# AWS arguments
|
| 856 |
+
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
|
| 857 |
+
aws_args.add_argument(
|
| 858 |
+
"--aws_access_key_id",
|
| 859 |
+
type=str,
|
| 860 |
+
default=None,
|
| 861 |
+
help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job",
|
| 862 |
+
)
|
| 863 |
+
aws_args.add_argument(
|
| 864 |
+
"--aws_secret_access_key",
|
| 865 |
+
type=str,
|
| 866 |
+
default=None,
|
| 867 |
+
help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.",
|
| 868 |
+
)
|
| 869 |
+
parser.add_argument(
|
| 870 |
+
"--debug",
|
| 871 |
+
action="store_true",
|
| 872 |
+
help="Whether to print out the torch.distributed stack trace when something fails.",
|
| 873 |
+
)
|
| 874 |
+
parser.add_argument(
|
| 875 |
+
"training_script",
|
| 876 |
+
type=str,
|
| 877 |
+
help=(
|
| 878 |
+
"The full path to the script to be launched in parallel, followed by all the arguments for the training "
|
| 879 |
+
"script."
|
| 880 |
+
),
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
# MPI arguments
|
| 884 |
+
mpirun_args = parser.add_argument_group("MPI Arguments", "Arguments related to mpirun for Multi-CPU")
|
| 885 |
+
mpirun_args.add_argument(
|
| 886 |
+
"--mpirun_hostfile",
|
| 887 |
+
type=str,
|
| 888 |
+
default=None,
|
| 889 |
+
help="Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will "
|
| 890 |
+
"get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.",
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# ParallelismConfig arguments
|
| 894 |
+
parallelism_config_args = parser.add_argument_group(
|
| 895 |
+
"ParallelismConfig Arguments",
|
| 896 |
+
"Arguments related to the ParallelismConfig used for distributed training.",
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
parallelism_config_args.add_argument(
|
| 900 |
+
"--parallelism_config_dp_replicate_size",
|
| 901 |
+
type=int,
|
| 902 |
+
default=1,
|
| 903 |
+
help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
parallelism_config_args.add_argument(
|
| 907 |
+
"--parallelism_config_dp_shard_size",
|
| 908 |
+
type=int,
|
| 909 |
+
default=1,
|
| 910 |
+
help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
parallelism_config_args.add_argument(
|
| 914 |
+
"--parallelism_config_tp_size",
|
| 915 |
+
type=int,
|
| 916 |
+
default=1,
|
| 917 |
+
help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
parallelism_config_args.add_argument(
|
| 921 |
+
"--parallelism_config_cp_size",
|
| 922 |
+
type=int,
|
| 923 |
+
default=1,
|
| 924 |
+
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
parallelism_config_args.add_argument(
|
| 928 |
+
"--parallelism_config_cp_backend",
|
| 929 |
+
type=str,
|
| 930 |
+
choices=["torch"],
|
| 931 |
+
default="torch",
|
| 932 |
+
help="Context Parallelism backend: torch (FSDP2) or deepspeed (ALST/Ulysses)",
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
parallelism_config_args.add_argument(
|
| 936 |
+
"--parallelism_config_cp_comm_strategy",
|
| 937 |
+
type=str,
|
| 938 |
+
default="allgather",
|
| 939 |
+
help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
parallelism_config_args.add_argument(
|
| 943 |
+
"--parallelism_config_sp_size",
|
| 944 |
+
type=int,
|
| 945 |
+
default=1,
|
| 946 |
+
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
parallelism_config_args.add_argument(
|
| 950 |
+
"--parallelism_config_sp_backend",
|
| 951 |
+
type=str,
|
| 952 |
+
choices=["deepspeed"],
|
| 953 |
+
default="deepspeed",
|
| 954 |
+
help="Sequence Parallelism backend: deepspeed (ALST/Ulysses)",
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
parallelism_config_args.add_argument(
|
| 958 |
+
"--parallelism_config_sp_seq_length",
|
| 959 |
+
type=str,
|
| 960 |
+
default=None,
|
| 961 |
+
help="Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `parallelism_config_sp_seq_length_is_variable=True`",
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
parallelism_config_args.add_argument(
|
| 965 |
+
"--parallelism_config_sp_seq_length_is_variable",
|
| 966 |
+
type=bool,
|
| 967 |
+
default=True,
|
| 968 |
+
help="If `True` will work with a sequence length that may change between batches, in which case `parallelism_config_sp_seq_length` value can be set to anything divisible by sp size or remain unset. If `False` then `parallelism_config_sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`.",
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
parallelism_config_args.add_argument(
|
| 972 |
+
"--parallelism_config_sp_attn_implementation",
|
| 973 |
+
type=str,
|
| 974 |
+
default="sdpa",
|
| 975 |
+
help="Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`.",
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
# Other arguments of the training scripts
|
| 979 |
+
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
|
| 980 |
+
|
| 981 |
+
if subparsers is not None:
|
| 982 |
+
parser.set_defaults(func=launch_command)
|
| 983 |
+
return parser
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
def simple_launcher(args):
|
| 987 |
+
cmd, current_env = prepare_simple_launcher_cmd_env(args)
|
| 988 |
+
|
| 989 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 990 |
+
process.wait()
|
| 991 |
+
if process.returncode != 0:
|
| 992 |
+
if not args.quiet:
|
| 993 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
| 994 |
+
else:
|
| 995 |
+
sys.exit(1)
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
def multi_gpu_launcher(args):
|
| 999 |
+
import torch.distributed.run as distrib_run
|
| 1000 |
+
|
| 1001 |
+
current_env = prepare_multi_gpu_env(args)
|
| 1002 |
+
if not check_cuda_p2p_ib_support():
|
| 1003 |
+
message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
|
| 1004 |
+
warn = False
|
| 1005 |
+
if "NCCL_P2P_DISABLE" not in current_env:
|
| 1006 |
+
current_env["NCCL_P2P_DISABLE"] = "1"
|
| 1007 |
+
warn = True
|
| 1008 |
+
if "NCCL_IB_DISABLE" not in current_env:
|
| 1009 |
+
current_env["NCCL_IB_DISABLE"] = "1"
|
| 1010 |
+
warn = True
|
| 1011 |
+
if warn:
|
| 1012 |
+
logger.warning(message)
|
| 1013 |
+
|
| 1014 |
+
debug = getattr(args, "debug", False)
|
| 1015 |
+
args = _filter_args(
|
| 1016 |
+
args,
|
| 1017 |
+
distrib_run.get_args_parser(),
|
| 1018 |
+
["--training_script", args.training_script, "--training_script_args", args.training_script_args],
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
with patch_environment(**current_env):
|
| 1022 |
+
try:
|
| 1023 |
+
distrib_run.run(args)
|
| 1024 |
+
except Exception:
|
| 1025 |
+
if is_rich_available() and debug:
|
| 1026 |
+
console = get_console()
|
| 1027 |
+
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
| 1028 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 1029 |
+
else:
|
| 1030 |
+
raise
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
def deepspeed_launcher(args):
|
| 1034 |
+
import torch.distributed.run as distrib_run
|
| 1035 |
+
|
| 1036 |
+
if not is_deepspeed_available():
|
| 1037 |
+
raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
|
| 1038 |
+
else:
|
| 1039 |
+
from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME
|
| 1040 |
+
|
| 1041 |
+
cmd, current_env = prepare_deepspeed_cmd_env(args)
|
| 1042 |
+
if not check_cuda_p2p_ib_support():
|
| 1043 |
+
message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
|
| 1044 |
+
warn = False
|
| 1045 |
+
if "NCCL_P2P_DISABLE" not in current_env:
|
| 1046 |
+
current_env["NCCL_P2P_DISABLE"] = "1"
|
| 1047 |
+
warn = True
|
| 1048 |
+
if "NCCL_IB_DISABLE" not in current_env:
|
| 1049 |
+
current_env["NCCL_IB_DISABLE"] = "1"
|
| 1050 |
+
warn = True
|
| 1051 |
+
if warn:
|
| 1052 |
+
logger.warning(message)
|
| 1053 |
+
|
| 1054 |
+
if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
|
| 1055 |
+
with open(DEEPSPEED_ENVIRONMENT_NAME, "a") as f:
|
| 1056 |
+
valid_env_items = convert_dict_to_env_variables(current_env)
|
| 1057 |
+
if len(valid_env_items) > 1:
|
| 1058 |
+
f.writelines(valid_env_items)
|
| 1059 |
+
|
| 1060 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 1061 |
+
process.wait()
|
| 1062 |
+
if process.returncode != 0:
|
| 1063 |
+
if not args.quiet:
|
| 1064 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
| 1065 |
+
else:
|
| 1066 |
+
sys.exit(1)
|
| 1067 |
+
else:
|
| 1068 |
+
debug = getattr(args, "debug", False)
|
| 1069 |
+
args = _filter_args(
|
| 1070 |
+
args,
|
| 1071 |
+
distrib_run.get_args_parser(),
|
| 1072 |
+
["--training_script", args.training_script, "--training_script_args", args.training_script_args],
|
| 1073 |
+
)
|
| 1074 |
+
with patch_environment(**current_env):
|
| 1075 |
+
try:
|
| 1076 |
+
distrib_run.run(args)
|
| 1077 |
+
except Exception:
|
| 1078 |
+
if is_rich_available() and debug:
|
| 1079 |
+
console = get_console()
|
| 1080 |
+
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
| 1081 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 1082 |
+
else:
|
| 1083 |
+
raise
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
def tpu_launcher(args):
|
| 1087 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
| 1088 |
+
|
| 1089 |
+
if args.no_python:
|
| 1090 |
+
raise ValueError("--no_python cannot be used with TPU launcher")
|
| 1091 |
+
|
| 1092 |
+
args, current_env = prepare_tpu(args, {})
|
| 1093 |
+
|
| 1094 |
+
if args.module:
|
| 1095 |
+
mod_name = args.training_script
|
| 1096 |
+
else:
|
| 1097 |
+
# Import training_script as a module
|
| 1098 |
+
script_path = Path(args.training_script)
|
| 1099 |
+
sys.path.append(str(script_path.parent.resolve()))
|
| 1100 |
+
mod_name = script_path.stem
|
| 1101 |
+
|
| 1102 |
+
mod = importlib.import_module(mod_name)
|
| 1103 |
+
if not hasattr(mod, args.main_training_function):
|
| 1104 |
+
raise ValueError(
|
| 1105 |
+
f"Your training script should have a function named {args.main_training_function}, or you should pass a "
|
| 1106 |
+
"different value to `--main_training_function`."
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
# Patch sys.argv
|
| 1110 |
+
sys.argv = [mod.__file__] + args.training_script_args
|
| 1111 |
+
|
| 1112 |
+
main_function = getattr(mod, args.main_training_function)
|
| 1113 |
+
with patch_environment(**current_env):
|
| 1114 |
+
xmp.spawn(PrepareForLaunch(main_function), args=())
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
def tpu_pod_launcher(args):
|
| 1118 |
+
from torch_xla.distributed import xla_dist
|
| 1119 |
+
|
| 1120 |
+
current_env = {}
|
| 1121 |
+
args, current_env = prepare_tpu(args, current_env, True)
|
| 1122 |
+
debug = getattr(args, "debug", False)
|
| 1123 |
+
|
| 1124 |
+
training_script = args.training_script
|
| 1125 |
+
training_script_args = args.training_script_args
|
| 1126 |
+
new_args = _filter_args(
|
| 1127 |
+
args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"]
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
if args.tpu_use_sudo:
|
| 1131 |
+
new_cmd = ["sudo"]
|
| 1132 |
+
else:
|
| 1133 |
+
new_cmd = []
|
| 1134 |
+
|
| 1135 |
+
new_cmd += [
|
| 1136 |
+
"accelerate-launch",
|
| 1137 |
+
"--tpu",
|
| 1138 |
+
"--no_tpu_cluster",
|
| 1139 |
+
"--num_machines",
|
| 1140 |
+
"1",
|
| 1141 |
+
"--mixed_precision",
|
| 1142 |
+
"no",
|
| 1143 |
+
"--dynamo_backend",
|
| 1144 |
+
"no",
|
| 1145 |
+
"--num_processes",
|
| 1146 |
+
str(args.num_processes),
|
| 1147 |
+
"--main_training_function",
|
| 1148 |
+
str(args.main_training_function),
|
| 1149 |
+
training_script,
|
| 1150 |
+
] + training_script_args
|
| 1151 |
+
|
| 1152 |
+
new_args.positional = new_cmd
|
| 1153 |
+
bad_flags = ""
|
| 1154 |
+
for arg in vars(new_args):
|
| 1155 |
+
if arg.startswith("docker_"):
|
| 1156 |
+
value = getattr(new_args, arg)
|
| 1157 |
+
if value != "" and value is not None:
|
| 1158 |
+
bad_flags += f'{arg}="{value}"\n'
|
| 1159 |
+
if bad_flags != "":
|
| 1160 |
+
raise ValueError(
|
| 1161 |
+
f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}"
|
| 1162 |
+
)
|
| 1163 |
+
new_args.env = [f"{k}={v}" for k, v in current_env.items()]
|
| 1164 |
+
new_args.env.append("ACCELERATE_IN_TPU_POD=1")
|
| 1165 |
+
try:
|
| 1166 |
+
xla_dist.resolve_and_execute(new_args)
|
| 1167 |
+
except Exception:
|
| 1168 |
+
if is_rich_available() and debug:
|
| 1169 |
+
console = get_console()
|
| 1170 |
+
console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]")
|
| 1171 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 1172 |
+
else:
|
| 1173 |
+
raise
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
|
| 1177 |
+
if not is_sagemaker_available():
|
| 1178 |
+
raise ImportError(
|
| 1179 |
+
"Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`"
|
| 1180 |
+
)
|
| 1181 |
+
if args.module or args.no_python:
|
| 1182 |
+
raise ValueError(
|
| 1183 |
+
"SageMaker requires a python training script file and cannot be used with --module or --no_python"
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
from sagemaker.huggingface import HuggingFace
|
| 1187 |
+
|
| 1188 |
+
args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args)
|
| 1189 |
+
|
| 1190 |
+
huggingface_estimator = HuggingFace(**args)
|
| 1191 |
+
|
| 1192 |
+
huggingface_estimator.fit(inputs=sagemaker_inputs)
|
| 1193 |
+
print(f"You can find your model data at: {huggingface_estimator.model_data}")
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
def _validate_launch_command(args):
|
| 1197 |
+
# Sanity checks
|
| 1198 |
+
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
|
| 1199 |
+
raise ValueError(
|
| 1200 |
+
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
|
| 1201 |
+
)
|
| 1202 |
+
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
|
| 1203 |
+
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
|
| 1204 |
+
|
| 1205 |
+
if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
|
| 1206 |
+
raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
|
| 1207 |
+
|
| 1208 |
+
defaults = None
|
| 1209 |
+
warned = []
|
| 1210 |
+
mp_from_config_flag = False
|
| 1211 |
+
# Get the default from the config file.
|
| 1212 |
+
if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
|
| 1213 |
+
defaults = load_config_from_file(args.config_file)
|
| 1214 |
+
if (
|
| 1215 |
+
not args.multi_gpu
|
| 1216 |
+
and not args.tpu
|
| 1217 |
+
and not args.tpu_use_cluster
|
| 1218 |
+
and not args.use_deepspeed
|
| 1219 |
+
and not args.use_fsdp
|
| 1220 |
+
and not args.use_megatron_lm
|
| 1221 |
+
):
|
| 1222 |
+
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
|
| 1223 |
+
args.multi_gpu = (
|
| 1224 |
+
True
|
| 1225 |
+
if defaults.distributed_type
|
| 1226 |
+
in (
|
| 1227 |
+
DistributedType.MULTI_GPU,
|
| 1228 |
+
DistributedType.MULTI_NPU,
|
| 1229 |
+
DistributedType.MULTI_MLU,
|
| 1230 |
+
DistributedType.MULTI_SDAA,
|
| 1231 |
+
DistributedType.MULTI_MUSA,
|
| 1232 |
+
DistributedType.MULTI_XPU,
|
| 1233 |
+
DistributedType.MULTI_HPU,
|
| 1234 |
+
DistributedType.MULTI_NEURON,
|
| 1235 |
+
)
|
| 1236 |
+
else False
|
| 1237 |
+
)
|
| 1238 |
+
args.tpu = defaults.distributed_type == DistributedType.XLA
|
| 1239 |
+
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
|
| 1240 |
+
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
|
| 1241 |
+
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
|
| 1242 |
+
args.use_parallelism_config = defaults.parallelism_config != {}
|
| 1243 |
+
if args.gpu_ids is None:
|
| 1244 |
+
if defaults.gpu_ids is not None:
|
| 1245 |
+
args.gpu_ids = defaults.gpu_ids
|
| 1246 |
+
else:
|
| 1247 |
+
args.gpu_ids = "all"
|
| 1248 |
+
|
| 1249 |
+
if args.multi_gpu and args.num_machines is None:
|
| 1250 |
+
args.num_machines = defaults.num_machines
|
| 1251 |
+
|
| 1252 |
+
if len(args.gpu_ids.split(",")) < 2 and (args.gpu_ids != "all") and args.multi_gpu and args.num_machines <= 1:
|
| 1253 |
+
raise ValueError(
|
| 1254 |
+
"Less than two GPU ids were configured and tried to run on on multiple GPUs. "
|
| 1255 |
+
"Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`."
|
| 1256 |
+
)
|
| 1257 |
+
if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
|
| 1258 |
+
# Update args with the defaults
|
| 1259 |
+
for name, attr in defaults.__dict__.items():
|
| 1260 |
+
if isinstance(attr, dict):
|
| 1261 |
+
# Copy defaults.somedict.somearg to args.somearg and
|
| 1262 |
+
# defaults.fsdp_config.x to args.fsdp_x
|
| 1263 |
+
for key, value in attr.items():
|
| 1264 |
+
if name == "fsdp_config" and not key.startswith("fsdp"):
|
| 1265 |
+
key = "fsdp_" + key
|
| 1266 |
+
elif name == "fp8_config" and not key.startswith("fp8"):
|
| 1267 |
+
key = "fp8_" + key
|
| 1268 |
+
if hasattr(args, "nondefault") and key not in args.nondefault:
|
| 1269 |
+
setattr(args, key, value)
|
| 1270 |
+
elif (
|
| 1271 |
+
name not in ["compute_environment", "mixed_precision", "distributed_type"]
|
| 1272 |
+
and getattr(args, name, None) is None
|
| 1273 |
+
):
|
| 1274 |
+
# Those args are handled separately
|
| 1275 |
+
setattr(args, name, attr)
|
| 1276 |
+
if not args.debug:
|
| 1277 |
+
args.debug = defaults.debug
|
| 1278 |
+
|
| 1279 |
+
if not args.mixed_precision:
|
| 1280 |
+
if defaults.mixed_precision is None:
|
| 1281 |
+
args.mixed_precision = "no"
|
| 1282 |
+
else:
|
| 1283 |
+
args.mixed_precision = defaults.mixed_precision
|
| 1284 |
+
mp_from_config_flag = True
|
| 1285 |
+
else:
|
| 1286 |
+
native_amp = is_bf16_available(True)
|
| 1287 |
+
if (
|
| 1288 |
+
args.mixed_precision == "bf16"
|
| 1289 |
+
and not native_amp
|
| 1290 |
+
and not (args.tpu and is_torch_xla_available(check_is_tpu=True))
|
| 1291 |
+
):
|
| 1292 |
+
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
|
| 1293 |
+
|
| 1294 |
+
# Silently set the default here
|
| 1295 |
+
if args.dynamo_backend is None:
|
| 1296 |
+
args.dynamo_backend = "no"
|
| 1297 |
+
if args.num_processes == -1:
|
| 1298 |
+
raise ValueError("You need to manually pass in `--num_processes` using this config yaml.")
|
| 1299 |
+
else:
|
| 1300 |
+
if args.num_processes is None:
|
| 1301 |
+
if is_xpu_available():
|
| 1302 |
+
args.num_processes = torch.xpu.device_count()
|
| 1303 |
+
elif is_mlu_available():
|
| 1304 |
+
args.num_processes = torch.mlu.device_count()
|
| 1305 |
+
elif is_sdaa_available():
|
| 1306 |
+
args.num_processes = torch.sdaa.device_count()
|
| 1307 |
+
elif is_musa_available():
|
| 1308 |
+
args.num_processes = torch.musa.device_count()
|
| 1309 |
+
elif is_npu_available():
|
| 1310 |
+
args.num_processes = torch.npu.device_count()
|
| 1311 |
+
elif is_hpu_available():
|
| 1312 |
+
args.num_processes = torch.hpu.device_count()
|
| 1313 |
+
elif is_neuron_available():
|
| 1314 |
+
args.num_processes = torch.neuron.device_count()
|
| 1315 |
+
else:
|
| 1316 |
+
args.num_processes = torch.cuda.device_count()
|
| 1317 |
+
warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
|
| 1318 |
+
if args.debug is None:
|
| 1319 |
+
args.debug = False
|
| 1320 |
+
if (
|
| 1321 |
+
not args.multi_gpu
|
| 1322 |
+
and args.num_processes > 1
|
| 1323 |
+
and (
|
| 1324 |
+
(is_xpu_available() and torch.xpu.device_count() > 1)
|
| 1325 |
+
or (is_npu_available() and torch.npu.device_count() > 1)
|
| 1326 |
+
or (is_hpu_available() and torch.hpu.device_count() > 1)
|
| 1327 |
+
or (is_mlu_available() and torch.mlu.device_count() > 1)
|
| 1328 |
+
or (is_sdaa_available() and torch.sdaa.device_count() > 1)
|
| 1329 |
+
or (is_musa_available() and torch.musa.device_count() > 1)
|
| 1330 |
+
or (is_neuron_available() and torch.neuron.device_count() > 1)
|
| 1331 |
+
or (torch.cuda.is_available() and torch.cuda.device_count() > 1)
|
| 1332 |
+
)
|
| 1333 |
+
):
|
| 1334 |
+
warned.append(
|
| 1335 |
+
"\t\tMore than one GPU was found, enabling multi-GPU training.\n"
|
| 1336 |
+
"\t\tIf this was unintended please pass in `--num_processes=1`."
|
| 1337 |
+
)
|
| 1338 |
+
args.multi_gpu = True
|
| 1339 |
+
if args.num_machines is None:
|
| 1340 |
+
warned.append("\t`--num_machines` was set to a value of `1`")
|
| 1341 |
+
args.num_machines = 1
|
| 1342 |
+
if args.mixed_precision is None:
|
| 1343 |
+
warned.append("\t`--mixed_precision` was set to a value of `'no'`")
|
| 1344 |
+
args.mixed_precision = "no"
|
| 1345 |
+
if not hasattr(args, "use_cpu"):
|
| 1346 |
+
args.use_cpu = args.cpu
|
| 1347 |
+
if args.dynamo_backend is None:
|
| 1348 |
+
warned.append("\t`--dynamo_backend` was set to a value of `'no'`")
|
| 1349 |
+
args.dynamo_backend = "no"
|
| 1350 |
+
if args.debug:
|
| 1351 |
+
logger.debug("Running script in debug mode, expect distributed operations to be slightly slower.")
|
| 1352 |
+
|
| 1353 |
+
is_aws_env_disabled = defaults is None or (
|
| 1354 |
+
defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER
|
| 1355 |
+
)
|
| 1356 |
+
if is_aws_env_disabled and args.num_cpu_threads_per_process is None:
|
| 1357 |
+
args.num_cpu_threads_per_process = get_int_from_env(["OMP_NUM_THREADS"], 1)
|
| 1358 |
+
if args.use_cpu and args.num_processes >= 1 and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0:
|
| 1359 |
+
local_size = get_int_from_env(
|
| 1360 |
+
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"],
|
| 1361 |
+
max(int(args.num_processes / args.num_machines), 1),
|
| 1362 |
+
)
|
| 1363 |
+
import psutil
|
| 1364 |
+
|
| 1365 |
+
threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
|
| 1366 |
+
if threads_per_process > 1:
|
| 1367 |
+
args.num_cpu_threads_per_process = threads_per_process
|
| 1368 |
+
warned.append(
|
| 1369 |
+
f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs"
|
| 1370 |
+
)
|
| 1371 |
+
|
| 1372 |
+
if any(warned):
|
| 1373 |
+
message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n"
|
| 1374 |
+
message += "\n".join(warned)
|
| 1375 |
+
message += (
|
| 1376 |
+
"\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`."
|
| 1377 |
+
)
|
| 1378 |
+
logger.warning(message)
|
| 1379 |
+
return args, defaults, mp_from_config_flag
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
def launch_command(args):
|
| 1383 |
+
args, defaults, mp_from_config_flag = _validate_launch_command(args)
|
| 1384 |
+
# Use the proper launcher
|
| 1385 |
+
if args.use_deepspeed and not args.cpu:
|
| 1386 |
+
args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else []
|
| 1387 |
+
if mp_from_config_flag:
|
| 1388 |
+
args.deepspeed_fields_from_accelerate_config.append("mixed_precision")
|
| 1389 |
+
args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config)
|
| 1390 |
+
deepspeed_launcher(args)
|
| 1391 |
+
elif args.use_fsdp and not args.cpu:
|
| 1392 |
+
multi_gpu_launcher(args)
|
| 1393 |
+
elif args.use_megatron_lm and not args.cpu:
|
| 1394 |
+
multi_gpu_launcher(args)
|
| 1395 |
+
elif args.multi_gpu and not args.cpu:
|
| 1396 |
+
multi_gpu_launcher(args)
|
| 1397 |
+
elif args.tpu and not args.cpu:
|
| 1398 |
+
if args.tpu_use_cluster:
|
| 1399 |
+
tpu_pod_launcher(args)
|
| 1400 |
+
else:
|
| 1401 |
+
tpu_launcher(args)
|
| 1402 |
+
elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
| 1403 |
+
sagemaker_launcher(defaults, args)
|
| 1404 |
+
else:
|
| 1405 |
+
simple_launcher(args)
|
| 1406 |
+
|
| 1407 |
+
|
| 1408 |
+
def main():
|
| 1409 |
+
parser = launch_command_parser()
|
| 1410 |
+
args = parser.parse_args()
|
| 1411 |
+
launch_command(args)
|
| 1412 |
+
|
| 1413 |
+
|
| 1414 |
+
if __name__ == "__main__":
|
| 1415 |
+
main()
|
accelerate/commands/menu/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from .selection_menu import BulletMenu
|
accelerate/commands/menu/cursor.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
A utility for showing and hiding the terminal cursor on Windows and Linux, based on https://github.com/bchao1/bullet
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from contextlib import contextmanager
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Windows only
|
| 25 |
+
if os.name == "nt":
|
| 26 |
+
import ctypes
|
| 27 |
+
import msvcrt # noqa
|
| 28 |
+
|
| 29 |
+
class CursorInfo(ctypes.Structure):
|
| 30 |
+
# _fields is a specific attr expected by ctypes
|
| 31 |
+
_fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def hide_cursor():
|
| 35 |
+
if os.name == "nt":
|
| 36 |
+
ci = CursorInfo()
|
| 37 |
+
handle = ctypes.windll.kernel32.GetStdHandle(-11)
|
| 38 |
+
ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 39 |
+
ci.visible = False
|
| 40 |
+
ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 41 |
+
elif os.name == "posix":
|
| 42 |
+
sys.stdout.write("\033[?25l")
|
| 43 |
+
sys.stdout.flush()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def show_cursor():
|
| 47 |
+
if os.name == "nt":
|
| 48 |
+
ci = CursorInfo()
|
| 49 |
+
handle = ctypes.windll.kernel32.GetStdHandle(-11)
|
| 50 |
+
ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 51 |
+
ci.visible = True
|
| 52 |
+
ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci))
|
| 53 |
+
elif os.name == "posix":
|
| 54 |
+
sys.stdout.write("\033[?25h")
|
| 55 |
+
sys.stdout.flush()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@contextmanager
|
| 59 |
+
def hide():
|
| 60 |
+
"Context manager to hide the terminal cursor"
|
| 61 |
+
try:
|
| 62 |
+
hide_cursor()
|
| 63 |
+
yield
|
| 64 |
+
finally:
|
| 65 |
+
show_cursor()
|
accelerate/commands/menu/helpers.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
A variety of helper functions and constants when dealing with terminal menu choices, based on
|
| 17 |
+
https://github.com/bchao1/bullet
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import enum
|
| 21 |
+
import shutil
|
| 22 |
+
import sys
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
TERMINAL_WIDTH, _ = shutil.get_terminal_size()
|
| 26 |
+
|
| 27 |
+
CURSOR_TO_CHAR = {"UP": "A", "DOWN": "B", "RIGHT": "C", "LEFT": "D"}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Direction(enum.Enum):
|
| 31 |
+
UP = 0
|
| 32 |
+
DOWN = 1
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def forceWrite(content, end=""):
|
| 36 |
+
sys.stdout.write(str(content) + end)
|
| 37 |
+
sys.stdout.flush()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def writeColor(content, color, end=""):
|
| 41 |
+
forceWrite(f"\u001b[{color}m{content}\u001b[0m", end)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def reset_cursor():
|
| 45 |
+
forceWrite("\r")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def move_cursor(num_lines: int, direction: str):
|
| 49 |
+
forceWrite(f"\033[{num_lines}{CURSOR_TO_CHAR[direction.upper()]}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def clear_line():
|
| 53 |
+
forceWrite(" " * TERMINAL_WIDTH)
|
| 54 |
+
reset_cursor()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def linebreak():
|
| 58 |
+
reset_cursor()
|
| 59 |
+
forceWrite("-" * TERMINAL_WIDTH)
|
accelerate/commands/menu/input.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This file contains utilities for handling input from the user and registering specific keys to specific functions,
|
| 17 |
+
based on https://github.com/bchao1/bullet
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from .keymap import KEYMAP, get_character
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def mark(key: str):
|
| 24 |
+
"""
|
| 25 |
+
Mark the function with the key code so it can be handled in the register
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def decorator(func):
|
| 29 |
+
handle = getattr(func, "handle_key", [])
|
| 30 |
+
handle += [key]
|
| 31 |
+
func.handle_key = handle
|
| 32 |
+
return func
|
| 33 |
+
|
| 34 |
+
return decorator
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def mark_multiple(*keys: list[str]):
|
| 38 |
+
"""
|
| 39 |
+
Mark the function with the key codes so it can be handled in the register
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def decorator(func):
|
| 43 |
+
handle = getattr(func, "handle_key", [])
|
| 44 |
+
handle += keys
|
| 45 |
+
func.handle_key = handle
|
| 46 |
+
return func
|
| 47 |
+
|
| 48 |
+
return decorator
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class KeyHandler(type):
|
| 52 |
+
"""
|
| 53 |
+
Metaclass that adds the key handlers to the class
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __new__(cls, name, bases, attrs):
|
| 57 |
+
new_cls = super().__new__(cls, name, bases, attrs)
|
| 58 |
+
if not hasattr(new_cls, "key_handler"):
|
| 59 |
+
new_cls.key_handler = {}
|
| 60 |
+
new_cls.handle_input = KeyHandler.handle_input
|
| 61 |
+
|
| 62 |
+
for value in attrs.values():
|
| 63 |
+
handled_keys = getattr(value, "handle_key", [])
|
| 64 |
+
for key in handled_keys:
|
| 65 |
+
new_cls.key_handler[key] = value
|
| 66 |
+
return new_cls
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def handle_input(cls):
|
| 70 |
+
"Finds and returns the selected character if it exists in the handler"
|
| 71 |
+
char = get_character()
|
| 72 |
+
if char != KEYMAP["undefined"]:
|
| 73 |
+
char = ord(char)
|
| 74 |
+
handler = cls.key_handler.get(char)
|
| 75 |
+
if handler:
|
| 76 |
+
cls.current_selection = char
|
| 77 |
+
return handler(cls)
|
| 78 |
+
else:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def register(cls):
|
| 83 |
+
"""Adds KeyHandler metaclass to the class"""
|
| 84 |
+
return KeyHandler(cls.__name__, cls.__bases__, cls.__dict__.copy())
|
accelerate/commands/menu/keymap.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Utilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import string
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ARROW_KEY_FLAG = 1 << 8
|
| 25 |
+
|
| 26 |
+
KEYMAP = {
|
| 27 |
+
"tab": ord("\t"),
|
| 28 |
+
"newline": ord("\r"),
|
| 29 |
+
"esc": 27,
|
| 30 |
+
"up": 65 + ARROW_KEY_FLAG,
|
| 31 |
+
"down": 66 + ARROW_KEY_FLAG,
|
| 32 |
+
"right": 67 + ARROW_KEY_FLAG,
|
| 33 |
+
"left": 68 + ARROW_KEY_FLAG,
|
| 34 |
+
"mod_int": 91,
|
| 35 |
+
"undefined": sys.maxsize,
|
| 36 |
+
"interrupt": 3,
|
| 37 |
+
"insert": 50,
|
| 38 |
+
"delete": 51,
|
| 39 |
+
"pg_up": 53,
|
| 40 |
+
"pg_down": 54,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
KEYMAP["arrow_begin"] = KEYMAP["up"]
|
| 44 |
+
KEYMAP["arrow_end"] = KEYMAP["left"]
|
| 45 |
+
|
| 46 |
+
if sys.platform == "win32":
|
| 47 |
+
WIN_CH_BUFFER = []
|
| 48 |
+
WIN_KEYMAP = {
|
| 49 |
+
b"\xe0H": KEYMAP["up"] - ARROW_KEY_FLAG,
|
| 50 |
+
b"\x00H": KEYMAP["up"] - ARROW_KEY_FLAG,
|
| 51 |
+
b"\xe0P": KEYMAP["down"] - ARROW_KEY_FLAG,
|
| 52 |
+
b"\x00P": KEYMAP["down"] - ARROW_KEY_FLAG,
|
| 53 |
+
b"\xe0M": KEYMAP["right"] - ARROW_KEY_FLAG,
|
| 54 |
+
b"\x00M": KEYMAP["right"] - ARROW_KEY_FLAG,
|
| 55 |
+
b"\xe0K": KEYMAP["left"] - ARROW_KEY_FLAG,
|
| 56 |
+
b"\x00K": KEYMAP["left"] - ARROW_KEY_FLAG,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
for i in range(10):
|
| 60 |
+
KEYMAP[str(i)] = ord(str(i))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_raw_chars():
|
| 64 |
+
"Gets raw characters from inputs"
|
| 65 |
+
if os.name == "nt":
|
| 66 |
+
import msvcrt
|
| 67 |
+
|
| 68 |
+
encoding = "mbcs"
|
| 69 |
+
# Flush the keyboard buffer
|
| 70 |
+
while msvcrt.kbhit():
|
| 71 |
+
msvcrt.getch()
|
| 72 |
+
if len(WIN_CH_BUFFER) == 0:
|
| 73 |
+
# Read the keystroke
|
| 74 |
+
ch = msvcrt.getch()
|
| 75 |
+
|
| 76 |
+
# If it is a prefix char, get second part
|
| 77 |
+
if ch in (b"\x00", b"\xe0"):
|
| 78 |
+
ch2 = ch + msvcrt.getch()
|
| 79 |
+
# Translate actual Win chars to bullet char types
|
| 80 |
+
try:
|
| 81 |
+
chx = chr(WIN_KEYMAP[ch2])
|
| 82 |
+
WIN_CH_BUFFER.append(chr(KEYMAP["mod_int"]))
|
| 83 |
+
WIN_CH_BUFFER.append(chx)
|
| 84 |
+
if ord(chx) in (
|
| 85 |
+
KEYMAP["insert"] - 1 << 9,
|
| 86 |
+
KEYMAP["delete"] - 1 << 9,
|
| 87 |
+
KEYMAP["pg_up"] - 1 << 9,
|
| 88 |
+
KEYMAP["pg_down"] - 1 << 9,
|
| 89 |
+
):
|
| 90 |
+
WIN_CH_BUFFER.append(chr(126))
|
| 91 |
+
ch = chr(KEYMAP["esc"])
|
| 92 |
+
except KeyError:
|
| 93 |
+
ch = ch2[1]
|
| 94 |
+
else:
|
| 95 |
+
ch = ch.decode(encoding)
|
| 96 |
+
else:
|
| 97 |
+
ch = WIN_CH_BUFFER.pop(0)
|
| 98 |
+
elif os.name == "posix":
|
| 99 |
+
import termios
|
| 100 |
+
import tty
|
| 101 |
+
|
| 102 |
+
fd = sys.stdin.fileno()
|
| 103 |
+
old_settings = termios.tcgetattr(fd)
|
| 104 |
+
try:
|
| 105 |
+
tty.setraw(fd)
|
| 106 |
+
ch = sys.stdin.read(1)
|
| 107 |
+
finally:
|
| 108 |
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
| 109 |
+
return ch
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_character():
|
| 113 |
+
"Gets a character from the keyboard and returns the key code"
|
| 114 |
+
char = get_raw_chars()
|
| 115 |
+
if ord(char) in [KEYMAP["interrupt"], KEYMAP["newline"]]:
|
| 116 |
+
return char
|
| 117 |
+
|
| 118 |
+
elif ord(char) == KEYMAP["esc"]:
|
| 119 |
+
combo = get_raw_chars()
|
| 120 |
+
if ord(combo) == KEYMAP["mod_int"]:
|
| 121 |
+
key = get_raw_chars()
|
| 122 |
+
if ord(key) >= KEYMAP["arrow_begin"] - ARROW_KEY_FLAG and ord(key) <= KEYMAP["arrow_end"] - ARROW_KEY_FLAG:
|
| 123 |
+
return chr(ord(key) + ARROW_KEY_FLAG)
|
| 124 |
+
else:
|
| 125 |
+
return KEYMAP["undefined"]
|
| 126 |
+
else:
|
| 127 |
+
return get_raw_chars()
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
if char in string.printable:
|
| 131 |
+
return char
|
| 132 |
+
else:
|
| 133 |
+
return KEYMAP["undefined"]
|
accelerate/commands/menu/selection_menu.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Main driver for the selection menu, based on https://github.com/bchao1/bullet
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import builtins
|
| 20 |
+
import sys
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from ...utils.imports import _is_package_available
|
| 24 |
+
from . import cursor, input
|
| 25 |
+
from .helpers import Direction, clear_line, forceWrite, linebreak, move_cursor, reset_cursor, writeColor
|
| 26 |
+
from .keymap import KEYMAP
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
in_colab = False
|
| 30 |
+
try:
|
| 31 |
+
in_colab = _is_package_available("google.colab")
|
| 32 |
+
except ModuleNotFoundError:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@input.register
|
| 37 |
+
class BulletMenu:
|
| 38 |
+
"""
|
| 39 |
+
A CLI menu to select a choice from a list of choices using the keyboard.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, prompt: Optional[str] = None, choices: list = []):
|
| 43 |
+
self.position = 0
|
| 44 |
+
self.choices = choices
|
| 45 |
+
self.prompt = prompt
|
| 46 |
+
if sys.platform == "win32":
|
| 47 |
+
self.arrow_char = "*"
|
| 48 |
+
else:
|
| 49 |
+
self.arrow_char = "➔ "
|
| 50 |
+
|
| 51 |
+
def write_choice(self, index, end: str = ""):
|
| 52 |
+
if sys.platform != "win32":
|
| 53 |
+
writeColor(self.choices[index], 32, end)
|
| 54 |
+
else:
|
| 55 |
+
forceWrite(self.choices[index], end)
|
| 56 |
+
|
| 57 |
+
def print_choice(self, index: int):
|
| 58 |
+
"Prints the choice at the given index"
|
| 59 |
+
if index == self.position:
|
| 60 |
+
forceWrite(f" {self.arrow_char} ")
|
| 61 |
+
self.write_choice(index)
|
| 62 |
+
else:
|
| 63 |
+
forceWrite(f" {self.choices[index]}")
|
| 64 |
+
reset_cursor()
|
| 65 |
+
|
| 66 |
+
def move_direction(self, direction: Direction, num_spaces: int = 1):
|
| 67 |
+
"Should not be directly called, used to move a direction of either up or down"
|
| 68 |
+
old_position = self.position
|
| 69 |
+
if direction == Direction.DOWN:
|
| 70 |
+
if self.position + 1 >= len(self.choices):
|
| 71 |
+
return
|
| 72 |
+
self.position += num_spaces
|
| 73 |
+
else:
|
| 74 |
+
if self.position - 1 < 0:
|
| 75 |
+
return
|
| 76 |
+
self.position -= num_spaces
|
| 77 |
+
clear_line()
|
| 78 |
+
self.print_choice(old_position)
|
| 79 |
+
move_cursor(num_spaces, direction.name)
|
| 80 |
+
self.print_choice(self.position)
|
| 81 |
+
|
| 82 |
+
@input.mark(KEYMAP["up"])
|
| 83 |
+
def move_up(self):
|
| 84 |
+
self.move_direction(Direction.UP)
|
| 85 |
+
|
| 86 |
+
@input.mark(KEYMAP["down"])
|
| 87 |
+
def move_down(self):
|
| 88 |
+
self.move_direction(Direction.DOWN)
|
| 89 |
+
|
| 90 |
+
@input.mark(KEYMAP["newline"])
|
| 91 |
+
def select(self):
|
| 92 |
+
move_cursor(len(self.choices) - self.position, "DOWN")
|
| 93 |
+
return self.position
|
| 94 |
+
|
| 95 |
+
@input.mark(KEYMAP["interrupt"])
|
| 96 |
+
def interrupt(self):
|
| 97 |
+
move_cursor(len(self.choices) - self.position, "DOWN")
|
| 98 |
+
raise KeyboardInterrupt
|
| 99 |
+
|
| 100 |
+
@input.mark_multiple(*[KEYMAP[str(number)] for number in range(10)])
|
| 101 |
+
def select_row(self):
|
| 102 |
+
index = int(chr(self.current_selection))
|
| 103 |
+
movement = index - self.position
|
| 104 |
+
if index == self.position:
|
| 105 |
+
return
|
| 106 |
+
if index < len(self.choices):
|
| 107 |
+
if self.position > index:
|
| 108 |
+
self.move_direction(Direction.UP, -movement)
|
| 109 |
+
elif self.position < index:
|
| 110 |
+
self.move_direction(Direction.DOWN, movement)
|
| 111 |
+
else:
|
| 112 |
+
return
|
| 113 |
+
else:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
def run(self, default_choice: int = 0):
|
| 117 |
+
"Start the menu and return the selected choice"
|
| 118 |
+
if self.prompt:
|
| 119 |
+
linebreak()
|
| 120 |
+
forceWrite(self.prompt, "\n")
|
| 121 |
+
if in_colab:
|
| 122 |
+
forceWrite("Please input a choice index (starting from 0), and press enter", "\n")
|
| 123 |
+
else:
|
| 124 |
+
forceWrite("Please select a choice using the arrow or number keys, and selecting with enter", "\n")
|
| 125 |
+
self.position = default_choice
|
| 126 |
+
for i in range(len(self.choices)):
|
| 127 |
+
self.print_choice(i)
|
| 128 |
+
forceWrite("\n")
|
| 129 |
+
move_cursor(len(self.choices) - self.position, "UP")
|
| 130 |
+
with cursor.hide():
|
| 131 |
+
while True:
|
| 132 |
+
if in_colab:
|
| 133 |
+
try:
|
| 134 |
+
choice = int(builtins.input())
|
| 135 |
+
except ValueError:
|
| 136 |
+
choice = default_choice
|
| 137 |
+
else:
|
| 138 |
+
choice = self.handle_input()
|
| 139 |
+
if choice is not None:
|
| 140 |
+
reset_cursor()
|
| 141 |
+
for _ in range(len(self.choices) + 1):
|
| 142 |
+
move_cursor(1, "UP")
|
| 143 |
+
clear_line()
|
| 144 |
+
self.write_choice(choice, "\n")
|
| 145 |
+
return choice
|
accelerate/commands/merge.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 17 |
+
from accelerate.utils import merge_fsdp_weights
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
|
| 21 |
+
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
|
| 22 |
+
|
| 23 |
+
This is a CPU-bound process and requires enough RAM to load the entire model state dict."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def merge_command(args):
|
| 27 |
+
merge_fsdp_weights(
|
| 28 |
+
args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def merge_command_parser(subparsers=None):
|
| 33 |
+
if subparsers is not None:
|
| 34 |
+
parser = subparsers.add_parser("merge-weights", description=description)
|
| 35 |
+
else:
|
| 36 |
+
parser = CustomArgumentParser(description=description)
|
| 37 |
+
|
| 38 |
+
parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"output_path",
|
| 41 |
+
type=str,
|
| 42 |
+
help="The path to save the merged weights. Defaults to the current directory. ",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--unsafe_serialization",
|
| 46 |
+
action="store_true",
|
| 47 |
+
default=False,
|
| 48 |
+
help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--remove_checkpoint_dir",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Whether to remove the checkpoint directory after merging.",
|
| 54 |
+
default=False,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if subparsers is not None:
|
| 58 |
+
parser.set_defaults(func=merge_command)
|
| 59 |
+
return parser
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
parser = merge_command_parser()
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
merge_command(args)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
accelerate/commands/test.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_command_parser(subparsers=None):
|
| 23 |
+
if subparsers is not None:
|
| 24 |
+
parser = subparsers.add_parser("test")
|
| 25 |
+
else:
|
| 26 |
+
parser = argparse.ArgumentParser("Accelerate test command")
|
| 27 |
+
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--config_file",
|
| 30 |
+
default=None,
|
| 31 |
+
help=(
|
| 32 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
| 33 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
| 34 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
| 35 |
+
"with 'huggingface'."
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if subparsers is not None:
|
| 40 |
+
parser.set_defaults(func=test_command)
|
| 41 |
+
return parser
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_command(args):
|
| 45 |
+
script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py")
|
| 46 |
+
|
| 47 |
+
if args.config_file is None:
|
| 48 |
+
test_args = [script_name]
|
| 49 |
+
else:
|
| 50 |
+
test_args = f"--config_file={args.config_file} {script_name}".split()
|
| 51 |
+
|
| 52 |
+
cmd = ["accelerate-launch"] + test_args
|
| 53 |
+
result = execute_subprocess_async(cmd)
|
| 54 |
+
if result.returncode == 0:
|
| 55 |
+
print("Test is a success! You are ready for your distributed training!")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
parser = test_command_parser()
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
test_command(args)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
accelerate/commands/to_fsdp2.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import enum
|
| 18 |
+
import logging
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConversionStatus(enum.Enum):
|
| 27 |
+
NOT_YET_IMPLEMENTED = 0
|
| 28 |
+
REMOVED = -1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
ARGUMENT_KEY_MAPPING = {
|
| 32 |
+
# New keys in FSDP2
|
| 33 |
+
"fsdp_version": "fsdp_version",
|
| 34 |
+
"fsdp_reshard_after_forward": "fsdp_reshard_after_forward",
|
| 35 |
+
# https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
|
| 36 |
+
# https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
|
| 37 |
+
"fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy",
|
| 38 |
+
"fsdp_backward_prefetch": ConversionStatus.REMOVED,
|
| 39 |
+
"fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED,
|
| 40 |
+
"fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading",
|
| 41 |
+
"fsdp_offload_params": "fsdp_offload_params",
|
| 42 |
+
"fsdp_sharding_strategy": "fsdp_reshard_after_forward",
|
| 43 |
+
"fsdp_state_dict_type": "fsdp_state_dict_type",
|
| 44 |
+
"fsdp_sync_module_states": ConversionStatus.REMOVED,
|
| 45 |
+
"fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap",
|
| 46 |
+
"fsdp_min_num_params": "fsdp_min_num_params",
|
| 47 |
+
"fsdp_use_orig_params": ConversionStatus.REMOVED,
|
| 48 |
+
"fsdp_activation_checkpointing": "fsdp_activation_checkpointing",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
ARGUMENT_VALUE_MAPPING = {
|
| 52 |
+
"fsdp_sharding_strategy": {
|
| 53 |
+
"FULL_SHARD": True,
|
| 54 |
+
"SHARD_GRAD_OP": False,
|
| 55 |
+
"HYBRID_SHARD": True,
|
| 56 |
+
"HYBRID_SHARD_ZERO2": False,
|
| 57 |
+
"NO_SHARD": False,
|
| 58 |
+
},
|
| 59 |
+
"fsdp_reshard_after_forward": { # Needed to convert newly created configs using FSDP1 to FSDP2
|
| 60 |
+
"FULL_SHARD": True,
|
| 61 |
+
"SHARD_GRAD_OP": False,
|
| 62 |
+
"HYBRID_SHARD": True,
|
| 63 |
+
"HYBRID_SHARD_ZERO2": False,
|
| 64 |
+
"NO_SHARD": False,
|
| 65 |
+
},
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _validate_to_fsdp2_args(args):
|
| 72 |
+
if not Path(args.config_file).exists():
|
| 73 |
+
raise FileNotFoundError(f"Config file {args.config_file} not found")
|
| 74 |
+
|
| 75 |
+
if not args.overwrite and args.output_file is None:
|
| 76 |
+
raise ValueError("If --overwrite is not set, --output_file must be provided")
|
| 77 |
+
|
| 78 |
+
if not args.overwrite and Path(args.output_file).exists():
|
| 79 |
+
raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def convert_config_to_fsdp2(config: dict) -> dict:
|
| 83 |
+
fsdp_config = config.get("fsdp_config", {})
|
| 84 |
+
|
| 85 |
+
if not fsdp_config:
|
| 86 |
+
logger.info("No FSDP config found in the config file, skipping conversion...")
|
| 87 |
+
return config
|
| 88 |
+
|
| 89 |
+
new_fsdp_config = {}
|
| 90 |
+
|
| 91 |
+
if fsdp_config.get("fsdp_version", 1) == 2:
|
| 92 |
+
logger.warning("Config already specifies FSDP2, skipping conversion...")
|
| 93 |
+
logger.warning(
|
| 94 |
+
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
|
| 95 |
+
)
|
| 96 |
+
return config
|
| 97 |
+
|
| 98 |
+
for key, value in fsdp_config.items():
|
| 99 |
+
conversion_status = ARGUMENT_KEY_MAPPING.get(key, None)
|
| 100 |
+
if isinstance(conversion_status, ConversionStatus) or conversion_status is None:
|
| 101 |
+
conversion_status = key
|
| 102 |
+
new_fsdp_config[conversion_status] = value
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
if conversion_status == ConversionStatus.REMOVED:
|
| 106 |
+
logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...")
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED:
|
| 110 |
+
logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...")
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
if conversion_status is None:
|
| 114 |
+
logger.warning(f"Argument {key} is not being converted, skipping this key...")
|
| 115 |
+
new_fsdp_config[key] = value
|
| 116 |
+
else:
|
| 117 |
+
if key in ARGUMENT_VALUE_MAPPING:
|
| 118 |
+
value = ARGUMENT_VALUE_MAPPING[key].get(value, value)
|
| 119 |
+
new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value
|
| 120 |
+
|
| 121 |
+
new_fsdp_config["fsdp_version"] = 2
|
| 122 |
+
config["fsdp_config"] = new_fsdp_config
|
| 123 |
+
return config
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def to_fsdp2_command_parser(subparsers=None):
|
| 127 |
+
description = "Convert an Accelerate config from FSDP1 to FSDP2"
|
| 128 |
+
|
| 129 |
+
if subparsers is not None:
|
| 130 |
+
parser = subparsers.add_parser("to-fsdp2", description=description)
|
| 131 |
+
else:
|
| 132 |
+
parser = CustomArgumentParser(description=description)
|
| 133 |
+
|
| 134 |
+
parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--overwrite",
|
| 137 |
+
action="store_true",
|
| 138 |
+
help="Overwrite the config file if it exists",
|
| 139 |
+
default=False,
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--output_file",
|
| 143 |
+
type=str,
|
| 144 |
+
help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)",
|
| 145 |
+
default=None,
|
| 146 |
+
)
|
| 147 |
+
if subparsers is not None:
|
| 148 |
+
parser.set_defaults(func=to_fsdp2_command)
|
| 149 |
+
|
| 150 |
+
return parser
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_config(config_file: str) -> dict:
|
| 154 |
+
with open(config_file) as f:
|
| 155 |
+
config = yaml.safe_load(f)
|
| 156 |
+
if not config:
|
| 157 |
+
raise ValueError("Config file is empty")
|
| 158 |
+
|
| 159 |
+
return config
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def to_fsdp2_command(args):
|
| 163 |
+
_validate_to_fsdp2_args(args)
|
| 164 |
+
config = load_config(args.config_file)
|
| 165 |
+
|
| 166 |
+
if args.overwrite and args.output_file is None:
|
| 167 |
+
args.output_file = args.config_file
|
| 168 |
+
|
| 169 |
+
new_config = convert_config_to_fsdp2(config)
|
| 170 |
+
|
| 171 |
+
with open(args.output_file, "w") as f:
|
| 172 |
+
yaml.dump(new_config, f)
|
accelerate/commands/tpu.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import subprocess
|
| 20 |
+
|
| 21 |
+
from packaging.version import Version, parse
|
| 22 |
+
|
| 23 |
+
from accelerate.commands.config.config_args import default_config_file, load_config_from_file
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_description = "Run commands across TPU VMs for initial setup before running `accelerate launch`."
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def tpu_command_parser(subparsers=None):
|
| 30 |
+
if subparsers is not None:
|
| 31 |
+
parser = subparsers.add_parser("tpu-config", description=_description)
|
| 32 |
+
else:
|
| 33 |
+
parser = argparse.ArgumentParser("Accelerate tpu-config command", description=_description)
|
| 34 |
+
# Core arguments
|
| 35 |
+
config_args = parser.add_argument_group(
|
| 36 |
+
"Config Arguments", "Arguments that can be configured through `accelerate config`."
|
| 37 |
+
)
|
| 38 |
+
config_args.add_argument(
|
| 39 |
+
"--config_file",
|
| 40 |
+
type=str,
|
| 41 |
+
default=None,
|
| 42 |
+
help="Path to the config file to use for accelerate.",
|
| 43 |
+
)
|
| 44 |
+
config_args.add_argument(
|
| 45 |
+
"--tpu_name",
|
| 46 |
+
default=None,
|
| 47 |
+
help="The name of the TPU to use. If not specified, will use the TPU specified in the config file.",
|
| 48 |
+
)
|
| 49 |
+
config_args.add_argument(
|
| 50 |
+
"--tpu_zone",
|
| 51 |
+
default=None,
|
| 52 |
+
help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.",
|
| 53 |
+
)
|
| 54 |
+
pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.")
|
| 55 |
+
pod_args.add_argument(
|
| 56 |
+
"--use_alpha",
|
| 57 |
+
action="store_true",
|
| 58 |
+
help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.",
|
| 59 |
+
)
|
| 60 |
+
pod_args.add_argument(
|
| 61 |
+
"--command_file",
|
| 62 |
+
default=None,
|
| 63 |
+
help="The path to the file containing the commands to run on the pod on startup.",
|
| 64 |
+
)
|
| 65 |
+
pod_args.add_argument(
|
| 66 |
+
"--command",
|
| 67 |
+
action="append",
|
| 68 |
+
nargs="+",
|
| 69 |
+
help="A command to run on the pod. Can be passed multiple times.",
|
| 70 |
+
)
|
| 71 |
+
pod_args.add_argument(
|
| 72 |
+
"--install_accelerate",
|
| 73 |
+
action="store_true",
|
| 74 |
+
help="Whether to install accelerate on the pod. Defaults to False.",
|
| 75 |
+
)
|
| 76 |
+
pod_args.add_argument(
|
| 77 |
+
"--accelerate_version",
|
| 78 |
+
default="latest",
|
| 79 |
+
help="The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.",
|
| 80 |
+
)
|
| 81 |
+
pod_args.add_argument(
|
| 82 |
+
"--debug", action="store_true", help="If set, will print the command that would be run instead of running it."
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if subparsers is not None:
|
| 86 |
+
parser.set_defaults(func=tpu_command_launcher)
|
| 87 |
+
return parser
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def tpu_command_launcher(args):
|
| 91 |
+
defaults = None
|
| 92 |
+
|
| 93 |
+
# Get the default from the config file if it exists.
|
| 94 |
+
if args.config_file is not None or os.path.isfile(default_config_file):
|
| 95 |
+
defaults = load_config_from_file(args.config_file)
|
| 96 |
+
if not args.command_file and defaults.command_file is not None and not args.command:
|
| 97 |
+
args.command_file = defaults.command_file
|
| 98 |
+
if not args.command and defaults.commands is not None:
|
| 99 |
+
args.command = defaults.commands
|
| 100 |
+
if not args.tpu_name:
|
| 101 |
+
args.tpu_name = defaults.tpu_name
|
| 102 |
+
if not args.tpu_zone:
|
| 103 |
+
args.tpu_zone = defaults.tpu_zone
|
| 104 |
+
if args.accelerate_version == "dev":
|
| 105 |
+
args.accelerate_version = "git+https://github.com/huggingface/accelerate.git"
|
| 106 |
+
elif args.accelerate_version == "latest":
|
| 107 |
+
args.accelerate_version = "accelerate -U"
|
| 108 |
+
elif isinstance(parse(args.accelerate_version), Version):
|
| 109 |
+
args.accelerate_version = f"accelerate=={args.accelerate_version}"
|
| 110 |
+
|
| 111 |
+
if not args.command_file and not args.command:
|
| 112 |
+
raise ValueError("You must specify either a command file or a command to run on the pod.")
|
| 113 |
+
|
| 114 |
+
if args.command_file:
|
| 115 |
+
with open(args.command_file) as f:
|
| 116 |
+
args.command = [f.read().splitlines()]
|
| 117 |
+
|
| 118 |
+
# To turn list of lists into list of strings
|
| 119 |
+
if isinstance(args.command[0], list):
|
| 120 |
+
args.command = [line for cmd in args.command for line in cmd]
|
| 121 |
+
# Default to the shared folder and install accelerate
|
| 122 |
+
new_cmd = ["cd /usr/share"]
|
| 123 |
+
if args.install_accelerate:
|
| 124 |
+
new_cmd += [f"pip install {args.accelerate_version}"]
|
| 125 |
+
new_cmd += args.command
|
| 126 |
+
args.command = "; ".join(new_cmd)
|
| 127 |
+
|
| 128 |
+
# Then send it to gcloud
|
| 129 |
+
# Eventually try to use google-api-core to do this instead of subprocess
|
| 130 |
+
cmd = ["gcloud"]
|
| 131 |
+
if args.use_alpha:
|
| 132 |
+
cmd += ["alpha"]
|
| 133 |
+
cmd += [
|
| 134 |
+
"compute",
|
| 135 |
+
"tpus",
|
| 136 |
+
"tpu-vm",
|
| 137 |
+
"ssh",
|
| 138 |
+
args.tpu_name,
|
| 139 |
+
"--zone",
|
| 140 |
+
args.tpu_zone,
|
| 141 |
+
"--command",
|
| 142 |
+
args.command,
|
| 143 |
+
"--worker",
|
| 144 |
+
"all",
|
| 145 |
+
]
|
| 146 |
+
if args.debug:
|
| 147 |
+
print(f"Running {' '.join(cmd)}")
|
| 148 |
+
return
|
| 149 |
+
subprocess.run(cmd)
|
| 150 |
+
print("Successfully setup pod.")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def main():
|
| 154 |
+
parser = tpu_command_parser()
|
| 155 |
+
args = parser.parse_args()
|
| 156 |
+
|
| 157 |
+
tpu_command_launcher(args)
|
accelerate/commands/utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _StoreAction(argparse.Action):
|
| 19 |
+
"""
|
| 20 |
+
Custom action that allows for `-` or `_` to be passed in for an argument.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, *args, **kwargs):
|
| 24 |
+
super().__init__(*args, **kwargs)
|
| 25 |
+
new_option_strings = []
|
| 26 |
+
for option_string in self.option_strings:
|
| 27 |
+
new_option_strings.append(option_string)
|
| 28 |
+
if "_" in option_string[2:]:
|
| 29 |
+
# Add `-` version to the option string
|
| 30 |
+
new_option_strings.append(option_string.replace("_", "-"))
|
| 31 |
+
self.option_strings = new_option_strings
|
| 32 |
+
|
| 33 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
| 34 |
+
setattr(namespace, self.dest, values)
|
| 35 |
+
if not hasattr(namespace, "nondefault"):
|
| 36 |
+
namespace.nondefault = set()
|
| 37 |
+
namespace.nondefault.add(self.dest)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class _StoreConstAction(_StoreAction):
|
| 41 |
+
"""
|
| 42 |
+
Same as `argparse._StoreConstAction` but uses the custom `_StoreAction`.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, option_strings, dest, const, default=None, required=False, help=None):
|
| 46 |
+
super().__init__(
|
| 47 |
+
option_strings=option_strings,
|
| 48 |
+
dest=dest,
|
| 49 |
+
nargs=0,
|
| 50 |
+
const=const,
|
| 51 |
+
default=default,
|
| 52 |
+
required=required,
|
| 53 |
+
help=help,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
| 57 |
+
super().__call__(parser, namespace, self.const, option_string)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class _StoreTrueAction(_StoreConstAction):
|
| 61 |
+
"""
|
| 62 |
+
Same as `argparse._StoreTrueAction` but uses the custom `_StoreConstAction`.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
option_strings,
|
| 68 |
+
dest,
|
| 69 |
+
default=None,
|
| 70 |
+
required=False,
|
| 71 |
+
help=None,
|
| 72 |
+
):
|
| 73 |
+
super().__init__(
|
| 74 |
+
option_strings=option_strings, dest=dest, const=True, default=default, required=required, help=help
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class CustomArgumentGroup(argparse._ArgumentGroup):
|
| 79 |
+
"""
|
| 80 |
+
Custom argument group that allows for the use of `-` or `_` in arguments passed and overrides the help for each
|
| 81 |
+
when applicable.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def _add_action(self, action):
|
| 85 |
+
args = vars(action)
|
| 86 |
+
if isinstance(action, argparse._StoreTrueAction):
|
| 87 |
+
action = _StoreTrueAction(
|
| 88 |
+
args["option_strings"], args["dest"], args["default"], args["required"], args["help"]
|
| 89 |
+
)
|
| 90 |
+
elif isinstance(action, argparse._StoreConstAction):
|
| 91 |
+
action = _StoreConstAction(
|
| 92 |
+
args["option_strings"],
|
| 93 |
+
args["dest"],
|
| 94 |
+
args["const"],
|
| 95 |
+
args["default"],
|
| 96 |
+
args["required"],
|
| 97 |
+
args["help"],
|
| 98 |
+
)
|
| 99 |
+
elif isinstance(action, argparse._StoreAction):
|
| 100 |
+
action = _StoreAction(**args)
|
| 101 |
+
action = super()._add_action(action)
|
| 102 |
+
return action
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class CustomArgumentParser(argparse.ArgumentParser):
|
| 106 |
+
"""
|
| 107 |
+
Custom argument parser that allows for the use of `-` or `_` in arguments passed and overrides the help for each
|
| 108 |
+
when applicable.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def add_argument(self, *args, **kwargs):
|
| 112 |
+
if "action" in kwargs:
|
| 113 |
+
# Translate action -> class
|
| 114 |
+
if kwargs["action"] == "store_true":
|
| 115 |
+
kwargs["action"] = _StoreTrueAction
|
| 116 |
+
else:
|
| 117 |
+
kwargs["action"] = _StoreAction
|
| 118 |
+
super().add_argument(*args, **kwargs)
|
| 119 |
+
|
| 120 |
+
def add_argument_group(self, *args, **kwargs):
|
| 121 |
+
group = CustomArgumentGroup(self, *args, **kwargs)
|
| 122 |
+
self._action_groups.append(group)
|
| 123 |
+
return group
|
accelerate/test_utils/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from .testing import (
|
| 15 |
+
DEFAULT_LAUNCH_COMMAND,
|
| 16 |
+
are_the_same_tensors,
|
| 17 |
+
assert_exception,
|
| 18 |
+
capture_call_output,
|
| 19 |
+
device_count,
|
| 20 |
+
execute_subprocess_async,
|
| 21 |
+
get_launch_command,
|
| 22 |
+
get_torch_dist_unique_port,
|
| 23 |
+
memory_allocated_func,
|
| 24 |
+
path_in_accelerate_package,
|
| 25 |
+
pytest_xdist_worker_id,
|
| 26 |
+
require_bnb,
|
| 27 |
+
require_cpu,
|
| 28 |
+
require_cuda,
|
| 29 |
+
require_cuda_or_hpu,
|
| 30 |
+
require_cuda_or_xpu,
|
| 31 |
+
require_fp8,
|
| 32 |
+
require_fp16,
|
| 33 |
+
require_huggingface_suite,
|
| 34 |
+
require_mlu,
|
| 35 |
+
require_mps,
|
| 36 |
+
require_multi_device,
|
| 37 |
+
require_multi_gpu,
|
| 38 |
+
require_multi_gpu_or_xpu,
|
| 39 |
+
require_multi_xpu,
|
| 40 |
+
require_musa,
|
| 41 |
+
require_non_cpu,
|
| 42 |
+
require_non_hpu,
|
| 43 |
+
require_non_torch_xla,
|
| 44 |
+
require_non_xpu,
|
| 45 |
+
require_npu,
|
| 46 |
+
require_pippy,
|
| 47 |
+
require_sdaa,
|
| 48 |
+
require_single_device,
|
| 49 |
+
require_single_gpu,
|
| 50 |
+
require_single_xpu,
|
| 51 |
+
require_torch_min_version,
|
| 52 |
+
require_torchao,
|
| 53 |
+
require_torchvision,
|
| 54 |
+
require_tpu,
|
| 55 |
+
require_transformer_engine,
|
| 56 |
+
require_transformer_engine_mxfp8,
|
| 57 |
+
require_xpu,
|
| 58 |
+
run_first,
|
| 59 |
+
skip,
|
| 60 |
+
slow,
|
| 61 |
+
torch_device,
|
| 62 |
+
)
|
| 63 |
+
from .training import RegressionDataset, RegressionModel
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
from .scripts import test_script, test_sync, test_ops # isort: skip
|
accelerate/test_utils/examples.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
A collection of utilities for comparing `examples/complete_*_example.py` scripts with the capabilities inside of each
|
| 18 |
+
`examples/by_feature` example. `compare_against_test` is the main function that should be used when testing, while the
|
| 19 |
+
others are used to either get the code that matters, or to preprocess them (such as stripping comments)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_function_contents_by_name(lines: list[str], name: str):
|
| 27 |
+
"""
|
| 28 |
+
Extracts a function from `lines` of segmented source code with the name `name`.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
lines (`List[str]`):
|
| 32 |
+
Source code of a script separated by line.
|
| 33 |
+
name (`str`):
|
| 34 |
+
The name of the function to extract. Should be either `training_function` or `main`
|
| 35 |
+
"""
|
| 36 |
+
if name != "training_function" and name != "main":
|
| 37 |
+
raise ValueError(f"Incorrect function name passed: {name}, choose either 'main' or 'training_function'")
|
| 38 |
+
good_lines, found_start = [], False
|
| 39 |
+
for line in lines:
|
| 40 |
+
if not found_start and f"def {name}" in line:
|
| 41 |
+
found_start = True
|
| 42 |
+
good_lines.append(line)
|
| 43 |
+
continue
|
| 44 |
+
if found_start:
|
| 45 |
+
if name == "training_function" and "def main" in line:
|
| 46 |
+
return good_lines
|
| 47 |
+
if name == "main" and "if __name__" in line:
|
| 48 |
+
return good_lines
|
| 49 |
+
good_lines.append(line)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def clean_lines(lines: list[str]):
|
| 53 |
+
"""
|
| 54 |
+
Filters `lines` and removes any entries that start with a comment ('#') or is just a newline ('\n')
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
lines (`List[str]`):
|
| 58 |
+
Source code of a script separated by line.
|
| 59 |
+
"""
|
| 60 |
+
return [line for line in lines if not line.lstrip().startswith("#") and line != "\n"]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def compare_against_test(
|
| 64 |
+
base_filename: str, feature_filename: str, parser_only: bool, secondary_filename: Optional[str] = None
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Tests whether the additional code inside of `feature_filename` was implemented in `base_filename`. This should be
|
| 68 |
+
used when testing to see if `complete_*_.py` examples have all of the implementations from each of the
|
| 69 |
+
`examples/by_feature/*` scripts.
|
| 70 |
+
|
| 71 |
+
It utilizes `nlp_example.py` to extract out all of the repeated training code, so that only the new additional code
|
| 72 |
+
is examined and checked. If something *other* than `nlp_example.py` should be used, such as `cv_example.py` for the
|
| 73 |
+
`complete_cv_example.py` script, it should be passed in for the `secondary_filename` parameter.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
base_filename (`str` or `os.PathLike`):
|
| 77 |
+
The filepath of a single "complete" example script to test, such as `examples/complete_cv_example.py`
|
| 78 |
+
feature_filename (`str` or `os.PathLike`):
|
| 79 |
+
The filepath of a single feature example script. The contents of this script are checked to see if they
|
| 80 |
+
exist in `base_filename`
|
| 81 |
+
parser_only (`bool`):
|
| 82 |
+
Whether to compare only the `main()` sections in both files, or to compare the contents of
|
| 83 |
+
`training_loop()`
|
| 84 |
+
secondary_filename (`str`, *optional*):
|
| 85 |
+
A potential secondary filepath that should be included in the check. This function extracts the base
|
| 86 |
+
functionalities off of "examples/nlp_example.py", so if `base_filename` is a script other than
|
| 87 |
+
`complete_nlp_example.py`, the template script should be included here. Such as `examples/cv_example.py`
|
| 88 |
+
"""
|
| 89 |
+
with open(base_filename) as f:
|
| 90 |
+
base_file_contents = f.readlines()
|
| 91 |
+
with open(os.path.abspath(os.path.join("examples", "nlp_example.py"))) as f:
|
| 92 |
+
full_file_contents = f.readlines()
|
| 93 |
+
with open(feature_filename) as f:
|
| 94 |
+
feature_file_contents = f.readlines()
|
| 95 |
+
if secondary_filename is not None:
|
| 96 |
+
with open(secondary_filename) as f:
|
| 97 |
+
secondary_file_contents = f.readlines()
|
| 98 |
+
|
| 99 |
+
# This is our base, we remove all the code from here in our `full_filename` and `feature_filename` to find the new content
|
| 100 |
+
if parser_only:
|
| 101 |
+
base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "main"))
|
| 102 |
+
full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "main"))
|
| 103 |
+
feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "main"))
|
| 104 |
+
if secondary_filename is not None:
|
| 105 |
+
secondary_file_func = clean_lines(get_function_contents_by_name(secondary_file_contents, "main"))
|
| 106 |
+
else:
|
| 107 |
+
base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "training_function"))
|
| 108 |
+
full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "training_function"))
|
| 109 |
+
feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "training_function"))
|
| 110 |
+
if secondary_filename is not None:
|
| 111 |
+
secondary_file_func = clean_lines(
|
| 112 |
+
get_function_contents_by_name(secondary_file_contents, "training_function")
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
_dl_line = "train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n"
|
| 116 |
+
|
| 117 |
+
# Specific code in our script that differs from the full version, aka what is new
|
| 118 |
+
new_feature_code = []
|
| 119 |
+
passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
|
| 120 |
+
it = iter(feature_file_func)
|
| 121 |
+
for i in range(len(feature_file_func) - 1):
|
| 122 |
+
if i not in passed_idxs:
|
| 123 |
+
line = next(it)
|
| 124 |
+
if (line not in full_file_func) and (line.lstrip() != _dl_line):
|
| 125 |
+
if "TESTING_MOCKED_DATALOADERS" not in line:
|
| 126 |
+
new_feature_code.append(line)
|
| 127 |
+
passed_idxs.append(i)
|
| 128 |
+
else:
|
| 129 |
+
# Skip over the `config['num_epochs'] = 2` statement
|
| 130 |
+
_ = next(it)
|
| 131 |
+
|
| 132 |
+
# Extract out just the new parts from the full_file_training_func
|
| 133 |
+
new_full_example_parts = []
|
| 134 |
+
passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
|
| 135 |
+
for i, line in enumerate(base_file_func):
|
| 136 |
+
if i not in passed_idxs:
|
| 137 |
+
if (line not in full_file_func) and (line.lstrip() != _dl_line):
|
| 138 |
+
if "TESTING_MOCKED_DATALOADERS" not in line:
|
| 139 |
+
new_full_example_parts.append(line)
|
| 140 |
+
passed_idxs.append(i)
|
| 141 |
+
|
| 142 |
+
# Finally, get the overall diff
|
| 143 |
+
diff_from_example = [line for line in new_feature_code if line not in new_full_example_parts]
|
| 144 |
+
if secondary_filename is not None:
|
| 145 |
+
diff_from_two = [line for line in full_file_contents if line not in secondary_file_func]
|
| 146 |
+
diff_from_example = [line for line in diff_from_example if line not in diff_from_two]
|
| 147 |
+
|
| 148 |
+
return diff_from_example
|
accelerate/test_utils/scripts/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
accelerate/test_utils/scripts/external_deps/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
accelerate/test_utils/scripts/external_deps/test_checkpointing.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import evaluate
|
| 19 |
+
import torch
|
| 20 |
+
from datasets import load_dataset
|
| 21 |
+
from torch.optim import AdamW
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
| 24 |
+
|
| 25 |
+
from accelerate import Accelerator, DistributedType
|
| 26 |
+
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
MAX_GPU_BATCH_SIZE = 16
|
| 30 |
+
EVAL_BATCH_SIZE = 32
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"):
|
| 34 |
+
"""
|
| 35 |
+
Creates a set of `DataLoader`s for the `glue` dataset.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
accelerator (`Accelerator`):
|
| 39 |
+
An `Accelerator` object
|
| 40 |
+
batch_size (`int`, *optional*):
|
| 41 |
+
The batch size for the train and validation DataLoaders.
|
| 42 |
+
model_name (`str`, *optional*):
|
| 43 |
+
"""
|
| 44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 45 |
+
datasets = load_dataset("glue", "mrpc")
|
| 46 |
+
|
| 47 |
+
def tokenize_function(examples):
|
| 48 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 49 |
+
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
| 50 |
+
return outputs
|
| 51 |
+
|
| 52 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 53 |
+
tokenized_datasets = datasets.map(
|
| 54 |
+
tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
|
| 58 |
+
# transformers library
|
| 59 |
+
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 60 |
+
|
| 61 |
+
def collate_fn(examples):
|
| 62 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 63 |
+
if accelerator.distributed_type == DistributedType.XLA:
|
| 64 |
+
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
| 65 |
+
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
| 66 |
+
|
| 67 |
+
# Instantiate dataloaders.
|
| 68 |
+
train_dataloader = DataLoader(
|
| 69 |
+
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
|
| 70 |
+
)
|
| 71 |
+
eval_dataloader = DataLoader(
|
| 72 |
+
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return train_dataloader, eval_dataloader
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def evaluation_loop(accelerator, model, eval_dataloader, metric):
|
| 79 |
+
model.eval()
|
| 80 |
+
samples_seen = 0
|
| 81 |
+
for step, batch in enumerate(eval_dataloader):
|
| 82 |
+
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
| 83 |
+
batch.to(accelerator.device)
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
outputs = model(**batch)
|
| 86 |
+
predictions = outputs.logits.argmax(dim=-1)
|
| 87 |
+
# It is slightly faster to call this once, than multiple times
|
| 88 |
+
predictions, references = accelerator.gather(
|
| 89 |
+
(predictions, batch["labels"])
|
| 90 |
+
) # If we are in a multiprocess environment, the last batch has duplicates
|
| 91 |
+
if accelerator.use_distributed:
|
| 92 |
+
if step == len(eval_dataloader) - 1:
|
| 93 |
+
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
|
| 94 |
+
references = references[: len(eval_dataloader.dataset) - samples_seen]
|
| 95 |
+
else:
|
| 96 |
+
samples_seen += references.shape[0]
|
| 97 |
+
metric.add_batch(
|
| 98 |
+
predictions=predictions,
|
| 99 |
+
references=references,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
eval_metric = metric.compute()
|
| 103 |
+
return eval_metric["accuracy"]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def training_function(config, args):
|
| 107 |
+
# Initialize accelerator
|
| 108 |
+
accelerator = Accelerator()
|
| 109 |
+
|
| 110 |
+
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
| 111 |
+
lr = config["lr"]
|
| 112 |
+
num_epochs = int(config["num_epochs"])
|
| 113 |
+
seed = int(config["seed"])
|
| 114 |
+
batch_size = int(config["batch_size"])
|
| 115 |
+
model_name = args.model_name_or_path
|
| 116 |
+
|
| 117 |
+
set_seed(seed)
|
| 118 |
+
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
|
| 119 |
+
|
| 120 |
+
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
| 121 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
|
| 122 |
+
|
| 123 |
+
# Instantiate optimizer
|
| 124 |
+
optimizer_cls = (
|
| 125 |
+
AdamW
|
| 126 |
+
if accelerator.state.deepspeed_plugin is None
|
| 127 |
+
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 128 |
+
else DummyOptim
|
| 129 |
+
)
|
| 130 |
+
optimizer = optimizer_cls(params=model.parameters(), lr=lr)
|
| 131 |
+
|
| 132 |
+
if accelerator.state.deepspeed_plugin is not None:
|
| 133 |
+
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
| 134 |
+
"gradient_accumulation_steps"
|
| 135 |
+
]
|
| 136 |
+
else:
|
| 137 |
+
gradient_accumulation_steps = 1
|
| 138 |
+
max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps
|
| 139 |
+
|
| 140 |
+
# Instantiate scheduler
|
| 141 |
+
if (
|
| 142 |
+
accelerator.state.deepspeed_plugin is None
|
| 143 |
+
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 144 |
+
):
|
| 145 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
| 146 |
+
optimizer=optimizer,
|
| 147 |
+
num_warmup_steps=0,
|
| 148 |
+
num_training_steps=max_training_steps,
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)
|
| 152 |
+
|
| 153 |
+
# Prepare everything
|
| 154 |
+
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
|
| 155 |
+
# prepare method.
|
| 156 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
| 157 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# We need to keep track of how many total steps we have iterated over
|
| 161 |
+
overall_step = 0
|
| 162 |
+
# We also need to keep track of the stating epoch so files are named properly
|
| 163 |
+
starting_epoch = 0
|
| 164 |
+
metric = evaluate.load("glue", "mrpc")
|
| 165 |
+
ending_epoch = num_epochs
|
| 166 |
+
|
| 167 |
+
if args.partial_train_epoch is not None:
|
| 168 |
+
ending_epoch = args.partial_train_epoch
|
| 169 |
+
|
| 170 |
+
if args.resume_from_checkpoint:
|
| 171 |
+
accelerator.load_state(args.resume_from_checkpoint)
|
| 172 |
+
epoch_string = args.resume_from_checkpoint.split("epoch_")[1]
|
| 173 |
+
state_epoch_num = ""
|
| 174 |
+
for char in epoch_string:
|
| 175 |
+
if char.isdigit():
|
| 176 |
+
state_epoch_num += char
|
| 177 |
+
else:
|
| 178 |
+
break
|
| 179 |
+
starting_epoch = int(state_epoch_num) + 1
|
| 180 |
+
accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric)
|
| 181 |
+
accelerator.print("resumed checkpoint performance:", accuracy)
|
| 182 |
+
accelerator.print("resumed checkpoint's scheduler's lr:", lr_scheduler.get_lr()[0])
|
| 183 |
+
accelerator.print("resumed optimizers's lr:", optimizer.param_groups[0]["lr"])
|
| 184 |
+
with open(os.path.join(args.output_dir, f"state_{starting_epoch - 1}.json")) as f:
|
| 185 |
+
resumed_state = json.load(f)
|
| 186 |
+
assert resumed_state["accuracy"] == accuracy, "Accuracy mismatch, loading from checkpoint failed"
|
| 187 |
+
assert resumed_state["lr"] == lr_scheduler.get_lr()[0], (
|
| 188 |
+
"Scheduler learning rate mismatch, loading from checkpoint failed"
|
| 189 |
+
)
|
| 190 |
+
assert resumed_state["optimizer_lr"] == optimizer.param_groups[0]["lr"], (
|
| 191 |
+
"Optimizer learning rate mismatch, loading from checkpoint failed"
|
| 192 |
+
)
|
| 193 |
+
assert resumed_state["epoch"] == starting_epoch - 1, "Epoch mismatch, loading from checkpoint failed"
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
# Now we train the model
|
| 197 |
+
state = {}
|
| 198 |
+
for epoch in range(starting_epoch, ending_epoch):
|
| 199 |
+
model.train()
|
| 200 |
+
for step, batch in enumerate(train_dataloader):
|
| 201 |
+
outputs = model(**batch)
|
| 202 |
+
loss = outputs.loss
|
| 203 |
+
loss = loss / gradient_accumulation_steps
|
| 204 |
+
accelerator.backward(loss)
|
| 205 |
+
if step % gradient_accumulation_steps == 0:
|
| 206 |
+
optimizer.step()
|
| 207 |
+
lr_scheduler.step()
|
| 208 |
+
optimizer.zero_grad()
|
| 209 |
+
|
| 210 |
+
overall_step += 1
|
| 211 |
+
output_dir = f"epoch_{epoch}"
|
| 212 |
+
output_dir = os.path.join(args.output_dir, output_dir)
|
| 213 |
+
accelerator.save_state(output_dir)
|
| 214 |
+
accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric)
|
| 215 |
+
state["accuracy"] = accuracy
|
| 216 |
+
state["lr"] = lr_scheduler.get_lr()[0]
|
| 217 |
+
state["optimizer_lr"] = optimizer.param_groups[0]["lr"]
|
| 218 |
+
state["epoch"] = epoch
|
| 219 |
+
state["step"] = overall_step
|
| 220 |
+
accelerator.print(f"epoch {epoch}:", state)
|
| 221 |
+
|
| 222 |
+
accelerator.wait_for_everyone()
|
| 223 |
+
if accelerator.is_main_process:
|
| 224 |
+
with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f:
|
| 225 |
+
json.dump(state, f)
|
| 226 |
+
accelerator.end_training()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def main():
|
| 230 |
+
parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--model_name_or_path",
|
| 233 |
+
type=str,
|
| 234 |
+
default="bert-base-cased",
|
| 235 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 236 |
+
required=False,
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--output_dir",
|
| 240 |
+
type=str,
|
| 241 |
+
default=".",
|
| 242 |
+
help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--resume_from_checkpoint",
|
| 246 |
+
type=str,
|
| 247 |
+
default=None,
|
| 248 |
+
help="If the training should continue from a checkpoint folder.",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--partial_train_epoch",
|
| 252 |
+
type=int,
|
| 253 |
+
default=None,
|
| 254 |
+
help="If passed, the training will stop after this number of epochs.",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--num_epochs",
|
| 258 |
+
type=int,
|
| 259 |
+
default=2,
|
| 260 |
+
help="Number of train epochs.",
|
| 261 |
+
)
|
| 262 |
+
args = parser.parse_args()
|
| 263 |
+
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
|
| 264 |
+
|
| 265 |
+
training_function(config, args)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
main()
|
accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Test script for verifying ALST/Ulysses SP works
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from deepspeed.runtime.utils import move_to_device
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
|
| 23 |
+
from accelerate import Accelerator
|
| 24 |
+
from accelerate.utils import ParallelismConfig, set_seed
|
| 25 |
+
from accelerate.utils.dataclasses import DeepSpeedSequenceParallelConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
set_seed(42)
|
| 29 |
+
|
| 30 |
+
world_size = 2
|
| 31 |
+
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
| 32 |
+
|
| 33 |
+
micro_batch_size = 1
|
| 34 |
+
|
| 35 |
+
parallelism_config = ParallelismConfig(
|
| 36 |
+
sp_backend="deepspeed",
|
| 37 |
+
sp_size=world_size,
|
| 38 |
+
# dp_shard_size=1, # set if dp is wanted as well
|
| 39 |
+
sp_handler=DeepSpeedSequenceParallelConfig(
|
| 40 |
+
sp_seq_length=256,
|
| 41 |
+
sp_seq_length_is_variable=True,
|
| 42 |
+
sp_attn_implementation="sdpa",
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
accelerator = Accelerator(
|
| 47 |
+
parallelism_config=parallelism_config,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 51 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 52 |
+
|
| 53 |
+
samples = 4
|
| 54 |
+
seqlen = 32
|
| 55 |
+
input_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100
|
| 56 |
+
position_ids = torch.arange(seqlen * samples).view(-1, seqlen)
|
| 57 |
+
|
| 58 |
+
ds = torch.utils.data.TensorDataset(input_ids, position_ids)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def collate_fn(batch):
|
| 62 |
+
input_ids, position_ids = batch[0]
|
| 63 |
+
return dict(
|
| 64 |
+
input_ids=input_ids.unsqueeze(0),
|
| 65 |
+
position_ids=position_ids.unsqueeze(0),
|
| 66 |
+
labels=input_ids.unsqueeze(0),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)
|
| 71 |
+
|
| 72 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
| 73 |
+
|
| 74 |
+
rank = torch.distributed.get_rank()
|
| 75 |
+
|
| 76 |
+
if rank == 0:
|
| 77 |
+
print(f"DL orig: {len(dl)} samples")
|
| 78 |
+
|
| 79 |
+
model, optimizer, dl = accelerator.prepare(model, optimizer, dl)
|
| 80 |
+
|
| 81 |
+
if rank == 0:
|
| 82 |
+
print(f"DL w/ adapter: {len(dl)} samples")
|
| 83 |
+
|
| 84 |
+
sp_size = parallelism_config.sp_size if parallelism_config else 1
|
| 85 |
+
if sp_size > 1:
|
| 86 |
+
from deepspeed.utils import groups
|
| 87 |
+
|
| 88 |
+
sp_group = groups._get_sequence_parallel_group()
|
| 89 |
+
sp_world_size = parallelism_config.sp_size
|
| 90 |
+
|
| 91 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 92 |
+
|
| 93 |
+
# Normal training loop
|
| 94 |
+
for iter, batch in enumerate(dl):
|
| 95 |
+
optimizer.zero_grad()
|
| 96 |
+
|
| 97 |
+
if rank == 0:
|
| 98 |
+
print(f"batch {iter}: seqlen: {len(batch['input_ids'][0])}")
|
| 99 |
+
batch = move_to_device(batch, model.device)
|
| 100 |
+
outputs = model(**batch)
|
| 101 |
+
|
| 102 |
+
shift_labels = batch["shift_labels"]
|
| 103 |
+
loss = unwrapped_model.loss_function(
|
| 104 |
+
logits=outputs.logits,
|
| 105 |
+
labels=None,
|
| 106 |
+
shift_labels=shift_labels,
|
| 107 |
+
vocab_size=unwrapped_model.config.vocab_size,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if sp_size > 1:
|
| 111 |
+
# differentiable weighted per-shard-loss aggregation across ranks
|
| 112 |
+
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
|
| 113 |
+
# special dealing with SFT that has prompt tokens that aren't used in loss computation
|
| 114 |
+
good_tokens = (shift_labels != -100).view(-1).sum()
|
| 115 |
+
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
|
| 116 |
+
total_loss = sum(
|
| 117 |
+
losses_per_rank[rank] * good_tokens_per_rank[rank]
|
| 118 |
+
for rank in range(sp_world_size)
|
| 119 |
+
if good_tokens_per_rank[rank] > 0
|
| 120 |
+
)
|
| 121 |
+
total_good_tokens = sum(good_tokens_per_rank)
|
| 122 |
+
loss = total_loss / max(total_good_tokens, 1)
|
| 123 |
+
|
| 124 |
+
if rank == 0:
|
| 125 |
+
accelerator.print(f"{iter}: {loss=}")
|
| 126 |
+
accelerator.log(dict(train_loss=loss, step=iter))
|
| 127 |
+
|
| 128 |
+
accelerator.backward(loss)
|
| 129 |
+
optimizer.step()
|
| 130 |
+
|
| 131 |
+
accelerator.end_training()
|
accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Test script for verifying multiple models can be utilized with Accelerate + DeepSpeed:
|
| 17 |
+
|
| 18 |
+
Scenario 1: One model is training, another model is being used for inference/logits to impact training in some form.
|
| 19 |
+
Scenario 2: Two models are training simultaneously, which means two optimizers, etc.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import evaluate
|
| 26 |
+
import torch
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
from torch.optim import AdamW
|
| 29 |
+
from torch.utils.data import DataLoader
|
| 30 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
|
| 31 |
+
|
| 32 |
+
from accelerate import Accelerator, DeepSpeedPlugin, DistributedType
|
| 33 |
+
from accelerate.state import AcceleratorState
|
| 34 |
+
from accelerate.utils.deepspeed import get_active_deepspeed_plugin
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
EVAL_BATCH_SIZE = 16
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class NoiseModel(torch.nn.Module):
|
| 41 |
+
def __init__(self, noise_factor=0.1):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.noise_factor = torch.nn.Parameter(torch.tensor(noise_factor, dtype=torch.float32))
|
| 44 |
+
|
| 45 |
+
def forward(self, loss):
|
| 46 |
+
return loss * self.noise_factor
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"):
|
| 50 |
+
"""
|
| 51 |
+
Creates a set of `DataLoader`s for the `glue` dataset.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
accelerator (`Accelerator`):
|
| 55 |
+
An `Accelerator` object
|
| 56 |
+
batch_size (`int`, *optional*):
|
| 57 |
+
The batch size for the train and validation DataLoaders.
|
| 58 |
+
model_name (`str`, *optional*):
|
| 59 |
+
"""
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 61 |
+
datasets = load_dataset("glue", "mrpc")
|
| 62 |
+
|
| 63 |
+
def tokenize_function(examples):
|
| 64 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 65 |
+
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
| 66 |
+
return outputs
|
| 67 |
+
|
| 68 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 69 |
+
tokenized_datasets = datasets.map(
|
| 70 |
+
tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
|
| 74 |
+
# transformers library
|
| 75 |
+
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 76 |
+
|
| 77 |
+
def collate_fn(examples):
|
| 78 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 79 |
+
if accelerator.distributed_type == DistributedType.XLA:
|
| 80 |
+
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
| 81 |
+
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
| 82 |
+
|
| 83 |
+
# Instantiate dataloaders.
|
| 84 |
+
train_dataloader = DataLoader(
|
| 85 |
+
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
|
| 86 |
+
)
|
| 87 |
+
eval_dataloader = DataLoader(
|
| 88 |
+
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return train_dataloader, eval_dataloader
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
test_file_path = __file__
|
| 95 |
+
path = Path(test_file_path).resolve()
|
| 96 |
+
test_file_dir_str = str(path.parent.parent.parent.parent.parent.parent)
|
| 97 |
+
|
| 98 |
+
# Create our DS plugins
|
| 99 |
+
# We use custom schedulers and optimizers, hence `model_only`
|
| 100 |
+
ds_config_file = dict(
|
| 101 |
+
zero2=f"{test_file_dir_str}/tests/deepspeed/ds_config_zero2_model_only.json",
|
| 102 |
+
zero3=f"{test_file_dir_str}/tests/deepspeed/ds_config_zero3_model_only.json",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def single_model_training(config, args):
|
| 107 |
+
# Training a single model, we have a `noise` model that is untrainable used to inject some noise into the training process
|
| 108 |
+
num_epochs = config["num_epochs"]
|
| 109 |
+
zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero2"])
|
| 110 |
+
zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero3"])
|
| 111 |
+
|
| 112 |
+
deepspeed_plugins = {"training": zero2_plugin, "inference": zero3_plugin}
|
| 113 |
+
|
| 114 |
+
# Initialize accelerator
|
| 115 |
+
accelerator = Accelerator(
|
| 116 |
+
deepspeed_plugins=deepspeed_plugins,
|
| 117 |
+
mixed_precision="bf16",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Initialize model under zero2 plugin
|
| 121 |
+
assert get_active_deepspeed_plugin(accelerator.state) is zero2_plugin
|
| 122 |
+
train_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
|
| 123 |
+
train_dataloader, eval_dataloader = get_dataloaders(
|
| 124 |
+
accelerator, batch_size=config["batch_size"], model_name=args.model_name_or_path
|
| 125 |
+
)
|
| 126 |
+
max_training_steps = len(train_dataloader) * config["num_epochs"]
|
| 127 |
+
optimizer = AdamW(train_model.parameters(), lr=config["lr"])
|
| 128 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
| 129 |
+
optimizer, num_warmup_steps=0, num_training_steps=max_training_steps
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler = accelerator.prepare(
|
| 133 |
+
train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Now prepare the model under zero3 plugin
|
| 137 |
+
accelerator.state.select_deepspeed_plugin("inference")
|
| 138 |
+
assert get_active_deepspeed_plugin(accelerator.state) is zero3_plugin
|
| 139 |
+
inference_model = NoiseModel()
|
| 140 |
+
inference_model = accelerator.prepare(inference_model)
|
| 141 |
+
inference_model.eval()
|
| 142 |
+
|
| 143 |
+
# Run training loop
|
| 144 |
+
accelerator.state.select_deepspeed_plugin("training")
|
| 145 |
+
# We also need to keep track of the stating epoch so files are named properly
|
| 146 |
+
starting_epoch = 0
|
| 147 |
+
|
| 148 |
+
# Now we train the model
|
| 149 |
+
best_performance = 0
|
| 150 |
+
metric = evaluate.load("glue", "mrpc")
|
| 151 |
+
performance_metric = {}
|
| 152 |
+
for epoch in range(starting_epoch, num_epochs):
|
| 153 |
+
train_model.train()
|
| 154 |
+
inference_model.train()
|
| 155 |
+
for step, batch in enumerate(train_dataloader):
|
| 156 |
+
with accelerator.accumulate(train_model):
|
| 157 |
+
outputs_1 = train_model(**batch)
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
outputs_2 = inference_model(outputs_1.loss)
|
| 160 |
+
# Combine the losses
|
| 161 |
+
loss = outputs_1.loss + outputs_2
|
| 162 |
+
accelerator.backward(loss)
|
| 163 |
+
optimizer.step()
|
| 164 |
+
lr_scheduler.step()
|
| 165 |
+
optimizer.zero_grad()
|
| 166 |
+
|
| 167 |
+
train_model.eval()
|
| 168 |
+
for step, batch in enumerate(eval_dataloader):
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
outputs = train_model(**batch)
|
| 171 |
+
predictions = outputs.logits.argmax(dim=-1)
|
| 172 |
+
# It is slightly faster to call this once, than multiple times
|
| 173 |
+
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
| 174 |
+
metric.add_batch(
|
| 175 |
+
predictions=predictions,
|
| 176 |
+
references=references,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
eval_metric = metric.compute()
|
| 180 |
+
# Use accelerator.print to print only on the main process.
|
| 181 |
+
accelerator.print(f"epoch {epoch}:", eval_metric)
|
| 182 |
+
performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"]
|
| 183 |
+
|
| 184 |
+
if best_performance < eval_metric["accuracy"]:
|
| 185 |
+
best_performance = eval_metric["accuracy"]
|
| 186 |
+
assert best_performance > performance_metric["epoch-0"]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def multiple_model_training(config, args):
|
| 190 |
+
# This will essentially be like a k-fold model, but one model is Zero-2 and another model is Zero-3
|
| 191 |
+
num_epochs = config["num_epochs"]
|
| 192 |
+
zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero2"])
|
| 193 |
+
zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero3"])
|
| 194 |
+
|
| 195 |
+
deepspeed_plugins = {"zero2": zero2_plugin, "zero3": zero3_plugin}
|
| 196 |
+
|
| 197 |
+
# Initialize accelerator
|
| 198 |
+
zero2_accelerator = Accelerator(
|
| 199 |
+
deepspeed_plugins=deepspeed_plugins,
|
| 200 |
+
mixed_precision="bf16",
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Since an `AcceleratorState` has already been made, we can just reuse it here
|
| 204 |
+
zero3_accelerator = Accelerator()
|
| 205 |
+
|
| 206 |
+
# Initialize model under zero2 plugin
|
| 207 |
+
assert get_active_deepspeed_plugin(zero2_accelerator.state) is zero2_plugin
|
| 208 |
+
zero2_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
|
| 209 |
+
train_dataloader, eval_dataloader = get_dataloaders(
|
| 210 |
+
zero2_accelerator, batch_size=config["batch_size"], model_name=args.model_name_or_path
|
| 211 |
+
)
|
| 212 |
+
max_training_steps = len(train_dataloader) * config["num_epochs"]
|
| 213 |
+
zero2_optimizer = AdamW(zero2_model.parameters(), lr=config["lr"])
|
| 214 |
+
zero2_lr_scheduler = get_linear_schedule_with_warmup(
|
| 215 |
+
zero2_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler = zero2_accelerator.prepare(
|
| 219 |
+
train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler
|
| 220 |
+
)
|
| 221 |
+
assert zero2_accelerator.deepspeed_engine_wrapped.engine is zero2_model
|
| 222 |
+
|
| 223 |
+
# now do Zero3
|
| 224 |
+
zero3_accelerator.state.select_deepspeed_plugin("zero3")
|
| 225 |
+
zero3_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = zero2_plugin.deepspeed_config[
|
| 226 |
+
"train_micro_batch_size_per_gpu"
|
| 227 |
+
]
|
| 228 |
+
assert get_active_deepspeed_plugin(zero3_accelerator.state) is zero3_plugin
|
| 229 |
+
zero3_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
|
| 230 |
+
zero3_optimizer = AdamW(zero3_model.parameters(), lr=config["lr"])
|
| 231 |
+
zero3_lr_scheduler = get_linear_schedule_with_warmup(
|
| 232 |
+
zero3_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps
|
| 233 |
+
)
|
| 234 |
+
zero3_model, zero3_optimizer, zero3_lr_scheduler = zero3_accelerator.prepare(
|
| 235 |
+
zero3_model, zero3_optimizer, zero3_lr_scheduler
|
| 236 |
+
)
|
| 237 |
+
assert zero3_accelerator.deepspeed_engine_wrapped.engine is zero3_model
|
| 238 |
+
|
| 239 |
+
# Run training loop
|
| 240 |
+
starting_epoch = 0
|
| 241 |
+
|
| 242 |
+
# Now we train the model
|
| 243 |
+
best_performance_a = 0
|
| 244 |
+
best_performance_b = 0
|
| 245 |
+
metric_a = evaluate.load("glue", "mrpc")
|
| 246 |
+
metric_b = evaluate.load("glue", "mrpc")
|
| 247 |
+
performance_metric_a = {}
|
| 248 |
+
performance_metric_b = {}
|
| 249 |
+
for epoch in range(starting_epoch, num_epochs):
|
| 250 |
+
zero2_model.train()
|
| 251 |
+
zero3_model.train()
|
| 252 |
+
for step, batch in enumerate(train_dataloader):
|
| 253 |
+
with zero2_accelerator.accumulate(zero2_model, zero3_model):
|
| 254 |
+
outputs_1 = zero2_model(**batch)
|
| 255 |
+
zero2_accelerator.backward(outputs_1.loss)
|
| 256 |
+
zero2_optimizer.step()
|
| 257 |
+
zero2_lr_scheduler.step()
|
| 258 |
+
zero2_optimizer.zero_grad()
|
| 259 |
+
outputs_2 = zero3_model(**batch)
|
| 260 |
+
zero3_accelerator.backward(outputs_2.loss)
|
| 261 |
+
zero3_optimizer.step()
|
| 262 |
+
zero3_lr_scheduler.step()
|
| 263 |
+
zero3_optimizer.zero_grad()
|
| 264 |
+
|
| 265 |
+
zero2_model.eval()
|
| 266 |
+
zero3_model.eval()
|
| 267 |
+
for step, batch in enumerate(eval_dataloader):
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
logits_a = zero2_model(**batch).logits
|
| 270 |
+
logits_b = zero3_model(**batch).logits
|
| 271 |
+
# Combine the logits from both models
|
| 272 |
+
predictions_a = logits_a.argmax(dim=-1)
|
| 273 |
+
predictions_b = logits_b.argmax(dim=-1)
|
| 274 |
+
# It is slightly faster to call this once, than multiple times
|
| 275 |
+
predictions_a, predictions_b, references = zero2_accelerator.gather_for_metrics(
|
| 276 |
+
(predictions_a, predictions_b, batch["labels"])
|
| 277 |
+
)
|
| 278 |
+
metric_a.add_batch(
|
| 279 |
+
predictions=predictions_a,
|
| 280 |
+
references=references,
|
| 281 |
+
)
|
| 282 |
+
metric_b.add_batch(
|
| 283 |
+
predictions=predictions_b,
|
| 284 |
+
references=references,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
eval_metric_a = metric_a.compute()
|
| 288 |
+
eval_metric_b = metric_b.compute()
|
| 289 |
+
# Use accelerator.print to print only on the main process.
|
| 290 |
+
zero2_accelerator.print(f"epoch {epoch}:", eval_metric_a, eval_metric_b)
|
| 291 |
+
performance_metric_a[f"epoch-{epoch}"] = eval_metric_a["accuracy"]
|
| 292 |
+
performance_metric_b[f"epoch-{epoch}"] = eval_metric_b["accuracy"]
|
| 293 |
+
|
| 294 |
+
if best_performance_a < eval_metric_a["accuracy"]:
|
| 295 |
+
best_performance_a = eval_metric_a["accuracy"]
|
| 296 |
+
if best_performance_b < eval_metric_b["accuracy"]:
|
| 297 |
+
best_performance_b = eval_metric_b["accuracy"]
|
| 298 |
+
assert best_performance_a > performance_metric_a["epoch-0"]
|
| 299 |
+
assert best_performance_b > performance_metric_b["epoch-0"]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def main():
|
| 303 |
+
parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--model_name_or_path",
|
| 306 |
+
type=str,
|
| 307 |
+
default="bert-base-cased",
|
| 308 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 309 |
+
required=False,
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--performance_lower_bound",
|
| 313 |
+
type=float,
|
| 314 |
+
default=None,
|
| 315 |
+
help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.",
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--num_epochs",
|
| 319 |
+
type=int,
|
| 320 |
+
default=3,
|
| 321 |
+
help="Number of train epochs.",
|
| 322 |
+
)
|
| 323 |
+
args = parser.parse_args()
|
| 324 |
+
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 8}
|
| 325 |
+
single_model_training(config, args)
|
| 326 |
+
AcceleratorState._reset_state(True)
|
| 327 |
+
multiple_model_training(config, args)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
main()
|
accelerate/test_utils/scripts/external_deps/test_metrics.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
|
| 20 |
+
import datasets
|
| 21 |
+
import evaluate
|
| 22 |
+
import torch
|
| 23 |
+
import transformers
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 26 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 27 |
+
|
| 28 |
+
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
|
| 29 |
+
from accelerate.data_loader import DataLoaderDispatcher
|
| 30 |
+
from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device
|
| 31 |
+
from accelerate.utils import is_torch_xla_available, set_seed
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ListHandler(logging.Handler):
|
| 38 |
+
def __init__(self, *args, **kwargs):
|
| 39 |
+
super().__init__(*args, **kwargs)
|
| 40 |
+
self.logs = []
|
| 41 |
+
|
| 42 |
+
def emit(self, record):
|
| 43 |
+
self.logs.append(record)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_basic_setup(accelerator, num_samples=82, batch_size=16):
|
| 47 |
+
"Returns everything needed to perform basic training"
|
| 48 |
+
set_seed(42)
|
| 49 |
+
model = RegressionModel()
|
| 50 |
+
ddp_model = deepcopy(model)
|
| 51 |
+
dset = RegressionDataset(length=num_samples)
|
| 52 |
+
dataloader = DataLoader(dset, batch_size=batch_size)
|
| 53 |
+
model.to(accelerator.device)
|
| 54 |
+
ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
|
| 55 |
+
return model, ddp_model, dataloader
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_dataloader(accelerator: Accelerator, use_longest=False):
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased")
|
| 60 |
+
dataset = load_dataset("glue", "mrpc", split="validation")
|
| 61 |
+
|
| 62 |
+
def tokenize_function(examples):
|
| 63 |
+
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
| 64 |
+
return outputs
|
| 65 |
+
|
| 66 |
+
with accelerator.main_process_first():
|
| 67 |
+
tokenized_datasets = dataset.map(
|
| 68 |
+
tokenize_function,
|
| 69 |
+
batched=True,
|
| 70 |
+
remove_columns=["idx", "sentence1", "sentence2"],
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 74 |
+
|
| 75 |
+
def collate_fn(examples):
|
| 76 |
+
if use_longest:
|
| 77 |
+
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
| 78 |
+
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
| 79 |
+
|
| 80 |
+
return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_mrpc_setup(dispatch_batches, split_batches):
|
| 84 |
+
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches)
|
| 85 |
+
accelerator = Accelerator(dataloader_config=dataloader_config)
|
| 86 |
+
dataloader = get_dataloader(accelerator, not dispatch_batches)
|
| 87 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 88 |
+
"hf-internal-testing/mrpc-bert-base-cased", return_dict=True
|
| 89 |
+
)
|
| 90 |
+
ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader)
|
| 91 |
+
return {
|
| 92 |
+
"ddp": [ddp_model, ddp_dataloader, torch_device],
|
| 93 |
+
"no": [model, dataloader, accelerator.device],
|
| 94 |
+
}, accelerator
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def generate_predictions(model, dataloader, accelerator):
|
| 98 |
+
logits_and_targets = []
|
| 99 |
+
for batch in dataloader:
|
| 100 |
+
input, target = batch.values()
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
logit = model(input)
|
| 103 |
+
logit, target = accelerator.gather_for_metrics((logit, target))
|
| 104 |
+
logits_and_targets.append((logit, target))
|
| 105 |
+
logits, targs = [], []
|
| 106 |
+
for logit, targ in logits_and_targets:
|
| 107 |
+
logits.append(logit)
|
| 108 |
+
targs.append(targ)
|
| 109 |
+
logits, targs = torch.cat(logits), torch.cat(targs)
|
| 110 |
+
return logits, targs
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_torch_metrics(
|
| 114 |
+
accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16
|
| 115 |
+
):
|
| 116 |
+
_, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size)
|
| 117 |
+
logits, _ = generate_predictions(ddp_model, dataloader, accelerator)
|
| 118 |
+
assert len(logits) == num_samples, (
|
| 119 |
+
f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(logits)}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):
|
| 124 |
+
metric = evaluate.load("glue", "mrpc")
|
| 125 |
+
setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches)
|
| 126 |
+
# First do baseline
|
| 127 |
+
model, dataloader, device = setup["no"]
|
| 128 |
+
model.to(device)
|
| 129 |
+
model.eval()
|
| 130 |
+
for batch in dataloader:
|
| 131 |
+
batch.to(device)
|
| 132 |
+
with torch.inference_mode():
|
| 133 |
+
outputs = model(**batch)
|
| 134 |
+
preds = outputs.logits.argmax(dim=-1)
|
| 135 |
+
metric.add_batch(predictions=preds, references=batch["labels"])
|
| 136 |
+
baseline = metric.compute()
|
| 137 |
+
|
| 138 |
+
# Then do distributed
|
| 139 |
+
model, dataloader, device = setup["ddp"]
|
| 140 |
+
model.eval()
|
| 141 |
+
for batch in dataloader:
|
| 142 |
+
with torch.inference_mode():
|
| 143 |
+
outputs = model(**batch)
|
| 144 |
+
preds = outputs.logits.argmax(dim=-1)
|
| 145 |
+
references = batch["labels"]
|
| 146 |
+
preds, references = accelerator.gather_for_metrics((preds, references))
|
| 147 |
+
metric.add_batch(predictions=preds, references=references)
|
| 148 |
+
distributed = metric.compute()
|
| 149 |
+
|
| 150 |
+
for key in "accuracy f1".split():
|
| 151 |
+
assert math.isclose(baseline[key], distributed[key]), (
|
| 152 |
+
f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset():
|
| 157 |
+
class DummyIterableDataset(IterableDataset):
|
| 158 |
+
def __init__(self, data):
|
| 159 |
+
self.data = data
|
| 160 |
+
|
| 161 |
+
def __len__(self):
|
| 162 |
+
return len(self.data)
|
| 163 |
+
|
| 164 |
+
def __iter__(self):
|
| 165 |
+
yield from self.data
|
| 166 |
+
|
| 167 |
+
iterable_dataset = DummyIterableDataset([n for n in range(30)])
|
| 168 |
+
dataloader = DataLoader(iterable_dataset, batch_size=4)
|
| 169 |
+
accelerator = Accelerator()
|
| 170 |
+
prepared_dataloader = accelerator.prepare(dataloader)
|
| 171 |
+
|
| 172 |
+
if accelerator.is_main_process:
|
| 173 |
+
logger = logging.root.manager.loggerDict["accelerate.accelerator"]
|
| 174 |
+
list_handler = ListHandler()
|
| 175 |
+
logger.addHandler(list_handler)
|
| 176 |
+
|
| 177 |
+
batches_for_metrics = []
|
| 178 |
+
for batch in prepared_dataloader:
|
| 179 |
+
batches_for_metrics.append(accelerator.gather_for_metrics(batch))
|
| 180 |
+
|
| 181 |
+
assert torch.cat(batches_for_metrics).size(0) == 30
|
| 182 |
+
|
| 183 |
+
if accelerator.is_main_process:
|
| 184 |
+
assert len(list_handler.logs) == 0
|
| 185 |
+
logger.removeHandler(list_handler)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def test_gather_for_metrics_with_iterable_dataset():
|
| 189 |
+
class DummyIterableDataset(IterableDataset):
|
| 190 |
+
def __init__(self, data):
|
| 191 |
+
self.data = data
|
| 192 |
+
|
| 193 |
+
def __len__(self):
|
| 194 |
+
return len(self.data)
|
| 195 |
+
|
| 196 |
+
def __iter__(self):
|
| 197 |
+
yield from self.data
|
| 198 |
+
|
| 199 |
+
iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30)))
|
| 200 |
+
dataloader = DataLoader(iterable_dataset, batch_size=4)
|
| 201 |
+
|
| 202 |
+
accelerator = Accelerator()
|
| 203 |
+
prepared_dataloader = accelerator.prepare(dataloader)
|
| 204 |
+
|
| 205 |
+
assert isinstance(prepared_dataloader, DataLoaderDispatcher)
|
| 206 |
+
|
| 207 |
+
if accelerator.is_main_process:
|
| 208 |
+
logger = logging.root.manager.loggerDict["accelerate.accelerator"]
|
| 209 |
+
list_handler = ListHandler()
|
| 210 |
+
logger.addHandler(list_handler)
|
| 211 |
+
|
| 212 |
+
batches_for_metrics = []
|
| 213 |
+
for batch in prepared_dataloader:
|
| 214 |
+
batches_for_metrics.append(accelerator.gather_for_metrics(batch))
|
| 215 |
+
|
| 216 |
+
assert torch.cat(batches_for_metrics).size(0) == 30
|
| 217 |
+
|
| 218 |
+
if accelerator.is_main_process:
|
| 219 |
+
assert len(list_handler.logs) == 0
|
| 220 |
+
|
| 221 |
+
logger.removeHandler(list_handler)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def test_gather_for_metrics_drop_last():
|
| 225 |
+
accelerator = Accelerator()
|
| 226 |
+
per_device_batch_size = 5
|
| 227 |
+
num_items = (10 * accelerator.num_processes) + 1
|
| 228 |
+
dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True)
|
| 229 |
+
dataloader = accelerator.prepare(dataloader)
|
| 230 |
+
|
| 231 |
+
iterator = iter(dataloader)
|
| 232 |
+
next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0')
|
| 233 |
+
batch = next(iterator)
|
| 234 |
+
gathered_items = accelerator.gather_for_metrics(batch)
|
| 235 |
+
|
| 236 |
+
# Should return a full set of complete batches from each GPU
|
| 237 |
+
num_expected_items = per_device_batch_size * accelerator.num_processes
|
| 238 |
+
assert gathered_items.size(0) == (num_expected_items), (
|
| 239 |
+
f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def main():
|
| 244 |
+
dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False)
|
| 245 |
+
accelerator = Accelerator(dataloader_config=dataloader_config)
|
| 246 |
+
if accelerator.is_local_main_process:
|
| 247 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 248 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 249 |
+
else:
|
| 250 |
+
datasets.utils.logging.set_verbosity_error()
|
| 251 |
+
transformers.utils.logging.set_verbosity_error()
|
| 252 |
+
# TorchXLA does not support batch dispatching. 'put_on_device' is always False for
|
| 253 |
+
# TorchXLA, which can cause a value error in 'prepare_data_loader' function.
|
| 254 |
+
dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False]
|
| 255 |
+
|
| 256 |
+
# Temporarily close this test for TorchXLA due to the 'Cannot set version_counter for
|
| 257 |
+
# inference tensor' error in inference mode. Reopen it after TorchXLA fixes this bug.
|
| 258 |
+
# These are a bit slower so they should only be ran on the GPU or TPU
|
| 259 |
+
if accelerator.device.type != "cpu" and not is_torch_xla_available():
|
| 260 |
+
if accelerator.is_local_main_process:
|
| 261 |
+
print("**Testing gather_for_metrics**")
|
| 262 |
+
for split_batches in [True, False]:
|
| 263 |
+
for dispatch_batches in dispatch_batches_options:
|
| 264 |
+
if accelerator.is_local_main_process:
|
| 265 |
+
print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`")
|
| 266 |
+
test_mrpc(dispatch_batches, split_batches)
|
| 267 |
+
accelerator.state._reset_state()
|
| 268 |
+
print("test_gather_for_metrics_with_iterable_dataset")
|
| 269 |
+
test_gather_for_metrics_with_iterable_dataset()
|
| 270 |
+
print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset")
|
| 271 |
+
test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()
|
| 272 |
+
|
| 273 |
+
# MpDeviceLoader in TorchXLA is an asynchronous loader that preloads several batches into cache.
|
| 274 |
+
# This can cause the 'end_of_dataloader' of DataLoaderStateMixin to be set earlier than intended.
|
| 275 |
+
# Skip this test when TorchXLA is enabled.
|
| 276 |
+
if accelerator.state.distributed_type != DistributedType.XLA:
|
| 277 |
+
if accelerator.is_local_main_process:
|
| 278 |
+
print("**Test torch metrics**")
|
| 279 |
+
for split_batches in [True, False]:
|
| 280 |
+
for dispatch_batches in dispatch_batches_options:
|
| 281 |
+
dataloader_config = DataLoaderConfiguration(
|
| 282 |
+
split_batches=split_batches, dispatch_batches=dispatch_batches
|
| 283 |
+
)
|
| 284 |
+
accelerator = Accelerator(dataloader_config=dataloader_config)
|
| 285 |
+
if accelerator.is_local_main_process:
|
| 286 |
+
print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99")
|
| 287 |
+
test_torch_metrics(accelerator, 99)
|
| 288 |
+
accelerator.state._reset_state()
|
| 289 |
+
if accelerator.is_local_main_process:
|
| 290 |
+
print("**Test last batch is not dropped when perfectly divisible**")
|
| 291 |
+
accelerator = Accelerator()
|
| 292 |
+
test_torch_metrics(accelerator, 512)
|
| 293 |
+
accelerator.state._reset_state()
|
| 294 |
+
if accelerator.is_local_main_process:
|
| 295 |
+
print("**Test that `drop_last` is taken into account**")
|
| 296 |
+
test_gather_for_metrics_drop_last()
|
| 297 |
+
accelerator.end_training()
|
| 298 |
+
accelerator.state._reset_state()
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _mp_fn(index):
|
| 302 |
+
# For xla_spawn (TPUs)
|
| 303 |
+
main()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
main()
|
accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import gc
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from datasets import load_dataset
|
| 21 |
+
from torch.optim import AdamW
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
| 24 |
+
|
| 25 |
+
from accelerate import Accelerator, DistributedType
|
| 26 |
+
from accelerate.utils import (
|
| 27 |
+
is_hpu_available,
|
| 28 |
+
is_mlu_available,
|
| 29 |
+
is_musa_available,
|
| 30 |
+
is_neuron_available,
|
| 31 |
+
is_npu_available,
|
| 32 |
+
is_sdaa_available,
|
| 33 |
+
is_xpu_available,
|
| 34 |
+
)
|
| 35 |
+
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
MAX_GPU_BATCH_SIZE = 16
|
| 39 |
+
EVAL_BATCH_SIZE = 32
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Converting Bytes to Megabytes
|
| 43 |
+
def b2mb(x):
|
| 44 |
+
return int(x / 2**20)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# This context manager is used to track the peak memory usage of the process
|
| 48 |
+
class TorchTracemalloc:
|
| 49 |
+
def __enter__(self):
|
| 50 |
+
gc.collect()
|
| 51 |
+
if torch.cuda.is_available():
|
| 52 |
+
torch.cuda.empty_cache()
|
| 53 |
+
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
|
| 54 |
+
self.begin = torch.cuda.memory_allocated()
|
| 55 |
+
elif is_mlu_available():
|
| 56 |
+
torch.mlu.empty_cache()
|
| 57 |
+
torch.mlu.reset_max_memory_allocated() # reset the peak gauge to zero
|
| 58 |
+
self.begin = torch.mlu.memory_allocated()
|
| 59 |
+
elif is_sdaa_available():
|
| 60 |
+
torch.sdaa.empty_cache()
|
| 61 |
+
torch.sdaa.reset_max_memory_allocated() # reset the peak gauge to zero
|
| 62 |
+
self.begin = torch.sdaa.memory_allocated()
|
| 63 |
+
elif is_musa_available():
|
| 64 |
+
torch.musa.empty_cache()
|
| 65 |
+
torch.musa.reset_max_memory_allocated() # reset the peak gauge to zero
|
| 66 |
+
self.begin = torch.musa.memory_allocated()
|
| 67 |
+
elif is_npu_available():
|
| 68 |
+
torch.npu.empty_cache()
|
| 69 |
+
torch.npu.reset_max_memory_allocated() # reset the peak gauge to zero
|
| 70 |
+
self.begin = torch.npu.memory_allocated()
|
| 71 |
+
elif is_xpu_available():
|
| 72 |
+
torch.xpu.empty_cache()
|
| 73 |
+
torch.xpu.reset_peak_memory_stats() # reset the peak gauge to zero
|
| 74 |
+
self.begin = torch.xpu.memory_allocated()
|
| 75 |
+
elif is_hpu_available():
|
| 76 |
+
# torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
|
| 77 |
+
torch.hpu.reset_peak_memory_stats() # reset the peak gauge to zero
|
| 78 |
+
self.begin = torch.hpu.memory_allocated()
|
| 79 |
+
elif is_neuron_available():
|
| 80 |
+
torch.neuron.empty_cache()
|
| 81 |
+
torch.neuron.reset_peak_memory_stats() # reset the peak gauge to zero
|
| 82 |
+
self.begin = torch.neuron.memory_allocated()
|
| 83 |
+
return self
|
| 84 |
+
|
| 85 |
+
def __exit__(self, *exc):
|
| 86 |
+
gc.collect()
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
torch.cuda.empty_cache()
|
| 89 |
+
self.end = torch.cuda.memory_allocated()
|
| 90 |
+
self.peak = torch.cuda.max_memory_allocated()
|
| 91 |
+
elif is_mlu_available():
|
| 92 |
+
torch.mlu.empty_cache()
|
| 93 |
+
self.end = torch.mlu.memory_allocated()
|
| 94 |
+
self.begin = torch.mlu.max_memory_allocated()
|
| 95 |
+
elif is_sdaa_available():
|
| 96 |
+
torch.sdaa.empty_cache()
|
| 97 |
+
self.end = torch.sdaa.memory_allocated()
|
| 98 |
+
self.begin = torch.sdaa.max_memory_allocated()
|
| 99 |
+
elif is_musa_available():
|
| 100 |
+
torch.musa.empty_cache()
|
| 101 |
+
self.end = torch.musa.memory_allocated()
|
| 102 |
+
self.begin = torch.musa.max_memory_allocated()
|
| 103 |
+
elif is_npu_available():
|
| 104 |
+
torch.npu.empty_cache()
|
| 105 |
+
self.end = torch.npu.memory_allocated()
|
| 106 |
+
self.peak = torch.npu.max_memory_allocated()
|
| 107 |
+
elif is_xpu_available():
|
| 108 |
+
torch.xpu.empty_cache()
|
| 109 |
+
self.end = torch.xpu.memory_allocated()
|
| 110 |
+
self.peak = torch.xpu.max_memory_allocated()
|
| 111 |
+
elif is_hpu_available():
|
| 112 |
+
# torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
|
| 113 |
+
self.end = torch.hpu.memory_allocated()
|
| 114 |
+
self.peak = torch.hpu.max_memory_allocated()
|
| 115 |
+
elif is_neuron_available():
|
| 116 |
+
torch.neuron.empty_cache()
|
| 117 |
+
self.end = torch.neuron.memory_allocated()
|
| 118 |
+
self.peak = torch.neuron.max_memory_allocated()
|
| 119 |
+
self.used = b2mb(self.end - self.begin)
|
| 120 |
+
self.peaked = b2mb(self.peak - self.begin)
|
| 121 |
+
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_dataloaders(
|
| 125 |
+
accelerator: Accelerator,
|
| 126 |
+
batch_size: int = 16,
|
| 127 |
+
model_name: str = "bert-base-cased",
|
| 128 |
+
n_train: int = 320,
|
| 129 |
+
n_val: int = 160,
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Creates a set of `DataLoader`s for the `glue` dataset.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
accelerator (`Accelerator`):
|
| 136 |
+
An `Accelerator` object
|
| 137 |
+
batch_size (`int`, *optional*):
|
| 138 |
+
The batch size for the train and validation DataLoaders.
|
| 139 |
+
model_name (`str`, *optional*):
|
| 140 |
+
The name of the model to use.
|
| 141 |
+
n_train (`int`, *optional*):
|
| 142 |
+
The number of training examples to use.
|
| 143 |
+
n_val (`int`, *optional*):
|
| 144 |
+
The number of validation examples to use.
|
| 145 |
+
"""
|
| 146 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 147 |
+
datasets = load_dataset(
|
| 148 |
+
"glue", "mrpc", split={"train": f"train[:{n_train}]", "validation": f"validation[:{n_val}]"}
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def tokenize_function(examples):
|
| 152 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 153 |
+
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
| 154 |
+
return outputs
|
| 155 |
+
|
| 156 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 157 |
+
tokenized_datasets = datasets.map(
|
| 158 |
+
tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
|
| 162 |
+
# transformers library
|
| 163 |
+
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 164 |
+
|
| 165 |
+
def collate_fn(examples):
|
| 166 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 167 |
+
if accelerator.distributed_type == DistributedType.XLA:
|
| 168 |
+
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
| 169 |
+
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
| 170 |
+
|
| 171 |
+
# Instantiate dataloaders.
|
| 172 |
+
train_dataloader = DataLoader(
|
| 173 |
+
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
|
| 174 |
+
)
|
| 175 |
+
eval_dataloader = DataLoader(
|
| 176 |
+
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
return train_dataloader, eval_dataloader
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def training_function(config, args):
|
| 183 |
+
# Initialize accelerator
|
| 184 |
+
accelerator = Accelerator()
|
| 185 |
+
|
| 186 |
+
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
| 187 |
+
lr = config["lr"]
|
| 188 |
+
num_epochs = int(config["num_epochs"])
|
| 189 |
+
seed = int(config["seed"])
|
| 190 |
+
batch_size = int(config["batch_size"])
|
| 191 |
+
model_name = args.model_name_or_path
|
| 192 |
+
|
| 193 |
+
set_seed(seed)
|
| 194 |
+
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name, args.n_train, args.n_val)
|
| 195 |
+
|
| 196 |
+
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
| 197 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
|
| 198 |
+
|
| 199 |
+
# Instantiate optimizer
|
| 200 |
+
optimizer_cls = (
|
| 201 |
+
AdamW
|
| 202 |
+
if accelerator.state.deepspeed_plugin is None
|
| 203 |
+
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 204 |
+
else DummyOptim
|
| 205 |
+
)
|
| 206 |
+
optimizer = optimizer_cls(params=model.parameters(), lr=lr)
|
| 207 |
+
|
| 208 |
+
if accelerator.state.deepspeed_plugin is not None:
|
| 209 |
+
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
| 210 |
+
"gradient_accumulation_steps"
|
| 211 |
+
]
|
| 212 |
+
else:
|
| 213 |
+
gradient_accumulation_steps = 1
|
| 214 |
+
max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps
|
| 215 |
+
|
| 216 |
+
# Instantiate scheduler
|
| 217 |
+
if (
|
| 218 |
+
accelerator.state.deepspeed_plugin is None
|
| 219 |
+
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 220 |
+
):
|
| 221 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
| 222 |
+
optimizer=optimizer,
|
| 223 |
+
num_warmup_steps=0,
|
| 224 |
+
num_training_steps=max_training_steps,
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)
|
| 228 |
+
|
| 229 |
+
# Prepare everything
|
| 230 |
+
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
|
| 231 |
+
# prepare method.
|
| 232 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
| 233 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# We need to keep track of how many total steps we have iterated over
|
| 237 |
+
overall_step = 0
|
| 238 |
+
# We also need to keep track of the stating epoch so files are named properly
|
| 239 |
+
starting_epoch = 0
|
| 240 |
+
|
| 241 |
+
# Now we train the model
|
| 242 |
+
train_total_peak_memory = {}
|
| 243 |
+
for epoch in range(starting_epoch, num_epochs):
|
| 244 |
+
with TorchTracemalloc() as tracemalloc:
|
| 245 |
+
model.train()
|
| 246 |
+
for step, batch in enumerate(train_dataloader):
|
| 247 |
+
outputs = model(**batch)
|
| 248 |
+
loss = outputs.loss
|
| 249 |
+
loss = loss / gradient_accumulation_steps
|
| 250 |
+
accelerator.backward(loss)
|
| 251 |
+
if step % gradient_accumulation_steps == 0:
|
| 252 |
+
optimizer.step()
|
| 253 |
+
lr_scheduler.step()
|
| 254 |
+
optimizer.zero_grad()
|
| 255 |
+
|
| 256 |
+
overall_step += 1
|
| 257 |
+
|
| 258 |
+
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
|
| 259 |
+
accelerator.print(f"Memory before entering the train : {b2mb(tracemalloc.begin)}")
|
| 260 |
+
accelerator.print(f"Memory consumed at the end of the train (end-begin): {tracemalloc.used}")
|
| 261 |
+
accelerator.print(f"Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}")
|
| 262 |
+
accelerator.print(
|
| 263 |
+
f"Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
|
| 264 |
+
)
|
| 265 |
+
train_total_peak_memory[f"epoch-{epoch}"] = tracemalloc.peaked + b2mb(tracemalloc.begin)
|
| 266 |
+
if args.peak_memory_upper_bound is not None:
|
| 267 |
+
assert train_total_peak_memory[f"epoch-{epoch}"] <= args.peak_memory_upper_bound, (
|
| 268 |
+
"Peak memory usage exceeded the upper bound"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
accelerator.wait_for_everyone()
|
| 272 |
+
if accelerator.is_main_process:
|
| 273 |
+
with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f:
|
| 274 |
+
json.dump(train_total_peak_memory, f)
|
| 275 |
+
accelerator.end_training()
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def main():
|
| 279 |
+
parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--model_name_or_path",
|
| 282 |
+
type=str,
|
| 283 |
+
default="bert-base-cased",
|
| 284 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 285 |
+
required=False,
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--output_dir",
|
| 289 |
+
type=str,
|
| 290 |
+
default=".",
|
| 291 |
+
help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--peak_memory_upper_bound",
|
| 295 |
+
type=float,
|
| 296 |
+
default=None,
|
| 297 |
+
help="The upper bound of peak memory usage in MB. If set, the training will throw an error if the peak memory usage exceeds this value.",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--n_train",
|
| 301 |
+
type=int,
|
| 302 |
+
default=320,
|
| 303 |
+
help="Number of training examples to use.",
|
| 304 |
+
)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--n_val",
|
| 307 |
+
type=int,
|
| 308 |
+
default=160,
|
| 309 |
+
help="Number of validation examples to use.",
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--num_epochs",
|
| 313 |
+
type=int,
|
| 314 |
+
default=1,
|
| 315 |
+
help="Number of train epochs.",
|
| 316 |
+
)
|
| 317 |
+
args = parser.parse_args()
|
| 318 |
+
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
|
| 319 |
+
training_function(config, args)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
if __name__ == "__main__":
|
| 323 |
+
main()
|
accelerate/test_utils/scripts/external_deps/test_performance.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from contextlib import nullcontext
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import evaluate
|
| 21 |
+
import torch
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
from torch.optim import AdamW
|
| 24 |
+
from torch.utils.data import DataLoader
|
| 25 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
|
| 26 |
+
|
| 27 |
+
from accelerate import Accelerator, DistributedType
|
| 28 |
+
from accelerate.parallelism_config import ParallelismConfig
|
| 29 |
+
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
|
| 30 |
+
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
MAX_GPU_BATCH_SIZE = 16
|
| 34 |
+
EVAL_BATCH_SIZE = 32
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"):
|
| 38 |
+
"""
|
| 39 |
+
Creates a set of `DataLoader`s for the `glue` dataset.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
accelerator (`Accelerator`):
|
| 43 |
+
An `Accelerator` object
|
| 44 |
+
batch_size (`int`, *optional*):
|
| 45 |
+
The batch size for the train and validation DataLoaders.
|
| 46 |
+
model_name (`str`, *optional*):
|
| 47 |
+
"""
|
| 48 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 49 |
+
|
| 50 |
+
datasets = load_dataset("glue", "mrpc")
|
| 51 |
+
|
| 52 |
+
def tokenize_function(examples):
|
| 53 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 54 |
+
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
| 55 |
+
return outputs
|
| 56 |
+
|
| 57 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 58 |
+
tokenized_datasets = datasets.map(
|
| 59 |
+
tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
|
| 63 |
+
# transformers library
|
| 64 |
+
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 65 |
+
|
| 66 |
+
def collate_fn(examples):
|
| 67 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 68 |
+
if accelerator.distributed_type == DistributedType.XLA:
|
| 69 |
+
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
| 70 |
+
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
| 71 |
+
|
| 72 |
+
# Instantiate dataloaders.
|
| 73 |
+
train_dataloader = DataLoader(
|
| 74 |
+
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
|
| 75 |
+
)
|
| 76 |
+
eval_dataloader = DataLoader(
|
| 77 |
+
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return train_dataloader, eval_dataloader
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def training_function(config, args):
|
| 84 |
+
accelerator_kwargs = {}
|
| 85 |
+
# need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail
|
| 86 |
+
if args.tp_size is not None:
|
| 87 |
+
accelerator_kwargs["parallelism_config"] = ParallelismConfig(tp_size=args.tp_size)
|
| 88 |
+
|
| 89 |
+
# Initialize accelerator
|
| 90 |
+
accelerator = Accelerator(**accelerator_kwargs)
|
| 91 |
+
|
| 92 |
+
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
| 93 |
+
lr = config["lr"]
|
| 94 |
+
num_epochs = int(config["num_epochs"])
|
| 95 |
+
seed = int(config["seed"])
|
| 96 |
+
batch_size = int(config["batch_size"])
|
| 97 |
+
model_name = args.model_name_or_path
|
| 98 |
+
|
| 99 |
+
set_seed(seed)
|
| 100 |
+
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
|
| 101 |
+
|
| 102 |
+
# Add TP related kwargs if provided
|
| 103 |
+
model_kwargs = {}
|
| 104 |
+
if args.tp_plan is not None:
|
| 105 |
+
model_kwargs["tp_plan"] = args.tp_plan
|
| 106 |
+
if args.tp_size is not None:
|
| 107 |
+
model_kwargs["tp_size"] = args.tp_size
|
| 108 |
+
|
| 109 |
+
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
| 110 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, **model_kwargs)
|
| 111 |
+
|
| 112 |
+
if args.add_pad_token:
|
| 113 |
+
if model.config.pad_token_id is None:
|
| 114 |
+
model.config.pad_token_id = 0
|
| 115 |
+
|
| 116 |
+
# Instantiate optimizer
|
| 117 |
+
optimizer_cls = (
|
| 118 |
+
AdamW
|
| 119 |
+
if accelerator.state.deepspeed_plugin is None
|
| 120 |
+
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 121 |
+
else DummyOptim
|
| 122 |
+
)
|
| 123 |
+
optimizer = optimizer_cls(params=model.parameters(), lr=lr)
|
| 124 |
+
|
| 125 |
+
max_training_steps = len(train_dataloader) * num_epochs
|
| 126 |
+
|
| 127 |
+
# Instantiate scheduler
|
| 128 |
+
linear_decay_scheduler = False
|
| 129 |
+
if (
|
| 130 |
+
accelerator.state.deepspeed_plugin is None
|
| 131 |
+
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
| 132 |
+
):
|
| 133 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
| 134 |
+
optimizer=optimizer,
|
| 135 |
+
num_warmup_steps=0,
|
| 136 |
+
num_training_steps=max_training_steps,
|
| 137 |
+
)
|
| 138 |
+
linear_decay_scheduler = True
|
| 139 |
+
else:
|
| 140 |
+
lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)
|
| 141 |
+
|
| 142 |
+
# Prepare everything
|
| 143 |
+
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
|
| 144 |
+
# prepare method.
|
| 145 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
| 146 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# We also need to keep track of the stating epoch so files are named properly
|
| 150 |
+
starting_epoch = 0
|
| 151 |
+
|
| 152 |
+
# Now we train the model
|
| 153 |
+
metric = evaluate.load("glue", "mrpc")
|
| 154 |
+
best_performance = 0
|
| 155 |
+
performance_metric = {}
|
| 156 |
+
expected_lr_after_first_optim_step = lr * (
|
| 157 |
+
1 - 1 / (max_training_steps / accelerator.num_processes / accelerator.gradient_accumulation_steps)
|
| 158 |
+
)
|
| 159 |
+
lr_scheduler_check_completed = False
|
| 160 |
+
for epoch in range(starting_epoch, num_epochs):
|
| 161 |
+
model.train()
|
| 162 |
+
for step, batch in enumerate(train_dataloader):
|
| 163 |
+
with accelerator.accumulate(model):
|
| 164 |
+
outputs = model(**batch)
|
| 165 |
+
loss = outputs.loss
|
| 166 |
+
accelerator.backward(loss)
|
| 167 |
+
context = nullcontext
|
| 168 |
+
if args.tp_plan is not None:
|
| 169 |
+
from torch.distributed._tensor.experimental import implicit_replication
|
| 170 |
+
|
| 171 |
+
context = implicit_replication
|
| 172 |
+
with context():
|
| 173 |
+
optimizer.step()
|
| 174 |
+
lr_scheduler.step()
|
| 175 |
+
optimizer.zero_grad()
|
| 176 |
+
|
| 177 |
+
# assert the learning rate after first optimizer step
|
| 178 |
+
if (
|
| 179 |
+
accelerator.sync_gradients
|
| 180 |
+
and not lr_scheduler_check_completed
|
| 181 |
+
and linear_decay_scheduler
|
| 182 |
+
and accelerator.state.mixed_precision == "no"
|
| 183 |
+
):
|
| 184 |
+
assert lr_scheduler.get_last_lr()[0] == expected_lr_after_first_optim_step, (
|
| 185 |
+
f"Wrong lr found at second step, expected {expected_lr_after_first_optim_step}, got {lr_scheduler.get_last_lr()[0]}"
|
| 186 |
+
)
|
| 187 |
+
lr_scheduler_check_completed = True
|
| 188 |
+
|
| 189 |
+
model.eval()
|
| 190 |
+
samples_seen = 0
|
| 191 |
+
for step, batch in enumerate(eval_dataloader):
|
| 192 |
+
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
| 193 |
+
batch.to(accelerator.device)
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
outputs = model(**batch)
|
| 196 |
+
predictions = outputs.logits.argmax(dim=-1)
|
| 197 |
+
# It is slightly faster to call this once, than multiple times
|
| 198 |
+
predictions, references = accelerator.gather(
|
| 199 |
+
(predictions, batch["labels"])
|
| 200 |
+
) # If we are in a multiprocess environment, the last batch has duplicates
|
| 201 |
+
if accelerator.use_distributed:
|
| 202 |
+
if step == len(eval_dataloader) - 1:
|
| 203 |
+
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
|
| 204 |
+
references = references[: len(eval_dataloader.dataset) - samples_seen]
|
| 205 |
+
else:
|
| 206 |
+
samples_seen += references.shape[0]
|
| 207 |
+
metric.add_batch(
|
| 208 |
+
predictions=predictions,
|
| 209 |
+
references=references,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
eval_metric = metric.compute()
|
| 213 |
+
# Use accelerator.print to print only on the main process.
|
| 214 |
+
accelerator.print(f"epoch {epoch}:", eval_metric)
|
| 215 |
+
performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"]
|
| 216 |
+
|
| 217 |
+
if best_performance < eval_metric["accuracy"]:
|
| 218 |
+
best_performance = eval_metric["accuracy"]
|
| 219 |
+
|
| 220 |
+
# check that the LR is 0
|
| 221 |
+
if linear_decay_scheduler and accelerator.state.mixed_precision == "no":
|
| 222 |
+
assert lr_scheduler.get_last_lr()[0] == 0, (
|
| 223 |
+
f"Wrong lr found at last step, expected 0, got {lr_scheduler.get_last_lr()[0]}"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if args.performance_lower_bound is not None:
|
| 227 |
+
assert args.performance_lower_bound <= best_performance, (
|
| 228 |
+
f"Best performance metric {best_performance} is lower than the lower bound {args.performance_lower_bound}"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
accelerator.wait_for_everyone()
|
| 232 |
+
if accelerator.is_main_process:
|
| 233 |
+
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
| 234 |
+
json.dump(performance_metric, f)
|
| 235 |
+
|
| 236 |
+
# TODO: skip saving of the model test for TP until the feature lands
|
| 237 |
+
if args.tp_plan is None:
|
| 238 |
+
# Finally try saving the model
|
| 239 |
+
accelerator.save_model(model, args.output_dir)
|
| 240 |
+
accelerator.wait_for_everyone()
|
| 241 |
+
if args.tp_plan is None:
|
| 242 |
+
assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
|
| 243 |
+
"Model was not saved when calling `Accelerator.save_model`"
|
| 244 |
+
)
|
| 245 |
+
accelerator.end_training()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def main():
|
| 249 |
+
parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--model_name_or_path",
|
| 252 |
+
type=str,
|
| 253 |
+
default="bert-base-cased",
|
| 254 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 255 |
+
required=False,
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--output_dir",
|
| 259 |
+
type=str,
|
| 260 |
+
default=".",
|
| 261 |
+
help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--performance_lower_bound",
|
| 265 |
+
type=float,
|
| 266 |
+
default=None,
|
| 267 |
+
help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.",
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--num_epochs",
|
| 271 |
+
type=int,
|
| 272 |
+
default=3,
|
| 273 |
+
help="Number of train epochs.",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--add_pad_token",
|
| 277 |
+
type=bool,
|
| 278 |
+
default=False,
|
| 279 |
+
help="To add pad token if not exists.",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--tp_plan",
|
| 283 |
+
type=str,
|
| 284 |
+
default=None,
|
| 285 |
+
help="pass 'auto' to use TP",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--tp_size",
|
| 289 |
+
type=int,
|
| 290 |
+
default=None,
|
| 291 |
+
help="TP size to be used to shard the model",
|
| 292 |
+
)
|
| 293 |
+
args = parser.parse_args()
|
| 294 |
+
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
|
| 295 |
+
training_function(config, args)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
main()
|
accelerate/test_utils/scripts/external_deps/test_pippy.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import (
|
| 16 |
+
BertConfig,
|
| 17 |
+
BertForMaskedLM,
|
| 18 |
+
GPT2Config,
|
| 19 |
+
GPT2ForSequenceClassification,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from accelerate import PartialState
|
| 23 |
+
from accelerate.inference import prepare_pippy
|
| 24 |
+
from accelerate.test_utils import torch_device
|
| 25 |
+
from accelerate.utils import DistributedType, set_seed
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
model_to_config = {
|
| 29 |
+
"bert": (BertForMaskedLM, BertConfig, 512),
|
| 30 |
+
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
|
| 35 |
+
initializer, config, seq_len = model_to_config[model_name]
|
| 36 |
+
config_args = {}
|
| 37 |
+
# Eventually needed for batch inference tests on gpt-2 when bs != 1
|
| 38 |
+
# if model_name == "gpt2":
|
| 39 |
+
# config_args["pad_token_id"] = 0
|
| 40 |
+
model_config = config(**config_args)
|
| 41 |
+
model = initializer(model_config)
|
| 42 |
+
kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
|
| 43 |
+
trace_input = torch.randint(size=(1, seq_len), **kwargs)
|
| 44 |
+
inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
|
| 45 |
+
return model, trace_input, inference_inputs
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_bert(batch_size: int = 2):
|
| 49 |
+
set_seed(42)
|
| 50 |
+
state = PartialState()
|
| 51 |
+
model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
|
| 52 |
+
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
|
| 53 |
+
# For inference args need to be a tuple
|
| 54 |
+
inputs = inference_inputs.to(torch_device)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
output = model(inputs)
|
| 57 |
+
# Zach: Check that we just grab the real outputs we need at the end
|
| 58 |
+
if not state.is_last_process:
|
| 59 |
+
assert output is None, "Output was not generated on just the last process!"
|
| 60 |
+
else:
|
| 61 |
+
assert output is not None, "Output was not generated in the last process!"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_gpt2(batch_size: int = 2):
|
| 65 |
+
set_seed(42)
|
| 66 |
+
state = PartialState()
|
| 67 |
+
model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
|
| 68 |
+
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
|
| 69 |
+
# For inference args need to be a tuple
|
| 70 |
+
inputs = inference_inputs.to(torch_device)
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
output = model(inputs)
|
| 73 |
+
# Zach: Check that we just grab the real outputs we need at the end
|
| 74 |
+
if not state.is_last_process:
|
| 75 |
+
assert output is None, "Output was not generated on just the last process!"
|
| 76 |
+
else:
|
| 77 |
+
assert output is not None, "Output was not generated in the last process!"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
|
| 81 |
+
# def test_resnet(batch_size: int = 2):
|
| 82 |
+
# set_seed(42)
|
| 83 |
+
# state = PartialState()
|
| 84 |
+
# model = resnet34()
|
| 85 |
+
# input_tensor = torch.rand(1, 3, 224, 224)
|
| 86 |
+
# model = prepare_pippy(
|
| 87 |
+
# model,
|
| 88 |
+
# example_args=(input_tensor,),
|
| 89 |
+
# )
|
| 90 |
+
# inference_inputs = torch.rand(batch_size, 3, 224, 224)
|
| 91 |
+
# inputs = send_to_device(inference_inputs, torch_device)
|
| 92 |
+
# with torch.no_grad():
|
| 93 |
+
# output = model(inputs)
|
| 94 |
+
# # Zach: Check that we just grab the real outputs we need at the end
|
| 95 |
+
# if not state.is_last_process:
|
| 96 |
+
# assert output is None, "Output was not generated on just the last process!"
|
| 97 |
+
# else:
|
| 98 |
+
# assert output is not None, "Output was not generated in the last process!"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
state = PartialState()
|
| 103 |
+
state.print("Testing pippy integration...")
|
| 104 |
+
try:
|
| 105 |
+
if state.distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_HPU]:
|
| 106 |
+
state.print("Testing GPT2...")
|
| 107 |
+
test_gpt2()
|
| 108 |
+
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
|
| 109 |
+
# due to references
|
| 110 |
+
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
|
| 111 |
+
# test_gpt2(3)
|
| 112 |
+
state.print("Testing BERT...")
|
| 113 |
+
test_bert()
|
| 114 |
+
else:
|
| 115 |
+
print("Less than two GPUs found, not running tests!")
|
| 116 |
+
finally:
|
| 117 |
+
state.destroy_process_group()
|
accelerate/test_utils/scripts/external_deps/test_zero3_integration.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch.distributed
|
| 16 |
+
|
| 17 |
+
from accelerate.test_utils import require_huggingface_suite, torch_device
|
| 18 |
+
from accelerate.utils import is_transformers_available
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if is_transformers_available():
|
| 22 |
+
from transformers import AutoModel, TrainingArguments
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
GPT2_TINY = "sshleifer/tiny-gpt2"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@require_huggingface_suite
|
| 29 |
+
def init_torch_dist_then_launch_deepspeed():
|
| 30 |
+
if torch_device == "xpu":
|
| 31 |
+
backend = "xccl"
|
| 32 |
+
elif torch_device == "hpu":
|
| 33 |
+
backend = "hccl"
|
| 34 |
+
else:
|
| 35 |
+
backend = "nccl"
|
| 36 |
+
|
| 37 |
+
torch.distributed.init_process_group(backend=backend)
|
| 38 |
+
deepspeed_config = {
|
| 39 |
+
"zero_optimization": {
|
| 40 |
+
"stage": 3,
|
| 41 |
+
},
|
| 42 |
+
"train_batch_size": "auto",
|
| 43 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 44 |
+
}
|
| 45 |
+
train_args = TrainingArguments(
|
| 46 |
+
output_dir="./",
|
| 47 |
+
deepspeed=deepspeed_config,
|
| 48 |
+
)
|
| 49 |
+
model = AutoModel.from_pretrained(GPT2_TINY)
|
| 50 |
+
assert train_args is not None
|
| 51 |
+
assert model is not None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
init_torch_dist_then_launch_deepspeed()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|
accelerate/test_utils/scripts/test_cli.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from accelerate.utils import is_xpu_available
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
accelerator_type = "GPU"
|
| 21 |
+
num_accelerators = 0
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
num_accelerators = torch.cuda.device_count()
|
| 24 |
+
accelerator_type = "GPU"
|
| 25 |
+
elif is_xpu_available():
|
| 26 |
+
num_accelerators = torch.xpu.device_count()
|
| 27 |
+
accelerator_type = "XPU"
|
| 28 |
+
print(f"Successfully ran on {num_accelerators} {accelerator_type}s")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
accelerate/test_utils/scripts/test_ddp_comm_hook.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState
|
| 17 |
+
from accelerate.utils import is_hpu_available
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MockModel(torch.nn.Module):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
torch.manual_seed(0)
|
| 24 |
+
self.p = torch.nn.Parameter(torch.randn(40, 20))
|
| 25 |
+
|
| 26 |
+
def forward(self, x, rank):
|
| 27 |
+
return self.p * (x ** (1 + rank))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _run_and_get_grads(model, rank):
|
| 31 |
+
torch.manual_seed(2024)
|
| 32 |
+
input = torch.randn(40, 20)
|
| 33 |
+
output = model(input, rank)
|
| 34 |
+
output.mean().backward()
|
| 35 |
+
param = next(model.parameters())
|
| 36 |
+
return param.grad
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option):
|
| 40 |
+
ddp_kwargs = DistributedDataParallelKwargs(
|
| 41 |
+
comm_hook=comm_hook,
|
| 42 |
+
comm_wrapper=comm_wrapper,
|
| 43 |
+
comm_state_option=comm_state_option,
|
| 44 |
+
)
|
| 45 |
+
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
| 46 |
+
|
| 47 |
+
model = accelerator.prepare(MockModel())
|
| 48 |
+
hook_grads = _run_and_get_grads(model, accelerator.local_process_index)
|
| 49 |
+
|
| 50 |
+
reference_model = torch.nn.parallel.DistributedDataParallel(
|
| 51 |
+
MockModel().to(accelerator.device),
|
| 52 |
+
device_ids=[accelerator.local_process_index],
|
| 53 |
+
output_device=accelerator.local_process_index,
|
| 54 |
+
)
|
| 55 |
+
reference_grads = _run_and_get_grads(reference_model, accelerator.local_process_index)
|
| 56 |
+
|
| 57 |
+
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-2, atol=1e-2)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
for comm_hook, comm_wrapper, comm_state_option in [
|
| 62 |
+
(DDPCommunicationHookType.NO, DDPCommunicationHookType.NO, {}),
|
| 63 |
+
(DDPCommunicationHookType.FP16, DDPCommunicationHookType.NO, {}),
|
| 64 |
+
(DDPCommunicationHookType.BF16, DDPCommunicationHookType.NO, {}),
|
| 65 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {}),
|
| 66 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.FP16, {}),
|
| 67 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BF16, {}),
|
| 68 |
+
(DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {"matrix_approximation_rank": 2}),
|
| 69 |
+
(DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.NO, {}),
|
| 70 |
+
(DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.FP16, {}),
|
| 71 |
+
(DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.BF16, {}),
|
| 72 |
+
]:
|
| 73 |
+
if is_hpu_available():
|
| 74 |
+
HPU_UNSUPPORTED_COMM_HOOKS = {DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16}
|
| 75 |
+
if comm_hook in HPU_UNSUPPORTED_COMM_HOOKS or comm_wrapper in HPU_UNSUPPORTED_COMM_HOOKS:
|
| 76 |
+
print(f"Skipping test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper} on HPU")
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
|
| 80 |
+
test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
|
| 81 |
+
PartialState().destroy_process_group()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
accelerate/test_utils/scripts/test_distributed_data_loop.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import pickle
|
| 18 |
+
import tempfile
|
| 19 |
+
import warnings
|
| 20 |
+
from unittest.mock import Mock
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch.utils.data import (
|
| 24 |
+
BatchSampler,
|
| 25 |
+
DataLoader,
|
| 26 |
+
Dataset,
|
| 27 |
+
IterableDataset,
|
| 28 |
+
RandomSampler,
|
| 29 |
+
TensorDataset,
|
| 30 |
+
default_collate,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from accelerate.accelerator import Accelerator, DataLoaderConfiguration
|
| 34 |
+
from accelerate.utils.dataclasses import DistributedType
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
NUM_ELEMENTS = 22
|
| 38 |
+
NUM_WORKERS = 4
|
| 39 |
+
BATCH_SIZE = 4
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DummyDataset(Dataset):
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return NUM_ELEMENTS
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, index):
|
| 47 |
+
squeeze = False
|
| 48 |
+
|
| 49 |
+
if isinstance(index, int):
|
| 50 |
+
index = [index]
|
| 51 |
+
squeeze = True
|
| 52 |
+
elif isinstance(index, slice):
|
| 53 |
+
index = list(range(*index.indices(self.size)))
|
| 54 |
+
else:
|
| 55 |
+
index = list(index)
|
| 56 |
+
|
| 57 |
+
batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index]
|
| 58 |
+
|
| 59 |
+
if squeeze:
|
| 60 |
+
batch = batch[0]
|
| 61 |
+
|
| 62 |
+
return batch
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DummyIterableDataset(IterableDataset):
|
| 66 |
+
def __init__(self, data):
|
| 67 |
+
self.data = data
|
| 68 |
+
|
| 69 |
+
def __iter__(self):
|
| 70 |
+
yield from self.data
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def create_accelerator(even_batches=True):
|
| 74 |
+
dataloader_config = DataLoaderConfiguration(even_batches=even_batches)
|
| 75 |
+
accelerator = Accelerator(dataloader_config=dataloader_config)
|
| 76 |
+
assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
|
| 77 |
+
return accelerator
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def create_dataloader(
|
| 81 |
+
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Create a simple DataLoader to use during the test cases
|
| 85 |
+
"""
|
| 86 |
+
values = torch.as_tensor(range(dataset_size))
|
| 87 |
+
if shuffle:
|
| 88 |
+
values = values[torch.randperm(values.size(0))]
|
| 89 |
+
if iterable:
|
| 90 |
+
dataset = DummyIterableDataset(values)
|
| 91 |
+
else:
|
| 92 |
+
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
|
| 93 |
+
|
| 94 |
+
dl = DataLoader(dataset, batch_size=batch_size)
|
| 95 |
+
dl = accelerator.prepare(dl)
|
| 96 |
+
|
| 97 |
+
return dl
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def verify_dataloader_batch_sizes(
|
| 101 |
+
accelerator: Accelerator,
|
| 102 |
+
dataset_size: int,
|
| 103 |
+
batch_size: int,
|
| 104 |
+
process_0_expected_batch_sizes: list[int],
|
| 105 |
+
process_1_expected_batch_sizes: list[int],
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
A helper function for verifying the batch sizes coming from a prepared dataloader in each process
|
| 109 |
+
"""
|
| 110 |
+
dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size)
|
| 111 |
+
|
| 112 |
+
batch_sizes = [len(batch[0]) for batch in dl]
|
| 113 |
+
|
| 114 |
+
if accelerator.process_index == 0:
|
| 115 |
+
assert batch_sizes == process_0_expected_batch_sizes
|
| 116 |
+
elif accelerator.process_index == 1:
|
| 117 |
+
assert batch_sizes == process_1_expected_batch_sizes
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def test_default_ensures_even_batch_sizes():
|
| 121 |
+
accelerator = create_accelerator()
|
| 122 |
+
|
| 123 |
+
# without padding, we would expect a different number of batches
|
| 124 |
+
verify_dataloader_batch_sizes(
|
| 125 |
+
accelerator,
|
| 126 |
+
dataset_size=3,
|
| 127 |
+
batch_size=1,
|
| 128 |
+
process_0_expected_batch_sizes=[1, 1],
|
| 129 |
+
process_1_expected_batch_sizes=[1, 1],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# without padding, we would expect the same number of batches, but different sizes
|
| 133 |
+
verify_dataloader_batch_sizes(
|
| 134 |
+
accelerator,
|
| 135 |
+
dataset_size=7,
|
| 136 |
+
batch_size=2,
|
| 137 |
+
process_0_expected_batch_sizes=[2, 2],
|
| 138 |
+
process_1_expected_batch_sizes=[2, 2],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def test_can_disable_even_batches():
|
| 143 |
+
accelerator = create_accelerator(even_batches=False)
|
| 144 |
+
|
| 145 |
+
verify_dataloader_batch_sizes(
|
| 146 |
+
accelerator,
|
| 147 |
+
dataset_size=3,
|
| 148 |
+
batch_size=1,
|
| 149 |
+
process_0_expected_batch_sizes=[1, 1],
|
| 150 |
+
process_1_expected_batch_sizes=[1],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
verify_dataloader_batch_sizes(
|
| 154 |
+
accelerator,
|
| 155 |
+
dataset_size=7,
|
| 156 |
+
batch_size=2,
|
| 157 |
+
process_0_expected_batch_sizes=[2, 2],
|
| 158 |
+
process_1_expected_batch_sizes=[2, 1],
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test_can_join_uneven_inputs():
|
| 163 |
+
accelerator = create_accelerator(even_batches=False)
|
| 164 |
+
|
| 165 |
+
model = torch.nn.Linear(1, 1)
|
| 166 |
+
ddp_model = accelerator.prepare(model)
|
| 167 |
+
|
| 168 |
+
dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 169 |
+
|
| 170 |
+
batch_idxs = []
|
| 171 |
+
with accelerator.join_uneven_inputs([ddp_model]):
|
| 172 |
+
for batch_idx, batch in enumerate(dl):
|
| 173 |
+
output = ddp_model(batch[0].float())
|
| 174 |
+
loss = output.sum()
|
| 175 |
+
loss.backward()
|
| 176 |
+
batch_idxs.append(batch_idx)
|
| 177 |
+
|
| 178 |
+
accelerator.wait_for_everyone()
|
| 179 |
+
|
| 180 |
+
if accelerator.process_index == 0:
|
| 181 |
+
assert batch_idxs == [0, 1]
|
| 182 |
+
elif accelerator.process_index == 1:
|
| 183 |
+
assert batch_idxs == [0]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def test_join_raises_warning_for_non_ddp_distributed(accelerator):
|
| 187 |
+
with warnings.catch_warnings(record=True) as w:
|
| 188 |
+
with accelerator.join_uneven_inputs([Mock()]):
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
assert issubclass(w[-1].category, UserWarning)
|
| 192 |
+
assert "only supported for multi-GPU" in str(w[-1].message)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def test_join_can_override_even_batches():
|
| 196 |
+
default_even_batches = True
|
| 197 |
+
overridden_even_batches = False
|
| 198 |
+
accelerator = create_accelerator(even_batches=default_even_batches)
|
| 199 |
+
model = torch.nn.Linear(1, 1)
|
| 200 |
+
ddp_model = accelerator.prepare(model)
|
| 201 |
+
train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 202 |
+
valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 203 |
+
|
| 204 |
+
with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
|
| 205 |
+
train_dl_overridden_value = train_dl.batch_sampler.even_batches
|
| 206 |
+
valid_dl_overridden_value = valid_dl.batch_sampler.even_batches
|
| 207 |
+
|
| 208 |
+
assert train_dl_overridden_value == overridden_even_batches
|
| 209 |
+
assert valid_dl_overridden_value == overridden_even_batches
|
| 210 |
+
assert train_dl.batch_sampler.even_batches == default_even_batches
|
| 211 |
+
assert valid_dl.batch_sampler.even_batches == default_even_batches
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def test_join_can_override_for_mixed_type_dataloaders():
|
| 215 |
+
default_even_batches = True
|
| 216 |
+
overridden_even_batches = False
|
| 217 |
+
accelerator = create_accelerator(even_batches=default_even_batches)
|
| 218 |
+
model = torch.nn.Linear(1, 1)
|
| 219 |
+
ddp_model = accelerator.prepare(model)
|
| 220 |
+
create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
|
| 221 |
+
batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1)
|
| 222 |
+
|
| 223 |
+
with warnings.catch_warnings():
|
| 224 |
+
warnings.filterwarnings("ignore")
|
| 225 |
+
try:
|
| 226 |
+
with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches):
|
| 227 |
+
batch_dl_overridden_value = batch_dl.batch_sampler.even_batches
|
| 228 |
+
except AttributeError:
|
| 229 |
+
# ensure attribute error is not raised when processing iterable dl
|
| 230 |
+
raise AssertionError
|
| 231 |
+
|
| 232 |
+
assert batch_dl_overridden_value == overridden_even_batches
|
| 233 |
+
assert batch_dl.batch_sampler.even_batches == default_even_batches
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def test_join_raises_warning_for_iterable_when_overriding_even_batches():
|
| 237 |
+
accelerator = create_accelerator()
|
| 238 |
+
model = torch.nn.Linear(1, 1)
|
| 239 |
+
ddp_model = accelerator.prepare(model)
|
| 240 |
+
create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True)
|
| 241 |
+
|
| 242 |
+
with warnings.catch_warnings(record=True) as w:
|
| 243 |
+
with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
|
| 244 |
+
pass
|
| 245 |
+
|
| 246 |
+
assert issubclass(w[-1].category, UserWarning)
|
| 247 |
+
assert "only supported for map-style datasets" in str(w[-1].message)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def test_pickle_accelerator():
|
| 251 |
+
accelerator = create_accelerator()
|
| 252 |
+
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
|
| 253 |
+
_ = accelerator.prepare(data_loader)
|
| 254 |
+
pickled_accelerator = pickle.dumps(accelerator)
|
| 255 |
+
unpickled_accelerator = pickle.loads(pickled_accelerator)
|
| 256 |
+
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
|
| 257 |
+
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def test_data_loader(data_loader, accelerator):
|
| 261 |
+
# Prepare the DataLoader
|
| 262 |
+
data_loader = accelerator.prepare(data_loader)
|
| 263 |
+
|
| 264 |
+
all_examples = []
|
| 265 |
+
for i, batch in enumerate(data_loader):
|
| 266 |
+
index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"]))
|
| 267 |
+
all_examples.extend(index.detach().cpu().numpy().tolist())
|
| 268 |
+
|
| 269 |
+
# Sort the examples
|
| 270 |
+
sorted_all_examples = sorted(all_examples)
|
| 271 |
+
|
| 272 |
+
# Check if all elements are present in the sorted list of iterated samples
|
| 273 |
+
assert len(set(sorted_all_examples)) == NUM_ELEMENTS, (
|
| 274 |
+
"Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _test_stateful_dataloader_resume(accelerator, iterable):
|
| 279 |
+
"""
|
| 280 |
+
Helper: iterate a stateful dataloader, save state after a few batches using `load_state_dict`,
|
| 281 |
+
resume from the saved state, and verify the resumed batches match what was originally unseen.
|
| 282 |
+
|
| 283 |
+
Saves early (after 3 batches) so many batches remain, exposing any off-by-one in state restoration.
|
| 284 |
+
Tested with both iterable and map-style datasets to cover different state_dict code paths.
|
| 285 |
+
"""
|
| 286 |
+
old_dataloader_config = accelerator.dataloader_config
|
| 287 |
+
try:
|
| 288 |
+
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
| 289 |
+
prepared_dl = create_dataloader(
|
| 290 |
+
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True
|
| 291 |
+
)
|
| 292 |
+
untrained_batches = []
|
| 293 |
+
save_step = 2
|
| 294 |
+
for step, batch in enumerate(prepared_dl):
|
| 295 |
+
if step == save_step:
|
| 296 |
+
state_dict = prepared_dl.state_dict()
|
| 297 |
+
if step > save_step:
|
| 298 |
+
untrained_batches.append(batch)
|
| 299 |
+
not_skipped_batches = accelerator.gather(untrained_batches)
|
| 300 |
+
prepared_dl.load_state_dict(state_dict)
|
| 301 |
+
resumed_batches = []
|
| 302 |
+
for batch in prepared_dl:
|
| 303 |
+
resumed_batches.append(batch)
|
| 304 |
+
resumed_batches = accelerator.gather(resumed_batches)
|
| 305 |
+
assert len(not_skipped_batches) == len(resumed_batches), (
|
| 306 |
+
f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}"
|
| 307 |
+
)
|
| 308 |
+
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
| 309 |
+
for v1, v2 in zip(b1, b2):
|
| 310 |
+
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
| 311 |
+
finally:
|
| 312 |
+
accelerator.dataloader_config = old_dataloader_config
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def test_stateful_dataloader(accelerator):
|
| 316 |
+
"""
|
| 317 |
+
Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
|
| 318 |
+
resumed from the saved state.
|
| 319 |
+
|
| 320 |
+
The result should be the same as the rest of the data that iterated over after saving.
|
| 321 |
+
"""
|
| 322 |
+
_test_stateful_dataloader_resume(accelerator, iterable=True)
|
| 323 |
+
_test_stateful_dataloader_resume(accelerator, iterable=False)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _test_stateful_dataloader_save_state_resume(accelerator, iterable):
|
| 327 |
+
"""
|
| 328 |
+
Helper: iterate a stateful dataloader, save state after a few batches using `Accelerator.save_state`,
|
| 329 |
+
resume, and verify the resumed batches match what was originally unseen.
|
| 330 |
+
"""
|
| 331 |
+
old_dataloader_config = accelerator.dataloader_config
|
| 332 |
+
try:
|
| 333 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 334 |
+
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
| 335 |
+
prepared_dl = create_dataloader(
|
| 336 |
+
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True
|
| 337 |
+
)
|
| 338 |
+
untrained_batches = []
|
| 339 |
+
save_step = 2
|
| 340 |
+
for step, batch in enumerate(prepared_dl):
|
| 341 |
+
if step == save_step:
|
| 342 |
+
accelerator.save_state(tmpdir)
|
| 343 |
+
if step > save_step:
|
| 344 |
+
untrained_batches.append(batch)
|
| 345 |
+
not_skipped_batches = accelerator.gather(untrained_batches)
|
| 346 |
+
accelerator.load_state(tmpdir)
|
| 347 |
+
resumed_batches = []
|
| 348 |
+
for batch in prepared_dl:
|
| 349 |
+
resumed_batches.append(batch)
|
| 350 |
+
resumed_batches = accelerator.gather(resumed_batches)
|
| 351 |
+
assert len(not_skipped_batches) == len(resumed_batches), (
|
| 352 |
+
f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}"
|
| 353 |
+
)
|
| 354 |
+
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
| 355 |
+
for v1, v2 in zip(b1, b2):
|
| 356 |
+
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
| 357 |
+
finally:
|
| 358 |
+
accelerator.dataloader_config = old_dataloader_config
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def test_stateful_dataloader_save_state(accelerator):
|
| 362 |
+
"""
|
| 363 |
+
Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
|
| 364 |
+
and then resumed from the saved state.
|
| 365 |
+
|
| 366 |
+
The result should be the same as the rest of the data that iterated over after saving.
|
| 367 |
+
"""
|
| 368 |
+
_test_stateful_dataloader_save_state_resume(accelerator, iterable=True)
|
| 369 |
+
_test_stateful_dataloader_save_state_resume(accelerator, iterable=False)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def main():
|
| 373 |
+
accelerator = create_accelerator()
|
| 374 |
+
torch.manual_seed(accelerator.process_index)
|
| 375 |
+
|
| 376 |
+
accelerator.print("Test that even_batches variable ensures uniform batches across processes")
|
| 377 |
+
test_default_ensures_even_batch_sizes()
|
| 378 |
+
|
| 379 |
+
accelerator.print("Run tests with even_batches disabled")
|
| 380 |
+
test_can_disable_even_batches()
|
| 381 |
+
|
| 382 |
+
accelerator.print("Test joining uneven inputs")
|
| 383 |
+
test_can_join_uneven_inputs()
|
| 384 |
+
|
| 385 |
+
accelerator.print("Test overriding even_batches when joining uneven inputs")
|
| 386 |
+
test_join_can_override_even_batches()
|
| 387 |
+
|
| 388 |
+
accelerator.print("Test overriding even_batches for mixed dataloader types")
|
| 389 |
+
test_join_can_override_for_mixed_type_dataloaders()
|
| 390 |
+
|
| 391 |
+
accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders")
|
| 392 |
+
test_join_raises_warning_for_iterable_when_overriding_even_batches()
|
| 393 |
+
|
| 394 |
+
accelerator.print("Test join with non DDP distributed raises warning")
|
| 395 |
+
original_state = accelerator.state.distributed_type
|
| 396 |
+
accelerator.state.distributed_type = DistributedType.FSDP
|
| 397 |
+
test_join_raises_warning_for_non_ddp_distributed(accelerator)
|
| 398 |
+
accelerator.state.distributed_type = original_state
|
| 399 |
+
|
| 400 |
+
accelerator.print("Test pickling an accelerator")
|
| 401 |
+
test_pickle_accelerator()
|
| 402 |
+
|
| 403 |
+
dataset = DummyDataset()
|
| 404 |
+
|
| 405 |
+
accelerator.print("Test DataLoader with shuffle=False")
|
| 406 |
+
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
|
| 407 |
+
test_data_loader(loader, accelerator)
|
| 408 |
+
|
| 409 |
+
accelerator.print("Test DataLoader with shuffle=True")
|
| 410 |
+
loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
|
| 411 |
+
test_data_loader(loader, accelerator)
|
| 412 |
+
|
| 413 |
+
accelerator.print("Test DataLoader with batch_sampler")
|
| 414 |
+
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
|
| 415 |
+
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS)
|
| 416 |
+
test_data_loader(loader, accelerator)
|
| 417 |
+
|
| 418 |
+
accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`")
|
| 419 |
+
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
|
| 420 |
+
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
|
| 421 |
+
test_data_loader(loader, accelerator)
|
| 422 |
+
test_stateful_dataloader(accelerator)
|
| 423 |
+
test_stateful_dataloader_save_state(accelerator)
|
| 424 |
+
|
| 425 |
+
accelerator.end_training()
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if __name__ == "__main__":
|
| 429 |
+
main()
|
accelerate/test_utils/scripts/test_merge_weights.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import gc
|
| 16 |
+
import logging
|
| 17 |
+
import shutil
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from safetensors.torch import load_file
|
| 22 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
|
| 25 |
+
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
| 26 |
+
from accelerate.commands.merge import merge_command, merge_command_parser
|
| 27 |
+
from accelerate.state import AcceleratorState
|
| 28 |
+
from accelerate.test_utils import torch_device
|
| 29 |
+
from accelerate.test_utils.training import RegressionDataset
|
| 30 |
+
from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO)
|
| 34 |
+
|
| 35 |
+
parser = merge_command_parser()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TinyModel(torch.nn.Module):
|
| 39 |
+
def __init__(self):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.linear1 = torch.nn.Linear(16, 16)
|
| 42 |
+
self.activation = torch.nn.ReLU()
|
| 43 |
+
self.linear2 = torch.nn.Linear(16, 16)
|
| 44 |
+
self.softmax = torch.nn.Softmax()
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.linear2(self.activation(self.linear1(x)))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def setup():
|
| 51 |
+
if AcceleratorState._shared_state != {}:
|
| 52 |
+
AcceleratorState()._reset_state()
|
| 53 |
+
plugin = FullyShardedDataParallelPlugin(
|
| 54 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
|
| 55 |
+
)
|
| 56 |
+
model = TinyModel()
|
| 57 |
+
with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
|
| 58 |
+
plugin.set_auto_wrap_policy(model)
|
| 59 |
+
accelerator = Accelerator(fsdp_plugin=plugin)
|
| 60 |
+
model = accelerator.prepare(model)
|
| 61 |
+
return model, plugin, accelerator
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def mock_training(accelerator, model):
|
| 65 |
+
train_set = RegressionDataset(length=128, seed=42)
|
| 66 |
+
train_dl = DataLoader(train_set, batch_size=16, shuffle=False)
|
| 67 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 68 |
+
|
| 69 |
+
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
| 70 |
+
for _ in range(3):
|
| 71 |
+
for batch in train_dl:
|
| 72 |
+
model.zero_grad()
|
| 73 |
+
output = model(batch["x"])
|
| 74 |
+
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
| 75 |
+
accelerator.backward(loss)
|
| 76 |
+
optimizer.step()
|
| 77 |
+
return model
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def check_weights(operation, state_1, state_2):
|
| 81 |
+
for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
|
| 82 |
+
if operation == "same":
|
| 83 |
+
assert torch.allclose(weight_1, weight_2)
|
| 84 |
+
else:
|
| 85 |
+
assert not torch.allclose(weight_1, weight_2)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def check_safetensors_weights(path, model):
|
| 89 |
+
safe_state_dict = load_file(path / "model.safetensors")
|
| 90 |
+
safe_loaded_model = TinyModel().to(torch_device)
|
| 91 |
+
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
|
| 92 |
+
safe_loaded_model.load_state_dict(safe_state_dict)
|
| 93 |
+
check_weights("same", model.state_dict(), safe_loaded_model.state_dict())
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def check_pytorch_weights(path, model):
|
| 97 |
+
nonsafe_state_dict = torch.load(path / "pytorch_model.bin", weights_only=True)
|
| 98 |
+
nonsafe_loaded_model = TinyModel().to(torch_device)
|
| 99 |
+
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
|
| 100 |
+
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
|
| 101 |
+
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_merge_weights_safetensors(model, path):
|
| 105 |
+
# Should now be saved at `path/merged.safetensors`
|
| 106 |
+
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True)
|
| 107 |
+
check_safetensors_weights(path, model)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_merge_weights_command_safetensors(model, path):
|
| 111 |
+
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)])
|
| 112 |
+
merge_command(args)
|
| 113 |
+
check_safetensors_weights(path, model)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_merge_weights_pytorch(model, path):
|
| 117 |
+
# Should now be saved at `path/merged.bin`
|
| 118 |
+
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False)
|
| 119 |
+
check_pytorch_weights(path, model)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_merge_weights_command_pytorch(model, path):
|
| 123 |
+
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"])
|
| 124 |
+
merge_command(args)
|
| 125 |
+
check_pytorch_weights(path, model)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
# Note this test requires at least two accelerators!
|
| 130 |
+
model, plugin, accelerator = setup()
|
| 131 |
+
if accelerator.num_processes > 1:
|
| 132 |
+
try:
|
| 133 |
+
# Initial setup for things
|
| 134 |
+
out_path = Path("test_merge_weights_fsdp_weights")
|
| 135 |
+
if not out_path.exists():
|
| 136 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
# Train briefly once weights aren't the baseline
|
| 139 |
+
model = mock_training(accelerator, model)
|
| 140 |
+
accelerator.wait_for_everyone()
|
| 141 |
+
|
| 142 |
+
gc.collect() # Needed for some lingering refs after training
|
| 143 |
+
save_fsdp_model(plugin, accelerator, model, out_path)
|
| 144 |
+
accelerator.wait_for_everyone()
|
| 145 |
+
|
| 146 |
+
# Finally we can test
|
| 147 |
+
test_merge_weights_safetensors(model, out_path)
|
| 148 |
+
test_merge_weights_command_safetensors(model, out_path)
|
| 149 |
+
test_merge_weights_pytorch(model, out_path)
|
| 150 |
+
test_merge_weights_command_pytorch(model, out_path)
|
| 151 |
+
except Exception:
|
| 152 |
+
raise
|
| 153 |
+
finally:
|
| 154 |
+
# Cleanup in case of any failures
|
| 155 |
+
if accelerator.is_main_process:
|
| 156 |
+
shutil.rmtree(out_path)
|
| 157 |
+
accelerator.wait_for_everyone()
|
| 158 |
+
accelerator.end_training()
|
accelerate/test_utils/scripts/test_notebook.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Test file to ensure that in general certain situational setups for notebooks work.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
from pytest import mark, raises
|
| 22 |
+
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
| 23 |
+
|
| 24 |
+
from accelerate import PartialState, notebook_launcher
|
| 25 |
+
from accelerate.test_utils import require_bnb
|
| 26 |
+
from accelerate.utils import is_bnb_available, is_xpu_available
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def basic_function():
|
| 30 |
+
# Just prints the PartialState
|
| 31 |
+
print(f"PartialState:\n{PartialState()}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def tough_nut_function(queue):
|
| 35 |
+
if queue.empty():
|
| 36 |
+
return
|
| 37 |
+
trial = queue.get()
|
| 38 |
+
if trial > 0:
|
| 39 |
+
queue.put(trial - 1)
|
| 40 |
+
raise RuntimeError("The nut hasn't cracked yet! Try again.")
|
| 41 |
+
|
| 42 |
+
print(f"PartialState:\n{PartialState()}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def bipolar_sleep_function(sleep_sec: int):
|
| 46 |
+
state = PartialState()
|
| 47 |
+
if state.process_index % 2 == 0:
|
| 48 |
+
raise RuntimeError("I'm an even process. I don't like to sleep.")
|
| 49 |
+
else:
|
| 50 |
+
time.sleep(sleep_sec)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_can_initialize():
|
| 57 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends")
|
| 61 |
+
def test_static_rdzv_backend():
|
| 62 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends")
|
| 66 |
+
def test_c10d_rdzv_backend():
|
| 67 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance")
|
| 71 |
+
def test_fault_tolerant(max_restarts: int = 3):
|
| 72 |
+
# Use torch.multiprocessing to get the right context for the current device
|
| 73 |
+
import torch.multiprocessing as mp
|
| 74 |
+
|
| 75 |
+
# Get appropriate context - 'spawn' for XPU, 'fork' for others
|
| 76 |
+
if is_xpu_available():
|
| 77 |
+
ctx = mp.get_context("spawn")
|
| 78 |
+
else:
|
| 79 |
+
ctx = mp.get_context("fork")
|
| 80 |
+
queue = ctx.Queue()
|
| 81 |
+
queue.put(max_restarts)
|
| 82 |
+
notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring")
|
| 86 |
+
def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100):
|
| 87 |
+
start_time = time.time()
|
| 88 |
+
with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."):
|
| 89 |
+
notebook_launcher(
|
| 90 |
+
bipolar_sleep_function,
|
| 91 |
+
(sleep_sec,),
|
| 92 |
+
num_processes=NUM_PROCESSES,
|
| 93 |
+
monitor_interval=monitor_interval,
|
| 94 |
+
)
|
| 95 |
+
assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time."
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@require_bnb
|
| 99 |
+
def test_problematic_imports():
|
| 100 |
+
with raises(RuntimeError, match="Please keep these imports"):
|
| 101 |
+
import bitsandbytes as bnb # noqa: F401
|
| 102 |
+
|
| 103 |
+
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
print("Test basic notebook can be ran")
|
| 108 |
+
test_can_initialize()
|
| 109 |
+
print("Test static rendezvous backend")
|
| 110 |
+
test_static_rdzv_backend()
|
| 111 |
+
print("Test c10d rendezvous backend")
|
| 112 |
+
test_c10d_rdzv_backend()
|
| 113 |
+
print("Test fault tolerant")
|
| 114 |
+
test_fault_tolerant()
|
| 115 |
+
print("Test monitoring")
|
| 116 |
+
test_monitoring()
|
| 117 |
+
if is_bnb_available():
|
| 118 |
+
print("Test problematic imports (bnb)")
|
| 119 |
+
test_problematic_imports()
|
| 120 |
+
if NUM_PROCESSES > 1:
|
| 121 |
+
PartialState().destroy_process_group()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|
accelerate/test_utils/scripts/test_ops.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from accelerate import PartialState
|
| 20 |
+
from accelerate.test_utils.testing import assert_exception
|
| 21 |
+
from accelerate.utils.dataclasses import DistributedType
|
| 22 |
+
from accelerate.utils.operations import (
|
| 23 |
+
DistributedOperationException,
|
| 24 |
+
broadcast,
|
| 25 |
+
copy_tensor_to_devices,
|
| 26 |
+
gather,
|
| 27 |
+
gather_object,
|
| 28 |
+
pad_across_processes,
|
| 29 |
+
reduce,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def create_tensor(state):
|
| 34 |
+
return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_gather(state):
|
| 38 |
+
tensor = create_tensor(state)
|
| 39 |
+
gathered_tensor = gather(tensor)
|
| 40 |
+
assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_gather_object(state):
|
| 44 |
+
# Gather objects in TorchXLA is not supported.
|
| 45 |
+
if state.distributed_type == DistributedType.XLA:
|
| 46 |
+
return
|
| 47 |
+
obj = [state.process_index]
|
| 48 |
+
gathered_obj = gather_object(obj)
|
| 49 |
+
assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
|
| 50 |
+
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_gather_non_contiguous(state):
|
| 54 |
+
# Skip this test because the 'is_contiguous' function of XLA tensor always returns True.
|
| 55 |
+
if state.distributed_type == DistributedType.XLA:
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
# Create a non-contiguous tensor (enforce non-contiguity after device memory allocation)
|
| 59 |
+
tensor = torch.arange(12, device=state.device).view(4, 3).t()
|
| 60 |
+
assert not tensor.is_contiguous()
|
| 61 |
+
# Shouldn't error out
|
| 62 |
+
_ = gather(tensor)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_broadcast(state):
|
| 66 |
+
tensor = create_tensor(state)
|
| 67 |
+
broadcasted_tensor = broadcast(tensor)
|
| 68 |
+
assert broadcasted_tensor.shape == torch.Size([state.num_processes])
|
| 69 |
+
assert broadcasted_tensor.tolist() == list(range(1, state.num_processes + 1))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_pad_across_processes(state):
|
| 73 |
+
# We need to pad the tensor with one more element if we are the main process
|
| 74 |
+
# to ensure that we can pad
|
| 75 |
+
if state.is_main_process:
|
| 76 |
+
tensor = torch.arange(state.num_processes + 1).to(state.device)
|
| 77 |
+
else:
|
| 78 |
+
tensor = torch.arange(state.num_processes).to(state.device)
|
| 79 |
+
padded_tensor = pad_across_processes(tensor)
|
| 80 |
+
assert padded_tensor.shape == torch.Size([state.num_processes + 1])
|
| 81 |
+
if not state.is_main_process:
|
| 82 |
+
assert padded_tensor.tolist() == list(range(0, state.num_processes)) + [0]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_reduce_sum(state):
|
| 86 |
+
# For now runs on only two processes
|
| 87 |
+
if state.num_processes != 2:
|
| 88 |
+
return
|
| 89 |
+
tensor = create_tensor(state)
|
| 90 |
+
reduced_tensor = reduce(tensor, "sum")
|
| 91 |
+
truth_tensor = torch.tensor([4.0, 6]).to(state.device)
|
| 92 |
+
assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def test_reduce_mean(state):
|
| 96 |
+
# For now runs on only two processes
|
| 97 |
+
if state.num_processes != 2:
|
| 98 |
+
return
|
| 99 |
+
tensor = create_tensor(state)
|
| 100 |
+
reduced_tensor = reduce(tensor, "mean")
|
| 101 |
+
truth_tensor = torch.tensor([2.0, 3]).to(state.device)
|
| 102 |
+
assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def test_op_checker(state):
|
| 106 |
+
# Must be in a distributed state, and gathering is currently not supported in TorchXLA.
|
| 107 |
+
if state.distributed_type in [DistributedType.NO, DistributedType.XLA]:
|
| 108 |
+
return
|
| 109 |
+
state.debug = True
|
| 110 |
+
# `pad_across_processes`
|
| 111 |
+
if state.process_index == 0:
|
| 112 |
+
data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}
|
| 113 |
+
else:
|
| 114 |
+
data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4, 5]]]).to(state.device)}
|
| 115 |
+
|
| 116 |
+
with assert_exception(DistributedOperationException):
|
| 117 |
+
pad_across_processes(data, dim=0)
|
| 118 |
+
|
| 119 |
+
# `reduce`
|
| 120 |
+
if state.process_index == 0:
|
| 121 |
+
data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}
|
| 122 |
+
else:
|
| 123 |
+
data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)}
|
| 124 |
+
|
| 125 |
+
with assert_exception(DistributedOperationException):
|
| 126 |
+
reduce(data)
|
| 127 |
+
|
| 128 |
+
# `broadcast`
|
| 129 |
+
if state.process_index == 0:
|
| 130 |
+
data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)}
|
| 131 |
+
else:
|
| 132 |
+
data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)}
|
| 133 |
+
|
| 134 |
+
with assert_exception(DistributedOperationException):
|
| 135 |
+
broadcast(data)
|
| 136 |
+
|
| 137 |
+
state.debug = False
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def test_copy_tensor_to_devices(state):
|
| 141 |
+
if state.distributed_type not in [DistributedType.MULTI_GPU, DistributedType.XLA]:
|
| 142 |
+
return
|
| 143 |
+
if state.is_main_process:
|
| 144 |
+
tensor = torch.tensor([1, 2, 3], dtype=torch.int).to(state.device)
|
| 145 |
+
else:
|
| 146 |
+
tensor = None
|
| 147 |
+
tensor = copy_tensor_to_devices(tensor)
|
| 148 |
+
assert torch.allclose(tensor, torch.tensor([1, 2, 3], dtype=torch.int, device=state.device))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _mp_fn(index):
|
| 152 |
+
# For xla_spawn (TPUs)
|
| 153 |
+
main()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def main():
|
| 157 |
+
state = PartialState()
|
| 158 |
+
state.print(f"State: {state}")
|
| 159 |
+
state.print("testing gather")
|
| 160 |
+
test_gather(state)
|
| 161 |
+
state.print("testing gather_object")
|
| 162 |
+
test_gather_object(state)
|
| 163 |
+
state.print("testing gather non-contiguous")
|
| 164 |
+
test_gather_non_contiguous(state)
|
| 165 |
+
state.print("testing broadcast")
|
| 166 |
+
test_broadcast(state)
|
| 167 |
+
state.print("testing pad_across_processes")
|
| 168 |
+
test_pad_across_processes(state)
|
| 169 |
+
state.print("testing reduce_sum")
|
| 170 |
+
test_reduce_sum(state)
|
| 171 |
+
state.print("testing reduce_mean")
|
| 172 |
+
test_reduce_mean(state)
|
| 173 |
+
state.print("testing op_checker")
|
| 174 |
+
test_op_checker(state)
|
| 175 |
+
state.print("testing sending tensors across devices")
|
| 176 |
+
test_copy_tensor_to_devices(state)
|
| 177 |
+
state.destroy_process_group()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
main()
|
accelerate/test_utils/scripts/test_script.py
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import contextlib
|
| 18 |
+
import io
|
| 19 |
+
import math
|
| 20 |
+
import time
|
| 21 |
+
from copy import deepcopy
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from torch.utils.data import DataLoader, Dataset
|
| 27 |
+
|
| 28 |
+
from accelerate import Accelerator
|
| 29 |
+
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
|
| 30 |
+
from accelerate.state import AcceleratorState
|
| 31 |
+
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
|
| 32 |
+
from accelerate.utils import (
|
| 33 |
+
DataLoaderConfiguration,
|
| 34 |
+
DistributedType,
|
| 35 |
+
gather,
|
| 36 |
+
gather_object,
|
| 37 |
+
is_bf16_available,
|
| 38 |
+
is_cuda_available,
|
| 39 |
+
is_datasets_available,
|
| 40 |
+
is_fp16_available,
|
| 41 |
+
is_hpu_available,
|
| 42 |
+
is_mps_available,
|
| 43 |
+
is_pytest_available,
|
| 44 |
+
set_seed,
|
| 45 |
+
synchronize_rng_states,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_hpu_available():
|
| 50 |
+
ATOL = 1e-3
|
| 51 |
+
RTOL = 1e-3
|
| 52 |
+
else:
|
| 53 |
+
ATOL = 1e-6
|
| 54 |
+
RTOL = 1e-6
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False):
|
| 58 |
+
"Creates a dataloader that can also use the `SeedableRandomSampler`"
|
| 59 |
+
if use_seedable_sampler:
|
| 60 |
+
# The SeedableRandomSampler is needed during distributed setups
|
| 61 |
+
# for full reproducibility across processes with the `DataLoader`
|
| 62 |
+
sampler = SeedableRandomSampler(
|
| 63 |
+
generator=generator,
|
| 64 |
+
data_source=train_set,
|
| 65 |
+
num_samples=len(train_set),
|
| 66 |
+
)
|
| 67 |
+
return DataLoader(train_set, batch_size=batch_size, sampler=sampler)
|
| 68 |
+
else:
|
| 69 |
+
return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def print_main(state):
|
| 73 |
+
print(f"Printing from the main process {state.process_index}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def print_local_main(state):
|
| 77 |
+
print(f"Printing from the local main process {state.local_process_index}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def print_last(state):
|
| 81 |
+
print(f"Printing from the last process {state.process_index}")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def print_on(state, process_idx):
|
| 85 |
+
print(f"Printing from process {process_idx}: {state.process_index}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def process_execution_check():
|
| 89 |
+
accelerator = Accelerator()
|
| 90 |
+
num_processes = accelerator.num_processes
|
| 91 |
+
# Test main_process_first context manager
|
| 92 |
+
path = Path("check_main_process_first.txt")
|
| 93 |
+
with accelerator.main_process_first():
|
| 94 |
+
if accelerator.is_main_process:
|
| 95 |
+
time.sleep(0.1) # ensure main process takes longest
|
| 96 |
+
with open(path, "a+") as f:
|
| 97 |
+
f.write("Currently in the main process\n")
|
| 98 |
+
else:
|
| 99 |
+
with open(path, "a+") as f:
|
| 100 |
+
f.write("Now on another process\n")
|
| 101 |
+
accelerator.wait_for_everyone()
|
| 102 |
+
|
| 103 |
+
if accelerator.is_main_process:
|
| 104 |
+
with open(path) as f:
|
| 105 |
+
text = "".join(f.readlines())
|
| 106 |
+
try:
|
| 107 |
+
assert text.startswith("Currently in the main process\n"), "Main process was not first"
|
| 108 |
+
if num_processes > 1:
|
| 109 |
+
assert text.endswith("Now on another process\n"), "Main process was not first"
|
| 110 |
+
assert text.count("Now on another process\n") == accelerator.num_processes - 1, (
|
| 111 |
+
f"Only wrote to file {text.count('Now on another process') + 1} times, not {accelerator.num_processes}"
|
| 112 |
+
)
|
| 113 |
+
except AssertionError:
|
| 114 |
+
path.unlink()
|
| 115 |
+
raise
|
| 116 |
+
|
| 117 |
+
if accelerator.is_main_process and path.exists():
|
| 118 |
+
path.unlink()
|
| 119 |
+
accelerator.wait_for_everyone()
|
| 120 |
+
# Test the decorators
|
| 121 |
+
f = io.StringIO()
|
| 122 |
+
with contextlib.redirect_stdout(f):
|
| 123 |
+
accelerator.on_main_process(print_main)(accelerator.state)
|
| 124 |
+
result = f.getvalue().rstrip()
|
| 125 |
+
if accelerator.is_main_process:
|
| 126 |
+
assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0"
|
| 127 |
+
else:
|
| 128 |
+
assert f.getvalue().rstrip() == "", f'{result} != ""'
|
| 129 |
+
f.truncate(0)
|
| 130 |
+
f.seek(0)
|
| 131 |
+
|
| 132 |
+
with contextlib.redirect_stdout(f):
|
| 133 |
+
accelerator.on_local_main_process(print_local_main)(accelerator.state)
|
| 134 |
+
if accelerator.is_local_main_process:
|
| 135 |
+
assert f.getvalue().rstrip() == "Printing from the local main process 0"
|
| 136 |
+
else:
|
| 137 |
+
assert f.getvalue().rstrip() == ""
|
| 138 |
+
f.truncate(0)
|
| 139 |
+
f.seek(0)
|
| 140 |
+
|
| 141 |
+
with contextlib.redirect_stdout(f):
|
| 142 |
+
accelerator.on_last_process(print_last)(accelerator.state)
|
| 143 |
+
if accelerator.is_last_process:
|
| 144 |
+
assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}"
|
| 145 |
+
else:
|
| 146 |
+
assert f.getvalue().rstrip() == ""
|
| 147 |
+
f.truncate(0)
|
| 148 |
+
f.seek(0)
|
| 149 |
+
|
| 150 |
+
for process_idx in range(num_processes):
|
| 151 |
+
with contextlib.redirect_stdout(f):
|
| 152 |
+
accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx)
|
| 153 |
+
if accelerator.process_index == process_idx:
|
| 154 |
+
assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}"
|
| 155 |
+
else:
|
| 156 |
+
assert f.getvalue().rstrip() == ""
|
| 157 |
+
f.truncate(0)
|
| 158 |
+
f.seek(0)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def init_state_check():
|
| 162 |
+
# Test we can instantiate this twice in a row.
|
| 163 |
+
state = AcceleratorState()
|
| 164 |
+
if state.local_process_index == 0:
|
| 165 |
+
print("Testing, testing. 1, 2, 3.")
|
| 166 |
+
print(state)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def rng_sync_check():
|
| 170 |
+
state = AcceleratorState()
|
| 171 |
+
synchronize_rng_states(["torch"])
|
| 172 |
+
assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU."
|
| 173 |
+
if state.distributed_type == DistributedType.MULTI_GPU:
|
| 174 |
+
synchronize_rng_states(["cuda"])
|
| 175 |
+
assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU."
|
| 176 |
+
elif state.distributed_type == DistributedType.MULTI_XPU:
|
| 177 |
+
synchronize_rng_states(["xpu"])
|
| 178 |
+
assert are_the_same_tensors(torch.xpu.get_rng_state()), "RNG states improperly synchronized on XPU."
|
| 179 |
+
generator = torch.Generator()
|
| 180 |
+
synchronize_rng_states(["generator"], generator=generator)
|
| 181 |
+
assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator."
|
| 182 |
+
|
| 183 |
+
if state.local_process_index == 0:
|
| 184 |
+
print("All rng are properly synched.")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def dl_preparation_check():
|
| 188 |
+
state = AcceleratorState()
|
| 189 |
+
length = 32 * state.num_processes
|
| 190 |
+
|
| 191 |
+
dl = DataLoader(range(length), batch_size=8)
|
| 192 |
+
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
|
| 193 |
+
result = []
|
| 194 |
+
for batch in dl:
|
| 195 |
+
result.append(gather(batch))
|
| 196 |
+
result = torch.cat(result)
|
| 197 |
+
|
| 198 |
+
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
|
| 199 |
+
|
| 200 |
+
dl = DataLoader(range(length), batch_size=8)
|
| 201 |
+
dl = prepare_data_loader(
|
| 202 |
+
dl,
|
| 203 |
+
state.device,
|
| 204 |
+
state.num_processes,
|
| 205 |
+
state.process_index,
|
| 206 |
+
put_on_device=True,
|
| 207 |
+
split_batches=True,
|
| 208 |
+
)
|
| 209 |
+
result = []
|
| 210 |
+
for batch in dl:
|
| 211 |
+
result.append(gather(batch))
|
| 212 |
+
result = torch.cat(result)
|
| 213 |
+
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
|
| 214 |
+
|
| 215 |
+
if state.process_index == 0:
|
| 216 |
+
print("Non-shuffled dataloader passing.")
|
| 217 |
+
|
| 218 |
+
dl = DataLoader(range(length), batch_size=8, shuffle=True)
|
| 219 |
+
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
|
| 220 |
+
result = []
|
| 221 |
+
for batch in dl:
|
| 222 |
+
result.append(gather(batch))
|
| 223 |
+
result = torch.cat(result).tolist()
|
| 224 |
+
result.sort()
|
| 225 |
+
assert result == list(range(length)), "Wrong shuffled dataloader result."
|
| 226 |
+
|
| 227 |
+
dl = DataLoader(range(length), batch_size=8, shuffle=True)
|
| 228 |
+
dl = prepare_data_loader(
|
| 229 |
+
dl,
|
| 230 |
+
state.device,
|
| 231 |
+
state.num_processes,
|
| 232 |
+
state.process_index,
|
| 233 |
+
put_on_device=True,
|
| 234 |
+
split_batches=True,
|
| 235 |
+
)
|
| 236 |
+
result = []
|
| 237 |
+
for batch in dl:
|
| 238 |
+
result.append(gather(batch))
|
| 239 |
+
result = torch.cat(result).tolist()
|
| 240 |
+
result.sort()
|
| 241 |
+
assert result == list(range(length)), "Wrong shuffled dataloader result."
|
| 242 |
+
|
| 243 |
+
if state.local_process_index == 0:
|
| 244 |
+
print("Shuffled dataloader passing.")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def central_dl_preparation_check():
|
| 248 |
+
state = AcceleratorState()
|
| 249 |
+
length = 32 * state.num_processes
|
| 250 |
+
|
| 251 |
+
dl = DataLoader(range(length), batch_size=8)
|
| 252 |
+
dl = prepare_data_loader(
|
| 253 |
+
dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
|
| 254 |
+
)
|
| 255 |
+
result = []
|
| 256 |
+
for batch in dl:
|
| 257 |
+
result.append(gather(batch))
|
| 258 |
+
result = torch.cat(result)
|
| 259 |
+
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
|
| 260 |
+
|
| 261 |
+
dl = DataLoader(range(length), batch_size=8)
|
| 262 |
+
dl = prepare_data_loader(
|
| 263 |
+
dl,
|
| 264 |
+
state.device,
|
| 265 |
+
state.num_processes,
|
| 266 |
+
state.process_index,
|
| 267 |
+
put_on_device=True,
|
| 268 |
+
split_batches=True,
|
| 269 |
+
dispatch_batches=True,
|
| 270 |
+
)
|
| 271 |
+
result = []
|
| 272 |
+
for batch in dl:
|
| 273 |
+
result.append(gather(batch))
|
| 274 |
+
result = torch.cat(result)
|
| 275 |
+
assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result."
|
| 276 |
+
|
| 277 |
+
if state.process_index == 0:
|
| 278 |
+
print("Non-shuffled central dataloader passing.")
|
| 279 |
+
|
| 280 |
+
dl = DataLoader(range(length), batch_size=8, shuffle=True)
|
| 281 |
+
dl = prepare_data_loader(
|
| 282 |
+
dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True
|
| 283 |
+
)
|
| 284 |
+
result = []
|
| 285 |
+
for batch in dl:
|
| 286 |
+
result.append(gather(batch))
|
| 287 |
+
result = torch.cat(result).tolist()
|
| 288 |
+
result.sort()
|
| 289 |
+
assert result == list(range(length)), "Wrong shuffled dataloader result."
|
| 290 |
+
|
| 291 |
+
dl = DataLoader(range(length), batch_size=8, shuffle=True)
|
| 292 |
+
dl = prepare_data_loader(
|
| 293 |
+
dl,
|
| 294 |
+
state.device,
|
| 295 |
+
state.num_processes,
|
| 296 |
+
state.process_index,
|
| 297 |
+
put_on_device=True,
|
| 298 |
+
split_batches=True,
|
| 299 |
+
dispatch_batches=True,
|
| 300 |
+
)
|
| 301 |
+
result = []
|
| 302 |
+
for batch in dl:
|
| 303 |
+
result.append(gather(batch))
|
| 304 |
+
result = torch.cat(result).tolist()
|
| 305 |
+
result.sort()
|
| 306 |
+
assert result == list(range(length)), "Wrong shuffled dataloader result."
|
| 307 |
+
|
| 308 |
+
if state.local_process_index == 0:
|
| 309 |
+
print("Shuffled central dataloader passing.")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def custom_sampler_check():
|
| 313 |
+
state = AcceleratorState()
|
| 314 |
+
|
| 315 |
+
class CustomDataset(Dataset):
|
| 316 |
+
def __init__(self, data):
|
| 317 |
+
self.data = data
|
| 318 |
+
|
| 319 |
+
def __len__(self):
|
| 320 |
+
return len(self.data)
|
| 321 |
+
|
| 322 |
+
def __getitem__(self, index):
|
| 323 |
+
return self.data[index]
|
| 324 |
+
|
| 325 |
+
class CustomBatchSampler:
|
| 326 |
+
def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
|
| 327 |
+
self.batch_size = batch_size
|
| 328 |
+
self.data_index = np.arange(dataset_length)
|
| 329 |
+
self.shuffle = shuffle
|
| 330 |
+
|
| 331 |
+
def __iter__(self):
|
| 332 |
+
num_batches = len(self)
|
| 333 |
+
if self.shuffle:
|
| 334 |
+
index = np.random.permutation(self.data_index)
|
| 335 |
+
else:
|
| 336 |
+
index = self.data_index
|
| 337 |
+
output = np.array_split(index, num_batches)
|
| 338 |
+
yield from output
|
| 339 |
+
|
| 340 |
+
def __len__(self):
|
| 341 |
+
return math.ceil(len(self.data_index) / self.batch_size)
|
| 342 |
+
|
| 343 |
+
dataset = CustomDataset(range(32 * state.num_processes))
|
| 344 |
+
sampler = CustomBatchSampler(len(dataset), batch_size=8)
|
| 345 |
+
dl = DataLoader(dataset, batch_sampler=sampler)
|
| 346 |
+
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)
|
| 347 |
+
# We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler
|
| 348 |
+
if hasattr(dl.batch_sampler, "batch_sampler"):
|
| 349 |
+
assert isinstance(dl.batch_sampler.batch_sampler, CustomBatchSampler), (
|
| 350 |
+
"Custom sampler was changed after calling `prepare_data_loader`"
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
assert isinstance(dl.batch_sampler, CustomBatchSampler), (
|
| 354 |
+
"Custom sampler was changed after calling `prepare_data_loader`"
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def check_seedable_sampler():
|
| 359 |
+
# Set seed
|
| 360 |
+
set_seed(42)
|
| 361 |
+
train_set = RegressionDataset(length=10, seed=42)
|
| 362 |
+
train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
|
| 363 |
+
|
| 364 |
+
config = DataLoaderConfiguration(use_seedable_sampler=True)
|
| 365 |
+
accelerator = Accelerator(dataloader_config=config)
|
| 366 |
+
train_dl = accelerator.prepare(train_dl)
|
| 367 |
+
original_items = []
|
| 368 |
+
for _ in range(3):
|
| 369 |
+
for batch in train_dl:
|
| 370 |
+
original_items.append(batch["x"])
|
| 371 |
+
original_items = torch.cat(original_items)
|
| 372 |
+
|
| 373 |
+
# Set seed again and the epoch
|
| 374 |
+
set_seed(42)
|
| 375 |
+
train_dl.set_epoch(0)
|
| 376 |
+
new_items = []
|
| 377 |
+
for _ in range(3):
|
| 378 |
+
for batch in train_dl:
|
| 379 |
+
new_items.append(batch["x"])
|
| 380 |
+
new_items = torch.cat(new_items)
|
| 381 |
+
assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch."
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def check_seedable_sampler_in_batch_sampler_shard():
|
| 385 |
+
set_seed(42)
|
| 386 |
+
|
| 387 |
+
config = DataLoaderConfiguration(use_seedable_sampler=True)
|
| 388 |
+
accelerator = Accelerator(dataloader_config=config)
|
| 389 |
+
assert accelerator.num_processes > 1, "This test requires more than one process."
|
| 390 |
+
|
| 391 |
+
dataloader = DataLoader(list(range(10)), batch_size=1, shuffle=True)
|
| 392 |
+
prepared_data_loader = prepare_data_loader(
|
| 393 |
+
dataloader=dataloader,
|
| 394 |
+
use_seedable_sampler=True,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
target_sampler = prepared_data_loader.batch_sampler.batch_sampler.sampler
|
| 398 |
+
assert isinstance(target_sampler, SeedableRandomSampler), (
|
| 399 |
+
"Sampler in BatchSamplerShard is not SeedableRandomSampler."
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def check_seedable_sampler_with_data_seed():
|
| 404 |
+
# Set seed
|
| 405 |
+
set_seed(42)
|
| 406 |
+
data_seed = 42
|
| 407 |
+
train_set = RegressionDataset(length=10, seed=42)
|
| 408 |
+
train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
|
| 409 |
+
|
| 410 |
+
config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed)
|
| 411 |
+
accelerator = Accelerator(dataloader_config=config)
|
| 412 |
+
prepared_dl = accelerator.prepare(train_dl)
|
| 413 |
+
original_items = []
|
| 414 |
+
for _ in range(3):
|
| 415 |
+
for batch in prepared_dl:
|
| 416 |
+
original_items.append(batch["x"])
|
| 417 |
+
original_items = torch.cat(original_items)
|
| 418 |
+
|
| 419 |
+
# Set new data seed
|
| 420 |
+
config.data_seed = 43
|
| 421 |
+
accelerator = Accelerator(dataloader_config=config)
|
| 422 |
+
prepared_dl = accelerator.prepare(train_dl)
|
| 423 |
+
new_items = []
|
| 424 |
+
for _ in range(3):
|
| 425 |
+
for batch in prepared_dl:
|
| 426 |
+
new_items.append(batch["x"])
|
| 427 |
+
new_items = torch.cat(new_items)
|
| 428 |
+
assert not torch.allclose(original_items, new_items), "Obtained the same items with different data seed."
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def mock_training(length, batch_size, generator, use_seedable_sampler=False):
|
| 432 |
+
set_seed(42)
|
| 433 |
+
generator.manual_seed(42)
|
| 434 |
+
train_set = RegressionDataset(length=length, seed=42)
|
| 435 |
+
|
| 436 |
+
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
|
| 437 |
+
model = RegressionModel()
|
| 438 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 439 |
+
for epoch in range(3):
|
| 440 |
+
for batch in train_dl:
|
| 441 |
+
model.zero_grad()
|
| 442 |
+
output = model(batch["x"])
|
| 443 |
+
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
| 444 |
+
loss.backward()
|
| 445 |
+
optimizer.step()
|
| 446 |
+
return train_set, model
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def training_check(use_seedable_sampler=False):
|
| 450 |
+
state = AcceleratorState()
|
| 451 |
+
generator = torch.Generator()
|
| 452 |
+
batch_size = 8
|
| 453 |
+
length = batch_size * 4 * state.num_processes
|
| 454 |
+
|
| 455 |
+
train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler)
|
| 456 |
+
assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
|
| 457 |
+
assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."
|
| 458 |
+
|
| 459 |
+
accelerator = Accelerator()
|
| 460 |
+
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
|
| 461 |
+
model = RegressionModel()
|
| 462 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 463 |
+
|
| 464 |
+
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
| 465 |
+
set_seed(42)
|
| 466 |
+
generator.manual_seed(42)
|
| 467 |
+
for _ in range(3):
|
| 468 |
+
for batch in train_dl:
|
| 469 |
+
model.zero_grad()
|
| 470 |
+
output = model(batch["x"])
|
| 471 |
+
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
| 472 |
+
accelerator.backward(loss)
|
| 473 |
+
optimizer.step()
|
| 474 |
+
|
| 475 |
+
model = accelerator.unwrap_model(model).cpu()
|
| 476 |
+
torch.testing.assert_close(
|
| 477 |
+
old_model.a,
|
| 478 |
+
model.a,
|
| 479 |
+
atol=ATOL,
|
| 480 |
+
rtol=RTOL,
|
| 481 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 482 |
+
)
|
| 483 |
+
torch.testing.assert_close(
|
| 484 |
+
old_model.b,
|
| 485 |
+
model.b,
|
| 486 |
+
atol=ATOL,
|
| 487 |
+
rtol=RTOL,
|
| 488 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
|
| 492 |
+
|
| 493 |
+
dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler)
|
| 494 |
+
accelerator = Accelerator(dataloader_config=dataloader_config)
|
| 495 |
+
train_dl = generate_baseline_dataloader(
|
| 496 |
+
train_set, generator, batch_size * state.num_processes, use_seedable_sampler
|
| 497 |
+
)
|
| 498 |
+
model = RegressionModel()
|
| 499 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 500 |
+
|
| 501 |
+
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
| 502 |
+
set_seed(42)
|
| 503 |
+
generator.manual_seed(42)
|
| 504 |
+
for _ in range(3):
|
| 505 |
+
for batch in train_dl:
|
| 506 |
+
model.zero_grad()
|
| 507 |
+
output = model(batch["x"])
|
| 508 |
+
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
| 509 |
+
accelerator.backward(loss)
|
| 510 |
+
optimizer.step()
|
| 511 |
+
|
| 512 |
+
model = accelerator.unwrap_model(model).cpu()
|
| 513 |
+
torch.testing.assert_close(
|
| 514 |
+
old_model.a,
|
| 515 |
+
model.a,
|
| 516 |
+
atol=ATOL,
|
| 517 |
+
rtol=RTOL,
|
| 518 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 519 |
+
)
|
| 520 |
+
torch.testing.assert_close(
|
| 521 |
+
old_model.b,
|
| 522 |
+
model.b,
|
| 523 |
+
atol=ATOL,
|
| 524 |
+
rtol=RTOL,
|
| 525 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
|
| 529 |
+
|
| 530 |
+
# FP32 wrapper check
|
| 531 |
+
if is_cuda_available() or is_mps_available():
|
| 532 |
+
# Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
|
| 533 |
+
print("Keep fp32 wrapper check.")
|
| 534 |
+
AcceleratorState._reset_state()
|
| 535 |
+
accelerator = Accelerator(mixed_precision="fp16")
|
| 536 |
+
|
| 537 |
+
model = torch.nn.Linear(2, 4)
|
| 538 |
+
model = accelerator.prepare(model)
|
| 539 |
+
model_with_fp32_wrapper = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
|
| 540 |
+
|
| 541 |
+
# Run forward with fp16 as input.
|
| 542 |
+
# When the model is with mixed precision wrapper, no error will be raised.
|
| 543 |
+
input_tensor = torch.Tensor([1, 2]).to(dtype=torch.float16, device=accelerator.device)
|
| 544 |
+
output = model_with_fp32_wrapper(input_tensor)
|
| 545 |
+
|
| 546 |
+
# BF16 support
|
| 547 |
+
if is_bf16_available():
|
| 548 |
+
# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
|
| 549 |
+
print("BF16 training check.")
|
| 550 |
+
AcceleratorState._reset_state()
|
| 551 |
+
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
|
| 552 |
+
accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config)
|
| 553 |
+
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
|
| 554 |
+
model = RegressionModel()
|
| 555 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 556 |
+
|
| 557 |
+
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
| 558 |
+
set_seed(42)
|
| 559 |
+
generator.manual_seed(42)
|
| 560 |
+
for _ in range(3):
|
| 561 |
+
for batch in train_dl:
|
| 562 |
+
model.zero_grad()
|
| 563 |
+
output = model(batch["x"])
|
| 564 |
+
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
| 565 |
+
accelerator.backward(loss)
|
| 566 |
+
optimizer.step()
|
| 567 |
+
|
| 568 |
+
model = accelerator.unwrap_model(model).cpu()
|
| 569 |
+
torch.testing.assert_close(
|
| 570 |
+
old_model.a,
|
| 571 |
+
model.a,
|
| 572 |
+
atol=ATOL,
|
| 573 |
+
rtol=RTOL,
|
| 574 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 575 |
+
)
|
| 576 |
+
torch.testing.assert_close(
|
| 577 |
+
old_model.b,
|
| 578 |
+
model.b,
|
| 579 |
+
atol=ATOL,
|
| 580 |
+
rtol=RTOL,
|
| 581 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# FP16 support (HPU fp16 model seems to be off by 10% from the CPU, which is a lot of numerical error)
|
| 585 |
+
if is_fp16_available() and not is_hpu_available():
|
| 586 |
+
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
|
| 587 |
+
print("FP16 training check.")
|
| 588 |
+
AcceleratorState._reset_state()
|
| 589 |
+
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
|
| 590 |
+
accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config)
|
| 591 |
+
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
|
| 592 |
+
model = RegressionModel()
|
| 593 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 594 |
+
|
| 595 |
+
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
| 596 |
+
set_seed(42)
|
| 597 |
+
generator.manual_seed(42)
|
| 598 |
+
for _ in range(3):
|
| 599 |
+
for batch in train_dl:
|
| 600 |
+
model.zero_grad()
|
| 601 |
+
output = model(batch["x"])
|
| 602 |
+
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
| 603 |
+
accelerator.backward(loss)
|
| 604 |
+
optimizer.step()
|
| 605 |
+
|
| 606 |
+
model = accelerator.unwrap_model(model).cpu()
|
| 607 |
+
torch.testing.assert_close(
|
| 608 |
+
old_model.a,
|
| 609 |
+
model.a,
|
| 610 |
+
atol=ATOL,
|
| 611 |
+
rtol=RTOL,
|
| 612 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 613 |
+
)
|
| 614 |
+
torch.testing.assert_close(
|
| 615 |
+
old_model.b,
|
| 616 |
+
model.b,
|
| 617 |
+
atol=ATOL,
|
| 618 |
+
rtol=RTOL,
|
| 619 |
+
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def test_split_between_processes_dataset(datasets_Dataset):
|
| 624 |
+
state = AcceleratorState()
|
| 625 |
+
data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
|
| 626 |
+
with state.split_between_processes(data, apply_padding=False) as results:
|
| 627 |
+
assert len(results) == 2, (
|
| 628 |
+
f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
|
| 632 |
+
with state.split_between_processes(data, apply_padding=False) as results:
|
| 633 |
+
if state.is_last_process:
|
| 634 |
+
assert len(results) == 1, (
|
| 635 |
+
f"Last process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}"
|
| 636 |
+
)
|
| 637 |
+
else:
|
| 638 |
+
assert len(results) == 2, (
|
| 639 |
+
f"One of the intermediate processes did not receive two items. Process index: {state.process_index}; Length: {len(results)}"
|
| 640 |
+
)
|
| 641 |
+
state.wait_for_everyone()
|
| 642 |
+
|
| 643 |
+
odd_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)])
|
| 644 |
+
even_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)])
|
| 645 |
+
|
| 646 |
+
for data in [odd_data, even_data]:
|
| 647 |
+
expected_output = data["k"]
|
| 648 |
+
|
| 649 |
+
with state.split_between_processes(data, apply_padding=True) as results:
|
| 650 |
+
if state.num_processes == 1:
|
| 651 |
+
assert len(results) == len(data), (
|
| 652 |
+
f"Single process did not receive all items. Process index: {state.process_index}; Length: {len(results)}"
|
| 653 |
+
)
|
| 654 |
+
else:
|
| 655 |
+
assert len(results) == 2, (
|
| 656 |
+
f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
results_per_process = []
|
| 660 |
+
for result in results:
|
| 661 |
+
results_per_process.append(result)
|
| 662 |
+
|
| 663 |
+
state.wait_for_everyone()
|
| 664 |
+
|
| 665 |
+
gathered_results = gather_object(results_per_process)
|
| 666 |
+
output = [r["k"] for r in gathered_results[: len(data)]]
|
| 667 |
+
|
| 668 |
+
assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def test_split_between_processes_list():
|
| 672 |
+
state = AcceleratorState()
|
| 673 |
+
data = list(range(0, 2 * state.num_processes))
|
| 674 |
+
with state.split_between_processes(data) as results:
|
| 675 |
+
assert len(results) == 2, (
|
| 676 |
+
f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}"
|
| 677 |
+
)
|
| 678 |
+
state.wait_for_everyone()
|
| 679 |
+
|
| 680 |
+
even_data = list(range(0, (2 * state.num_processes)))
|
| 681 |
+
odd_data = list(range(0, (2 * state.num_processes) - 1))
|
| 682 |
+
for data in [odd_data, even_data]:
|
| 683 |
+
expected_output = data
|
| 684 |
+
|
| 685 |
+
with state.split_between_processes(data, apply_padding=True) as results:
|
| 686 |
+
num_samples_per_device = math.ceil(len(data) / state.num_processes)
|
| 687 |
+
# Test all processes gets the correct number of item(s)
|
| 688 |
+
assert len(results) == num_samples_per_device, (
|
| 689 |
+
f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
results_per_process = []
|
| 693 |
+
for result in results:
|
| 694 |
+
results_per_process.append(result)
|
| 695 |
+
|
| 696 |
+
state.wait_for_everyone()
|
| 697 |
+
|
| 698 |
+
gathered_results = gather_object(results_per_process)
|
| 699 |
+
output = gathered_results[: len(data)]
|
| 700 |
+
|
| 701 |
+
assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def test_split_between_processes_nested_dict():
|
| 705 |
+
state = AcceleratorState()
|
| 706 |
+
a = [1, 2, 3, 4, 5, 6, 7, 8]
|
| 707 |
+
b = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
| 708 |
+
c = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
| 709 |
+
if state.num_processes in (1, 2, 4):
|
| 710 |
+
data = {"a": a, "b": b, "c": c}
|
| 711 |
+
data_copy = deepcopy(data)
|
| 712 |
+
with state.split_between_processes(data) as results:
|
| 713 |
+
if state.process_index == 0:
|
| 714 |
+
assert results["a"] == data_copy["a"][: 8 // state.num_processes]
|
| 715 |
+
elif state.num_processes == 2:
|
| 716 |
+
assert results["a"] == data_copy["a"][4:]
|
| 717 |
+
elif state.process_index == 3:
|
| 718 |
+
# We return a list each time
|
| 719 |
+
assert results["a"] == data_copy["a"][-2:], f"Expected: {data_copy['a'][-2]}, Actual: {results['a']}"
|
| 720 |
+
if state.process_index == 0:
|
| 721 |
+
assert results["b"] == data_copy["b"][: 8 // state.num_processes]
|
| 722 |
+
elif state.num_processes == 2:
|
| 723 |
+
assert results["b"] == data_copy["b"][4:]
|
| 724 |
+
elif state.process_index == 3:
|
| 725 |
+
assert results["b"] == data_copy["b"][-2:]
|
| 726 |
+
if state.process_index == 0:
|
| 727 |
+
assert torch.allclose(results["c"], data_copy["c"][: 8 // state.num_processes]), (
|
| 728 |
+
f"Did not obtain expected values on process 0, expected `{data['c'][: 8 // state.num_processes]}`, received: {results['c']}"
|
| 729 |
+
)
|
| 730 |
+
elif state.num_processes == 2:
|
| 731 |
+
assert torch.allclose(results["c"], data_copy["c"][4:]), (
|
| 732 |
+
f"Did not obtain expected values on process 2, expected `{data['c'][4:]}`, received: {results['c']}"
|
| 733 |
+
)
|
| 734 |
+
elif state.process_index == 3:
|
| 735 |
+
assert torch.allclose(results["c"], data_copy["c"][-2:]), (
|
| 736 |
+
f"Did not obtain expected values on process 4, expected `{data['c'][-2:]}`, received: {results['c']}"
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
state.wait_for_everyone()
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def test_split_between_processes_tensor():
|
| 743 |
+
state = AcceleratorState()
|
| 744 |
+
if state.num_processes > 1:
|
| 745 |
+
data = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).to(state.device)
|
| 746 |
+
with state.split_between_processes(data) as results:
|
| 747 |
+
if state.process_index == 0:
|
| 748 |
+
expected = torch.tensor([[0, 1, 2, 3]]).to(state.device)
|
| 749 |
+
else:
|
| 750 |
+
expected = torch.tensor([[4, 5, 6, 7]]).to(state.device)
|
| 751 |
+
torch.testing.assert_close(results, expected)
|
| 752 |
+
state.wait_for_everyone()
|
| 753 |
+
|
| 754 |
+
even_data = torch.tensor([[i] for i in range(2 * state.num_processes)]).to(state.device)
|
| 755 |
+
odd_data = torch.tensor([[i] for i in range(2 * state.num_processes - 1)]).to(state.device)
|
| 756 |
+
for data in [even_data, odd_data]:
|
| 757 |
+
expected_output = [torch.tensor(i) for i in data.tolist()]
|
| 758 |
+
|
| 759 |
+
with state.split_between_processes(data, apply_padding=True) as results:
|
| 760 |
+
num_samples_per_device = math.ceil(len(data) / state.num_processes)
|
| 761 |
+
assert len(results) == num_samples_per_device, (
|
| 762 |
+
f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}"
|
| 763 |
+
)
|
| 764 |
+
results_per_process = []
|
| 765 |
+
for result in results:
|
| 766 |
+
results_per_process.append(result.to("cpu"))
|
| 767 |
+
|
| 768 |
+
state.wait_for_everyone()
|
| 769 |
+
|
| 770 |
+
gathered_results = gather_object(results_per_process)
|
| 771 |
+
output = gathered_results[: len(data)]
|
| 772 |
+
|
| 773 |
+
assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}"
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def test_split_between_processes_evenly():
|
| 777 |
+
state = AcceleratorState()
|
| 778 |
+
if state.num_processes in (1, 2, 4, 8):
|
| 779 |
+
data = list(range(17))
|
| 780 |
+
num_samples_per_process = len(data) // state.num_processes
|
| 781 |
+
num_extras = len(data) % state.num_processes
|
| 782 |
+
with state.split_between_processes(data) as results:
|
| 783 |
+
if state.process_index < num_extras:
|
| 784 |
+
assert len(results) == num_samples_per_process + 1, (
|
| 785 |
+
f"Each Process should have even elements. Expected: {num_samples_per_process + 1}, Actual: {len(results)}"
|
| 786 |
+
)
|
| 787 |
+
else:
|
| 788 |
+
assert len(results) == num_samples_per_process, (
|
| 789 |
+
f"Each Process should have even elements. Expected: {num_samples_per_process}, Actual: {len(results)}"
|
| 790 |
+
)
|
| 791 |
+
state.wait_for_everyone()
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
def test_trigger():
|
| 795 |
+
accelerator = Accelerator()
|
| 796 |
+
# should start with being false
|
| 797 |
+
assert accelerator.check_trigger() is False
|
| 798 |
+
|
| 799 |
+
# set a breakpoint on the main process
|
| 800 |
+
if accelerator.is_main_process:
|
| 801 |
+
accelerator.set_trigger()
|
| 802 |
+
|
| 803 |
+
# check it's been activated across all processes
|
| 804 |
+
# calls `all_reduce` and triggers a sync
|
| 805 |
+
assert accelerator.check_trigger() is True
|
| 806 |
+
|
| 807 |
+
# check it's been reset after the sync
|
| 808 |
+
assert accelerator.check_trigger() is False
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
def test_reinstantiated_state():
|
| 812 |
+
import pytest
|
| 813 |
+
|
| 814 |
+
AcceleratorState._reset_state()
|
| 815 |
+
simple_model = torch.nn.Linear(1, 1)
|
| 816 |
+
# First define an accelerator
|
| 817 |
+
accelerator = Accelerator()
|
| 818 |
+
# Then call `reset_state`, breaking the state existing in the accelerator
|
| 819 |
+
AcceleratorState._reset_state()
|
| 820 |
+
# Now try and prepare a simple model, should raise the custom error early
|
| 821 |
+
with pytest.raises(AttributeError) as cm:
|
| 822 |
+
accelerator.prepare(simple_model)
|
| 823 |
+
assert "`AcceleratorState` object has no attribute" in str(cm.value.args[0])
|
| 824 |
+
assert "This happens if `AcceleratorState._reset_state()`" in str(cm.value.args[0])
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def main():
|
| 828 |
+
accelerator = Accelerator()
|
| 829 |
+
state = accelerator.state
|
| 830 |
+
if state.local_process_index == 0:
|
| 831 |
+
print("**Initialization**")
|
| 832 |
+
init_state_check()
|
| 833 |
+
state.wait_for_everyone()
|
| 834 |
+
|
| 835 |
+
if state.distributed_type == DistributedType.MULTI_GPU:
|
| 836 |
+
num_processes_per_node = torch.cuda.device_count()
|
| 837 |
+
else:
|
| 838 |
+
num_processes_per_node = state.num_processes
|
| 839 |
+
|
| 840 |
+
# We only run this test on non-multinode
|
| 841 |
+
if num_processes_per_node == state.num_processes:
|
| 842 |
+
if state.process_index == 0:
|
| 843 |
+
print("\n**Test process execution**")
|
| 844 |
+
process_execution_check()
|
| 845 |
+
|
| 846 |
+
if state.process_index == 0:
|
| 847 |
+
print("\n**Test split between processes as a list**")
|
| 848 |
+
test_split_between_processes_list()
|
| 849 |
+
|
| 850 |
+
if state.process_index == 0:
|
| 851 |
+
print("\n**Test split between processes as a dict**")
|
| 852 |
+
test_split_between_processes_nested_dict()
|
| 853 |
+
|
| 854 |
+
if state.process_index == 0:
|
| 855 |
+
print("\n**Test split between processes as a tensor**")
|
| 856 |
+
test_split_between_processes_tensor()
|
| 857 |
+
|
| 858 |
+
if state.process_index == 0:
|
| 859 |
+
print("\n**Test split between processes evenly**")
|
| 860 |
+
test_split_between_processes_evenly()
|
| 861 |
+
|
| 862 |
+
if state.process_index == 0:
|
| 863 |
+
print("\n**Test split between processes as a datasets.Dataset**")
|
| 864 |
+
if is_datasets_available():
|
| 865 |
+
from datasets import Dataset as datasets_Dataset
|
| 866 |
+
|
| 867 |
+
test_split_between_processes_dataset(datasets_Dataset)
|
| 868 |
+
else:
|
| 869 |
+
print("Skipped because Hugging Face datasets is not available")
|
| 870 |
+
|
| 871 |
+
if state.local_process_index == 0:
|
| 872 |
+
print("\n**Test random number generator synchronization**")
|
| 873 |
+
rng_sync_check()
|
| 874 |
+
|
| 875 |
+
if state.local_process_index == 0:
|
| 876 |
+
print("\n**DataLoader integration test**")
|
| 877 |
+
dl_preparation_check()
|
| 878 |
+
if state.distributed_type != DistributedType.XLA:
|
| 879 |
+
central_dl_preparation_check()
|
| 880 |
+
custom_sampler_check()
|
| 881 |
+
check_seedable_sampler()
|
| 882 |
+
check_seedable_sampler_with_data_seed()
|
| 883 |
+
|
| 884 |
+
if state.num_processes > 1:
|
| 885 |
+
check_seedable_sampler_in_batch_sampler_shard()
|
| 886 |
+
|
| 887 |
+
# Trainings are not exactly the same in DeepSpeed and CPU mode
|
| 888 |
+
if state.distributed_type == DistributedType.DEEPSPEED:
|
| 889 |
+
return
|
| 890 |
+
|
| 891 |
+
if state.local_process_index == 0:
|
| 892 |
+
print("\n**Training integration test**")
|
| 893 |
+
training_check(use_seedable_sampler=False)
|
| 894 |
+
training_check(use_seedable_sampler=True)
|
| 895 |
+
|
| 896 |
+
if state.local_process_index == 0:
|
| 897 |
+
print("\n**Breakpoint trigger test**")
|
| 898 |
+
test_trigger()
|
| 899 |
+
|
| 900 |
+
if is_pytest_available():
|
| 901 |
+
if state.local_process_index == 0:
|
| 902 |
+
print("\n**Test reinstantiated state**")
|
| 903 |
+
test_reinstantiated_state()
|
| 904 |
+
|
| 905 |
+
state.destroy_process_group()
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
if __name__ == "__main__":
|
| 909 |
+
main()
|
accelerate/test_utils/scripts/test_sync.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.optim import AdamW
|
| 20 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 21 |
+
from torch.utils.data import DataLoader
|
| 22 |
+
|
| 23 |
+
from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin
|
| 24 |
+
from accelerate.state import GradientState
|
| 25 |
+
from accelerate.test_utils import RegressionDataset, RegressionModel
|
| 26 |
+
from accelerate.utils import DistributedType, set_seed
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs):
|
| 30 |
+
for param, grad_param in zip(model_a.parameters(), model_b.parameters()):
|
| 31 |
+
if not param.requires_grad:
|
| 32 |
+
continue
|
| 33 |
+
if not did_step:
|
| 34 |
+
# Grads should not be in sync
|
| 35 |
+
assert torch.allclose(param.grad, grad_param.grad, **kwargs) is False, (
|
| 36 |
+
f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})"
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
# Grads should be in sync
|
| 40 |
+
assert torch.allclose(param.grad, grad_param.grad, **kwargs) is True, (
|
| 41 |
+
f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def step_model(model, input, target, accelerator, do_backward=True):
|
| 46 |
+
model.train()
|
| 47 |
+
output = model(input)
|
| 48 |
+
loss = F.mse_loss(output, target.to(output.device))
|
| 49 |
+
if not do_backward:
|
| 50 |
+
loss /= accelerator.gradient_accumulation_steps
|
| 51 |
+
loss.backward()
|
| 52 |
+
else:
|
| 53 |
+
accelerator.backward(loss)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_training_setup(accelerator, sched=False):
|
| 57 |
+
"Returns everything needed to perform basic training"
|
| 58 |
+
set_seed(42)
|
| 59 |
+
model = RegressionModel()
|
| 60 |
+
ddp_model = deepcopy(model)
|
| 61 |
+
dset = RegressionDataset(length=80)
|
| 62 |
+
dataloader = DataLoader(dset, batch_size=16)
|
| 63 |
+
model.to(accelerator.device)
|
| 64 |
+
if sched:
|
| 65 |
+
opt = AdamW(params=model.parameters(), lr=1e-3)
|
| 66 |
+
ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3)
|
| 67 |
+
sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65)
|
| 68 |
+
ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65)
|
| 69 |
+
# Make a copy of `model`
|
| 70 |
+
if sched:
|
| 71 |
+
ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader)
|
| 72 |
+
else:
|
| 73 |
+
ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
|
| 74 |
+
if sched:
|
| 75 |
+
return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched)
|
| 76 |
+
return model, ddp_model, dataloader
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_noop_sync(accelerator):
|
| 80 |
+
# Test when on a single CPU or GPU that the context manager does nothing
|
| 81 |
+
model, ddp_model, dataloader = get_training_setup(accelerator)
|
| 82 |
+
# Use a single batch
|
| 83 |
+
ddp_input, ddp_target = next(iter(dataloader)).values()
|
| 84 |
+
for iteration in range(3):
|
| 85 |
+
# Gather the distributed inputs and targs for the base model
|
| 86 |
+
input, target = accelerator.gather((ddp_input, ddp_target))
|
| 87 |
+
input, target = input.to(accelerator.device), target.to(accelerator.device)
|
| 88 |
+
# Perform our initial ground truth step in non "DDP"
|
| 89 |
+
step_model(model, input, target, accelerator)
|
| 90 |
+
# Do "gradient accumulation" (noop)
|
| 91 |
+
if iteration % 2 == 0:
|
| 92 |
+
# Accumulate grads locally
|
| 93 |
+
with accelerator.no_sync(ddp_model):
|
| 94 |
+
step_model(ddp_model, ddp_input, ddp_target, accelerator)
|
| 95 |
+
else:
|
| 96 |
+
# Sync grads
|
| 97 |
+
step_model(ddp_model, ddp_input, ddp_target, accelerator)
|
| 98 |
+
|
| 99 |
+
# Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync
|
| 100 |
+
check_model_parameters(model, ddp_model, True, iteration)
|
| 101 |
+
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
|
| 102 |
+
if not param.requires_grad:
|
| 103 |
+
continue
|
| 104 |
+
assert torch.allclose(param.grad, ddp_param.grad), (
|
| 105 |
+
f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Shuffle ddp_input on each iteration
|
| 109 |
+
torch.manual_seed(1337 + iteration)
|
| 110 |
+
ddp_input = ddp_input[torch.randperm(len(ddp_input))]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_distributed_sync(accelerator):
|
| 114 |
+
# Test on distributed setup that context manager behaves properly
|
| 115 |
+
model, ddp_model, dataloader = get_training_setup(accelerator)
|
| 116 |
+
# Use a single batch
|
| 117 |
+
ddp_input, ddp_target = next(iter(dataloader)).values()
|
| 118 |
+
for iteration in range(3):
|
| 119 |
+
# Gather the distributed inputs and targs for the base model
|
| 120 |
+
input, target = accelerator.gather((ddp_input, ddp_target))
|
| 121 |
+
input, target = input.to(accelerator.device), target.to(accelerator.device)
|
| 122 |
+
# Perform our initial ground truth step in non "DDP"
|
| 123 |
+
step_model(model, input, target, accelerator)
|
| 124 |
+
# Do "gradient accumulation" (noop)
|
| 125 |
+
if iteration % 2 == 0:
|
| 126 |
+
# Accumulate grads locally
|
| 127 |
+
with accelerator.no_sync(ddp_model):
|
| 128 |
+
step_model(ddp_model, ddp_input, ddp_target, accelerator)
|
| 129 |
+
else:
|
| 130 |
+
# Sync grads
|
| 131 |
+
step_model(ddp_model, ddp_input, ddp_target, accelerator)
|
| 132 |
+
|
| 133 |
+
# DDP model and model should only be in sync when not (iteration % 2 == 0)
|
| 134 |
+
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
|
| 135 |
+
if not param.requires_grad:
|
| 136 |
+
continue
|
| 137 |
+
if iteration % 2 == 0:
|
| 138 |
+
# Grads should not be in sync
|
| 139 |
+
assert torch.allclose(param.grad, ddp_param.grad) is False, (
|
| 140 |
+
f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
# Grads should be in sync
|
| 144 |
+
assert torch.allclose(param.grad, ddp_param.grad) is True, (
|
| 145 |
+
f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Shuffle ddp_input on each iteration
|
| 149 |
+
torch.manual_seed(1337 + iteration)
|
| 150 |
+
ddp_input = ddp_input[torch.randperm(len(ddp_input))]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def test_distributed_sync_multiple_fwd(accelerator):
|
| 154 |
+
# Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards
|
| 155 |
+
model, ddp_model, dataloader = get_training_setup(accelerator)
|
| 156 |
+
# Do multiple forwards
|
| 157 |
+
losses = []
|
| 158 |
+
num_iterations = 3
|
| 159 |
+
for iteration in range(num_iterations):
|
| 160 |
+
ddp_input, ddp_target = next(iter(dataloader)).values()
|
| 161 |
+
|
| 162 |
+
# Gather the distributed inputs and targs for the base model
|
| 163 |
+
input, target = accelerator.gather((ddp_input, ddp_target))
|
| 164 |
+
input, target = input.to(accelerator.device), target.to(accelerator.device)
|
| 165 |
+
|
| 166 |
+
# Perform our initial ground truth step in non "DDP"
|
| 167 |
+
step_model(model, input, target, accelerator)
|
| 168 |
+
|
| 169 |
+
# Accumulate grads locally
|
| 170 |
+
with accelerator.no_sync(ddp_model):
|
| 171 |
+
ddp_output = ddp_model(ddp_input)
|
| 172 |
+
loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device))
|
| 173 |
+
losses.append(loss)
|
| 174 |
+
|
| 175 |
+
# Do multiple backwards and sync only at the last backward
|
| 176 |
+
for iteration in range(num_iterations):
|
| 177 |
+
loss = losses[iteration]
|
| 178 |
+
|
| 179 |
+
if iteration < num_iterations - 1:
|
| 180 |
+
# Accumulate grads locally
|
| 181 |
+
accelerator.backward(loss)
|
| 182 |
+
|
| 183 |
+
# DDP model and model should only be in sync after last backward
|
| 184 |
+
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
|
| 185 |
+
if not param.requires_grad:
|
| 186 |
+
continue
|
| 187 |
+
# Grads should not be in sync
|
| 188 |
+
assert torch.allclose(param.grad, ddp_param.grad) is False, (
|
| 189 |
+
f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
else:
|
| 193 |
+
# Sync grads if last backward
|
| 194 |
+
with accelerator.trigger_sync_in_backward(ddp_model):
|
| 195 |
+
accelerator.backward(loss)
|
| 196 |
+
|
| 197 |
+
# DDP model and model should only be in sync after last backward
|
| 198 |
+
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
|
| 199 |
+
if not param.requires_grad:
|
| 200 |
+
continue
|
| 201 |
+
# Grads should be in sync
|
| 202 |
+
assert torch.allclose(param.grad, ddp_param.grad) is True, (
|
| 203 |
+
f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False):
|
| 208 |
+
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
|
| 209 |
+
dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
|
| 210 |
+
accelerator = Accelerator(
|
| 211 |
+
dataloader_config=dataloader_config,
|
| 212 |
+
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
| 213 |
+
)
|
| 214 |
+
# Test that context manager behaves properly
|
| 215 |
+
model, ddp_model, dataloader = get_training_setup(accelerator)
|
| 216 |
+
for iteration, batch in enumerate(dataloader):
|
| 217 |
+
ddp_input, ddp_target = batch.values()
|
| 218 |
+
# Gather the distributed inputs and targs for the base model
|
| 219 |
+
input, target = accelerator.gather((ddp_input, ddp_target))
|
| 220 |
+
input, target = input.to(accelerator.device), target.to(accelerator.device)
|
| 221 |
+
# Perform our initial ground truth step in non "DDP"
|
| 222 |
+
step_model(model, input, target, accelerator, False)
|
| 223 |
+
# Do "gradient accumulation" (noop)
|
| 224 |
+
with accelerator.accumulate(ddp_model):
|
| 225 |
+
step_model(ddp_model, ddp_input, ddp_target, accelerator)
|
| 226 |
+
|
| 227 |
+
# DDP model and model should only be in sync when not (iteration % 2 == 0)
|
| 228 |
+
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
|
| 229 |
+
if not param.requires_grad:
|
| 230 |
+
continue
|
| 231 |
+
if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch:
|
| 232 |
+
# Grads should be in sync
|
| 233 |
+
assert torch.allclose(param.grad, ddp_param.grad) is True, (
|
| 234 |
+
f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
# Grads should not be in sync
|
| 238 |
+
assert torch.allclose(param.grad, ddp_param.grad) is False, (
|
| 239 |
+
f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Shuffle ddp_input on each iteration
|
| 243 |
+
torch.manual_seed(1337 + iteration)
|
| 244 |
+
ddp_input = ddp_input[torch.randperm(len(ddp_input))]
|
| 245 |
+
GradientState._reset_state()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def test_gradient_accumulation_with_opt_and_scheduler(
|
| 249 |
+
split_batches=False, dispatch_batches=False, sync_each_batch=False
|
| 250 |
+
):
|
| 251 |
+
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
|
| 252 |
+
dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
|
| 253 |
+
accelerator = Accelerator(
|
| 254 |
+
dataloader_config=dataloader_config,
|
| 255 |
+
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
| 256 |
+
)
|
| 257 |
+
# Test that context manager behaves properly
|
| 258 |
+
model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True)
|
| 259 |
+
for iteration, batch in enumerate(dataloader):
|
| 260 |
+
ddp_input, ddp_target = batch.values()
|
| 261 |
+
# Gather the distributed inputs and targs for the base model
|
| 262 |
+
input, target = accelerator.gather((ddp_input, ddp_target))
|
| 263 |
+
input, target = input.to(accelerator.device), target.to(accelerator.device)
|
| 264 |
+
# Perform our initial ground truth step in non "DDP"
|
| 265 |
+
model.train()
|
| 266 |
+
ddp_model.train()
|
| 267 |
+
step_model(model, input, target, accelerator, False)
|
| 268 |
+
opt.step()
|
| 269 |
+
|
| 270 |
+
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)):
|
| 271 |
+
if split_batches:
|
| 272 |
+
sched.step()
|
| 273 |
+
else:
|
| 274 |
+
for _ in range(accelerator.num_processes):
|
| 275 |
+
sched.step()
|
| 276 |
+
|
| 277 |
+
# Perform gradient accumulation under wrapper
|
| 278 |
+
with accelerator.accumulate(ddp_model):
|
| 279 |
+
step_model(ddp_model, ddp_input, ddp_target, accelerator)
|
| 280 |
+
ddp_opt.step()
|
| 281 |
+
ddp_sched.step()
|
| 282 |
+
|
| 283 |
+
# Learning rates should be the same
|
| 284 |
+
assert opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"], (
|
| 285 |
+
f"Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]['lr']}\nDDP opt: {ddp_opt.param_groups[0]['lr']}\n"
|
| 286 |
+
)
|
| 287 |
+
did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader))
|
| 288 |
+
if accelerator.num_processes > 1:
|
| 289 |
+
check_model_parameters(
|
| 290 |
+
model,
|
| 291 |
+
ddp_model,
|
| 292 |
+
did_step or sync_each_batch, # syncs at each grad_accum interval of if sync_each_batch==True
|
| 293 |
+
iteration,
|
| 294 |
+
rtol=1e-3, # needs a relative tolerance due to roundoff errors
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if did_step:
|
| 298 |
+
opt.zero_grad() # flush gradients every accum step
|
| 299 |
+
ddp_opt.zero_grad()
|
| 300 |
+
|
| 301 |
+
# Shuffle ddp_input on each iteration
|
| 302 |
+
torch.manual_seed(1337 + iteration)
|
| 303 |
+
GradientState._reset_state()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def test_dataloader_break():
|
| 307 |
+
accelerator = Accelerator()
|
| 308 |
+
first_dset = RegressionDataset(length=80)
|
| 309 |
+
first_dataloader = DataLoader(first_dset, batch_size=16)
|
| 310 |
+
second_dset = RegressionDataset(length=96)
|
| 311 |
+
second_dataloader = DataLoader(second_dset, batch_size=16)
|
| 312 |
+
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
|
| 313 |
+
|
| 314 |
+
assert accelerator.gradient_state.active_dataloader is None
|
| 315 |
+
for iteration, _ in enumerate(first_dataloader):
|
| 316 |
+
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
|
| 317 |
+
if iteration < len(first_dataloader) - 1:
|
| 318 |
+
assert not accelerator.gradient_state.end_of_dataloader
|
| 319 |
+
if iteration == 1:
|
| 320 |
+
for batch_num, _ in enumerate(second_dataloader):
|
| 321 |
+
assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader)
|
| 322 |
+
if batch_num < len(second_dataloader) - 1:
|
| 323 |
+
assert not accelerator.gradient_state.end_of_dataloader
|
| 324 |
+
else:
|
| 325 |
+
assert accelerator.gradient_state.end_of_dataloader
|
| 326 |
+
else:
|
| 327 |
+
assert accelerator.gradient_state.end_of_dataloader
|
| 328 |
+
assert accelerator.gradient_state.active_dataloader is None
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def main():
|
| 332 |
+
accelerator = Accelerator()
|
| 333 |
+
state = accelerator.state
|
| 334 |
+
if state.local_process_index == 0:
|
| 335 |
+
print("**Test `accumulate` gradient accumulation with dataloader break**")
|
| 336 |
+
if state.distributed_type != DistributedType.XLA:
|
| 337 |
+
test_dataloader_break()
|
| 338 |
+
if state.distributed_type == DistributedType.NO:
|
| 339 |
+
if state.local_process_index == 0:
|
| 340 |
+
print("**Test NOOP `no_sync` context manager**")
|
| 341 |
+
test_noop_sync(accelerator)
|
| 342 |
+
if state.distributed_type in (
|
| 343 |
+
DistributedType.MULTI_GPU,
|
| 344 |
+
DistributedType.MULTI_NPU,
|
| 345 |
+
DistributedType.MULTI_MLU,
|
| 346 |
+
DistributedType.MULTI_SDAA,
|
| 347 |
+
DistributedType.MULTI_MUSA,
|
| 348 |
+
DistributedType.MULTI_CPU,
|
| 349 |
+
DistributedType.MULTI_HPU,
|
| 350 |
+
DistributedType.MULTI_NEURON,
|
| 351 |
+
):
|
| 352 |
+
if state.local_process_index == 0:
|
| 353 |
+
print("**Test Distributed `no_sync` context manager**")
|
| 354 |
+
test_distributed_sync(accelerator)
|
| 355 |
+
if state.local_process_index == 0:
|
| 356 |
+
print("**Test Distributed `no_sync` context manager with multiple forwards**")
|
| 357 |
+
test_distributed_sync_multiple_fwd(accelerator)
|
| 358 |
+
if state.distributed_type in (
|
| 359 |
+
DistributedType.MULTI_GPU,
|
| 360 |
+
DistributedType.MULTI_NPU,
|
| 361 |
+
DistributedType.MULTI_MLU,
|
| 362 |
+
DistributedType.MULTI_SDAA,
|
| 363 |
+
DistributedType.MULTI_MUSA,
|
| 364 |
+
DistributedType.MULTI_HPU,
|
| 365 |
+
DistributedType.MULTI_NEURON,
|
| 366 |
+
):
|
| 367 |
+
for split_batch in [True, False]:
|
| 368 |
+
for dispatch_batches in [True, False]:
|
| 369 |
+
for sync_each_batch in [True, False]:
|
| 370 |
+
if state.local_process_index == 0:
|
| 371 |
+
print(
|
| 372 |
+
"**Test `accumulate` gradient accumulation, ",
|
| 373 |
+
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
|
| 374 |
+
)
|
| 375 |
+
test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch)
|
| 376 |
+
|
| 377 |
+
# Currently will break on torch 2.0 +, need to investigate why
|
| 378 |
+
if state.local_process_index == 0:
|
| 379 |
+
print(
|
| 380 |
+
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
|
| 381 |
+
"`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**",
|
| 382 |
+
)
|
| 383 |
+
test_gradient_accumulation_with_opt_and_scheduler()
|
| 384 |
+
if state.distributed_type in (
|
| 385 |
+
DistributedType.MULTI_GPU,
|
| 386 |
+
DistributedType.MULTI_NPU,
|
| 387 |
+
DistributedType.MULTI_MLU,
|
| 388 |
+
DistributedType.MULTI_SDAA,
|
| 389 |
+
DistributedType.MULTI_MUSA,
|
| 390 |
+
DistributedType.MULTI_HPU,
|
| 391 |
+
DistributedType.MULTI_NEURON,
|
| 392 |
+
):
|
| 393 |
+
for split_batch in [True, False]:
|
| 394 |
+
for dispatch_batches in [True, False]:
|
| 395 |
+
for sync_each_batch in [True, False]:
|
| 396 |
+
if not split_batch and not dispatch_batches and not sync_each_batch:
|
| 397 |
+
continue
|
| 398 |
+
if state.local_process_index == 0:
|
| 399 |
+
print(
|
| 400 |
+
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
|
| 401 |
+
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
|
| 402 |
+
)
|
| 403 |
+
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
|
| 404 |
+
state.destroy_process_group()
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _mp_fn(index):
|
| 408 |
+
# For xla_spawn (TPUs)
|
| 409 |
+
main()
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
if __name__ == "__main__":
|
| 413 |
+
main()
|
accelerate/test_utils/testing.py
ADDED
|
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import inspect
|
| 17 |
+
import io
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import shutil
|
| 21 |
+
import subprocess
|
| 22 |
+
import sys
|
| 23 |
+
import tempfile
|
| 24 |
+
import unittest
|
| 25 |
+
from contextlib import contextmanager
|
| 26 |
+
from functools import partial
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Optional, Union
|
| 29 |
+
from unittest import mock
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
|
| 33 |
+
import accelerate
|
| 34 |
+
|
| 35 |
+
from ..state import AcceleratorState
|
| 36 |
+
from ..utils import (
|
| 37 |
+
check_cuda_fp8_capability,
|
| 38 |
+
compare_versions,
|
| 39 |
+
gather,
|
| 40 |
+
is_aim_available,
|
| 41 |
+
is_bnb_available,
|
| 42 |
+
is_clearml_available,
|
| 43 |
+
is_comet_ml_available,
|
| 44 |
+
is_cuda_available,
|
| 45 |
+
is_datasets_available,
|
| 46 |
+
is_deepspeed_available,
|
| 47 |
+
is_dvclive_available,
|
| 48 |
+
is_fp8_available,
|
| 49 |
+
is_fp16_available,
|
| 50 |
+
is_habana_gaudi1,
|
| 51 |
+
is_hpu_available,
|
| 52 |
+
is_import_timer_available,
|
| 53 |
+
is_matplotlib_available,
|
| 54 |
+
is_mlflow_available,
|
| 55 |
+
is_mlu_available,
|
| 56 |
+
is_mps_available,
|
| 57 |
+
is_musa_available,
|
| 58 |
+
is_neuron_available,
|
| 59 |
+
is_npu_available,
|
| 60 |
+
is_pandas_available,
|
| 61 |
+
is_pippy_available,
|
| 62 |
+
is_pytest_available,
|
| 63 |
+
is_schedulefree_available,
|
| 64 |
+
is_sdaa_available,
|
| 65 |
+
is_swanlab_available,
|
| 66 |
+
is_tensorboard_available,
|
| 67 |
+
is_timm_available,
|
| 68 |
+
is_torch_version,
|
| 69 |
+
is_torch_xla_available,
|
| 70 |
+
is_torchao_available,
|
| 71 |
+
is_torchdata_stateful_dataloader_available,
|
| 72 |
+
is_torchvision_available,
|
| 73 |
+
is_trackio_available,
|
| 74 |
+
is_transformer_engine_available,
|
| 75 |
+
is_transformer_engine_mxfp8_available,
|
| 76 |
+
is_transformers_available,
|
| 77 |
+
is_triton_available,
|
| 78 |
+
is_wandb_available,
|
| 79 |
+
is_xpu_available,
|
| 80 |
+
str_to_bool,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_backend():
|
| 85 |
+
if is_torch_xla_available():
|
| 86 |
+
return "xla", torch.cuda.device_count(), torch.cuda.memory_allocated
|
| 87 |
+
elif is_cuda_available():
|
| 88 |
+
return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated
|
| 89 |
+
elif is_mps_available(min_version="2.0"):
|
| 90 |
+
return "mps", 1, torch.mps.current_allocated_memory
|
| 91 |
+
elif is_mps_available():
|
| 92 |
+
return "mps", 1, lambda: 0
|
| 93 |
+
elif is_mlu_available():
|
| 94 |
+
return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated
|
| 95 |
+
elif is_sdaa_available():
|
| 96 |
+
return "sdaa", torch.sdaa.device_count(), torch.sdaa.memory_allocated
|
| 97 |
+
elif is_musa_available():
|
| 98 |
+
return "musa", torch.musa.device_count(), torch.musa.memory_allocated
|
| 99 |
+
elif is_npu_available():
|
| 100 |
+
return "npu", torch.npu.device_count(), torch.npu.memory_allocated
|
| 101 |
+
elif is_xpu_available():
|
| 102 |
+
return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated
|
| 103 |
+
elif is_hpu_available():
|
| 104 |
+
return "hpu", torch.hpu.device_count(), torch.hpu.memory_allocated
|
| 105 |
+
elif is_neuron_available():
|
| 106 |
+
return "neuron", torch.neuron.device_count(), torch.neuron.memory_allocated
|
| 107 |
+
else:
|
| 108 |
+
return "cpu", 1, lambda: 0
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
torch_device, device_count, memory_allocated_func = get_backend()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_launch_command(**kwargs) -> list:
|
| 115 |
+
"""
|
| 116 |
+
Wraps around `kwargs` to help simplify launching from `subprocess`.
|
| 117 |
+
|
| 118 |
+
Example:
|
| 119 |
+
```python
|
| 120 |
+
# returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2']
|
| 121 |
+
get_launch_command(num_processes=2, device_count=2)
|
| 122 |
+
```
|
| 123 |
+
"""
|
| 124 |
+
command = ["accelerate", "launch"]
|
| 125 |
+
for k, v in kwargs.items():
|
| 126 |
+
if isinstance(v, bool) and v:
|
| 127 |
+
command.append(f"--{k}")
|
| 128 |
+
elif v is not None:
|
| 129 |
+
command.append(f"--{k}={v}")
|
| 130 |
+
return command
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
DEFAULT_LAUNCH_COMMAND = get_launch_command(num_processes=device_count, monitor_interval=0.1)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_flag_from_env(key, default=False):
|
| 137 |
+
try:
|
| 138 |
+
value = os.environ[key]
|
| 139 |
+
except KeyError:
|
| 140 |
+
# KEY isn't set, default to `default`.
|
| 141 |
+
_value = default
|
| 142 |
+
else:
|
| 143 |
+
# KEY is set, convert it to True or False.
|
| 144 |
+
try:
|
| 145 |
+
_value = str_to_bool(value)
|
| 146 |
+
except ValueError:
|
| 147 |
+
# More values are supported, but let's keep the message simple.
|
| 148 |
+
raise ValueError(f"If set, {key} must be yes or no.")
|
| 149 |
+
return _value
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def skip(test_case):
|
| 156 |
+
"Decorator that skips a test unconditionally"
|
| 157 |
+
return unittest.skip("Test was skipped")(test_case)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def slow(test_case):
|
| 161 |
+
"""
|
| 162 |
+
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
|
| 163 |
+
truthy value to run them.
|
| 164 |
+
"""
|
| 165 |
+
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def require_cpu(test_case):
|
| 169 |
+
"""
|
| 170 |
+
Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available.
|
| 171 |
+
"""
|
| 172 |
+
return unittest.skipUnless(torch_device == "cpu", "test requires only a CPU")(test_case)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def require_non_cpu(test_case):
|
| 176 |
+
"""
|
| 177 |
+
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
|
| 178 |
+
hardware accelerator available.
|
| 179 |
+
"""
|
| 180 |
+
return unittest.skipUnless(torch_device != "cpu", "test requires a GPU")(test_case)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def require_cuda(test_case):
|
| 184 |
+
"""
|
| 185 |
+
Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available or when
|
| 186 |
+
TorchXLA is available.
|
| 187 |
+
"""
|
| 188 |
+
return unittest.skipUnless(is_cuda_available() and not is_torch_xla_available(), "test requires a GPU")(test_case)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def require_cuda_or_hpu(test_case):
|
| 192 |
+
"""
|
| 193 |
+
Decorator marking a test that requires CUDA or HPU. These tests are skipped when there are no GPU available or when
|
| 194 |
+
TorchXLA is available.
|
| 195 |
+
"""
|
| 196 |
+
return unittest.skipUnless(
|
| 197 |
+
(is_cuda_available() and not is_torch_xla_available()) or is_hpu_available(), "test requires a GPU or HPU"
|
| 198 |
+
)(test_case)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def require_xpu(test_case):
|
| 202 |
+
"""
|
| 203 |
+
Decorator marking a test that requires XPU. These tests are skipped when there are no XPU available.
|
| 204 |
+
"""
|
| 205 |
+
return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def require_cuda_or_xpu(test_case):
|
| 209 |
+
"""
|
| 210 |
+
Decorator marking a test that requires CUDA or XPU. These tests are skipped when there are no GPU available or when
|
| 211 |
+
TorchXLA is available.
|
| 212 |
+
"""
|
| 213 |
+
cuda_condition = is_cuda_available() and not is_torch_xla_available()
|
| 214 |
+
xpu_condition = is_xpu_available()
|
| 215 |
+
return unittest.skipUnless(cuda_condition or xpu_condition, "test requires a CUDA GPU or XPU")(test_case)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def require_non_xpu(test_case):
|
| 219 |
+
"""
|
| 220 |
+
Decorator marking a test that should be skipped for XPU.
|
| 221 |
+
"""
|
| 222 |
+
return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def require_non_hpu(test_case):
|
| 226 |
+
"""
|
| 227 |
+
Decorator marking a test that should be skipped for HPU.
|
| 228 |
+
"""
|
| 229 |
+
return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def require_fp16(test_case):
|
| 233 |
+
"""
|
| 234 |
+
Decorator marking a test that requires FP16. These tests are skipped when FP16 is not supported.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
return unittest.skipUnless(is_fp16_available(), "test requires FP16 support")(test_case)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def require_fp8(test_case):
|
| 241 |
+
"""
|
| 242 |
+
Decorator marking a test that requires FP8. These tests are skipped when FP8 is not supported.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
# is_fp8_available only checks for libraries
|
| 246 |
+
# ideally it should check for device capability as well
|
| 247 |
+
fp8_is_available = is_fp8_available()
|
| 248 |
+
|
| 249 |
+
if torch.cuda.is_available() and not check_cuda_fp8_capability():
|
| 250 |
+
fp8_is_available = False
|
| 251 |
+
|
| 252 |
+
if is_hpu_available() and is_habana_gaudi1():
|
| 253 |
+
fp8_is_available = False
|
| 254 |
+
|
| 255 |
+
return unittest.skipUnless(fp8_is_available, "test requires FP8 support")(test_case)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def require_fsdp2(test_case):
|
| 259 |
+
return unittest.skipUnless(is_torch_version(">=", "2.5.0"), "test requires FSDP2 (torch >= 2.5.0)")(test_case)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def require_mlu(test_case):
|
| 263 |
+
"""
|
| 264 |
+
Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.
|
| 265 |
+
"""
|
| 266 |
+
return unittest.skipUnless(is_mlu_available(), "test require a MLU")(test_case)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def require_sdaa(test_case):
|
| 270 |
+
"""
|
| 271 |
+
Decorator marking a test that requires SDAA. These tests are skipped when there are no SDAA available.
|
| 272 |
+
"""
|
| 273 |
+
return unittest.skipUnless(is_sdaa_available(), "test require a SDAA")(test_case)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def require_musa(test_case):
|
| 277 |
+
"""
|
| 278 |
+
Decorator marking a test that requires MUSA. These tests are skipped when there are no MUSA available.
|
| 279 |
+
"""
|
| 280 |
+
return unittest.skipUnless(is_musa_available(), "test require a MUSA")(test_case)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def require_npu(test_case):
|
| 284 |
+
"""
|
| 285 |
+
Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available.
|
| 286 |
+
"""
|
| 287 |
+
return unittest.skipUnless(is_npu_available(), "test require a NPU")(test_case)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def require_neuron(test_case):
|
| 291 |
+
"""
|
| 292 |
+
Decorator marking a test that requires Neuron. These tests are skipped when there are no Neuron Cores available.
|
| 293 |
+
"""
|
| 294 |
+
return unittest.skipUnless(is_neuron_available(), "test require Neuron Cores")(test_case)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def require_mps(test_case):
|
| 298 |
+
"""
|
| 299 |
+
Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps`
|
| 300 |
+
backend.
|
| 301 |
+
"""
|
| 302 |
+
return unittest.skipUnless(is_mps_available(), "test requires a `mps` backend support in `torch`")(test_case)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def require_huggingface_suite(test_case):
|
| 306 |
+
"""
|
| 307 |
+
Decorator marking a test that requires transformers and datasets. These tests are skipped when they are not.
|
| 308 |
+
"""
|
| 309 |
+
return unittest.skipUnless(
|
| 310 |
+
is_transformers_available() and is_datasets_available(),
|
| 311 |
+
"test requires the Hugging Face suite",
|
| 312 |
+
)(test_case)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def require_transformers(test_case):
|
| 316 |
+
"""
|
| 317 |
+
Decorator marking a test that requires transformers. These tests are skipped when they are not.
|
| 318 |
+
"""
|
| 319 |
+
return unittest.skipUnless(is_transformers_available(), "test requires the transformers library")(test_case)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def require_timm(test_case):
|
| 323 |
+
"""
|
| 324 |
+
Decorator marking a test that requires timm. These tests are skipped when they are not.
|
| 325 |
+
"""
|
| 326 |
+
return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def require_torchvision(test_case):
|
| 330 |
+
"""
|
| 331 |
+
Decorator marking a test that requires torchvision. These tests are skipped when they are not.
|
| 332 |
+
"""
|
| 333 |
+
return unittest.skipUnless(is_torchvision_available(), "test requires the torchvision library")(test_case)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def require_triton(test_case):
|
| 337 |
+
"""
|
| 338 |
+
Decorator marking a test that requires triton. These tests are skipped when they are not.
|
| 339 |
+
"""
|
| 340 |
+
return unittest.skipUnless(is_triton_available(), "test requires the triton library")(test_case)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def require_schedulefree(test_case):
|
| 344 |
+
"""
|
| 345 |
+
Decorator marking a test that requires schedulefree. These tests are skipped when they are not.
|
| 346 |
+
"""
|
| 347 |
+
return unittest.skipUnless(is_schedulefree_available(), "test requires the schedulefree library")(test_case)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def require_bnb(test_case):
|
| 351 |
+
"""
|
| 352 |
+
Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not.
|
| 353 |
+
"""
|
| 354 |
+
return unittest.skipUnless(is_bnb_available(), "test requires the bitsandbytes library")(test_case)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def require_tpu(test_case):
|
| 358 |
+
"""
|
| 359 |
+
Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available.
|
| 360 |
+
"""
|
| 361 |
+
return unittest.skipUnless(is_torch_xla_available(check_is_tpu=True), "test requires TPU")(test_case)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def require_non_torch_xla(test_case):
|
| 365 |
+
"""
|
| 366 |
+
Decorator marking a test as requiring an environment without TorchXLA. These tests are skipped when TorchXLA is
|
| 367 |
+
available.
|
| 368 |
+
"""
|
| 369 |
+
return unittest.skipUnless(not is_torch_xla_available(), "test requires an env without TorchXLA")(test_case)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def require_single_device(test_case):
|
| 373 |
+
"""
|
| 374 |
+
Decorator marking a test that requires a single device. These tests are skipped when there is no hardware
|
| 375 |
+
accelerator available or number of devices is more than one.
|
| 376 |
+
"""
|
| 377 |
+
return unittest.skipUnless(
|
| 378 |
+
torch_device != "cpu" and device_count == 1, "test requires a single device accelerator"
|
| 379 |
+
)(test_case)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def require_single_gpu(test_case):
|
| 383 |
+
"""
|
| 384 |
+
Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU
|
| 385 |
+
available or number of GPUs is more than one.
|
| 386 |
+
"""
|
| 387 |
+
return unittest.skipUnless(torch.cuda.device_count() == 1, "test requires a GPU")(test_case)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def require_single_xpu(test_case):
|
| 391 |
+
"""
|
| 392 |
+
Decorator marking a test that requires CUDA on a single XPU. These tests are skipped when there are no XPU
|
| 393 |
+
available or number of xPUs is more than one.
|
| 394 |
+
"""
|
| 395 |
+
return unittest.skipUnless(torch.xpu.device_count() == 1, "test requires a XPU")(test_case)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def require_multi_device(test_case):
|
| 399 |
+
"""
|
| 400 |
+
Decorator marking a test that requires a multi-device setup. These tests are skipped on a machine without multiple
|
| 401 |
+
devices.
|
| 402 |
+
"""
|
| 403 |
+
return unittest.skipUnless(device_count > 1, "test requires multiple hardware accelerators")(test_case)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def require_multi_gpu(test_case):
|
| 407 |
+
"""
|
| 408 |
+
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
|
| 409 |
+
GPUs.
|
| 410 |
+
"""
|
| 411 |
+
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def require_multi_xpu(test_case):
|
| 415 |
+
"""
|
| 416 |
+
Decorator marking a test that requires a multi-XPU setup. These tests are skipped on a machine without multiple
|
| 417 |
+
XPUs.
|
| 418 |
+
"""
|
| 419 |
+
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def require_multi_gpu_or_xpu(test_case):
|
| 423 |
+
"""
|
| 424 |
+
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
|
| 425 |
+
GPUs or XPUs.
|
| 426 |
+
"""
|
| 427 |
+
return unittest.skipUnless(
|
| 428 |
+
(is_cuda_available() or is_xpu_available()) and device_count > 1, "test requires multiple GPUs or XPUs"
|
| 429 |
+
)(test_case)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def require_deepspeed(test_case):
|
| 433 |
+
"""
|
| 434 |
+
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
|
| 435 |
+
"""
|
| 436 |
+
return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def require_tp(test_case):
|
| 440 |
+
"""
|
| 441 |
+
Decorator marking a test that requires TP installed. These tests are skipped when TP isn't installed
|
| 442 |
+
"""
|
| 443 |
+
return unittest.skipUnless(
|
| 444 |
+
is_torch_version(">=", "2.3.0") and compare_versions("transformers", ">=", "4.52.0"),
|
| 445 |
+
"test requires torch version >= 2.3.0 and transformers version >= 4.52.0",
|
| 446 |
+
)(test_case)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def require_torch_min_version(test_case=None, version=None):
|
| 450 |
+
"""
|
| 451 |
+
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an
|
| 452 |
+
installed torch version is less than the required one.
|
| 453 |
+
"""
|
| 454 |
+
if test_case is None:
|
| 455 |
+
return partial(require_torch_min_version, version=version)
|
| 456 |
+
return unittest.skipUnless(is_torch_version(">=", version), f"test requires torch version >= {version}")(test_case)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def require_tensorboard(test_case):
|
| 460 |
+
"""
|
| 461 |
+
Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't
|
| 462 |
+
installed
|
| 463 |
+
"""
|
| 464 |
+
return unittest.skipUnless(is_tensorboard_available(), "test requires Tensorboard")(test_case)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def require_wandb(test_case):
|
| 468 |
+
"""
|
| 469 |
+
Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed
|
| 470 |
+
"""
|
| 471 |
+
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def require_trackio(test_case):
|
| 475 |
+
"""
|
| 476 |
+
Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed
|
| 477 |
+
"""
|
| 478 |
+
return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def require_comet_ml(test_case):
|
| 482 |
+
"""
|
| 483 |
+
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
|
| 484 |
+
"""
|
| 485 |
+
return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def require_aim(test_case):
|
| 489 |
+
"""
|
| 490 |
+
Decorator marking a test that requires aim installed. These tests are skipped when aim isn't installed
|
| 491 |
+
"""
|
| 492 |
+
return unittest.skipUnless(is_aim_available(), "test requires aim")(test_case)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def require_clearml(test_case):
|
| 496 |
+
"""
|
| 497 |
+
Decorator marking a test that requires clearml installed. These tests are skipped when clearml isn't installed
|
| 498 |
+
"""
|
| 499 |
+
return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def require_dvclive(test_case):
|
| 503 |
+
"""
|
| 504 |
+
Decorator marking a test that requires dvclive installed. These tests are skipped when dvclive isn't installed
|
| 505 |
+
"""
|
| 506 |
+
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def require_swanlab(test_case):
|
| 510 |
+
"""
|
| 511 |
+
Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
|
| 512 |
+
"""
|
| 513 |
+
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def require_pandas(test_case):
|
| 517 |
+
"""
|
| 518 |
+
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
|
| 519 |
+
"""
|
| 520 |
+
return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def require_mlflow(test_case):
|
| 524 |
+
"""
|
| 525 |
+
Decorator marking a test that requires mlflow installed. These tests are skipped when mlflow isn't installed
|
| 526 |
+
"""
|
| 527 |
+
return unittest.skipUnless(is_mlflow_available(), "test requires mlflow")(test_case)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def require_pippy(test_case):
|
| 531 |
+
"""
|
| 532 |
+
Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed It is
|
| 533 |
+
also checked if the test is running on a Gaudi1 device which doesn't support pippy.
|
| 534 |
+
"""
|
| 535 |
+
return unittest.skipUnless(is_pippy_available() and not is_habana_gaudi1(), "test requires pippy")(test_case)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def require_import_timer(test_case):
|
| 539 |
+
"""
|
| 540 |
+
Decorator marking a test that requires tuna interpreter installed. These tests are skipped when tuna isn't
|
| 541 |
+
installed
|
| 542 |
+
"""
|
| 543 |
+
return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def require_transformer_engine(test_case):
|
| 547 |
+
"""
|
| 548 |
+
Decorator marking a test that requires transformers engine installed. These tests are skipped when transformers
|
| 549 |
+
engine isn't installed
|
| 550 |
+
"""
|
| 551 |
+
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def require_transformer_engine_mxfp8(test_case):
|
| 555 |
+
"""
|
| 556 |
+
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
|
| 557 |
+
when transformers engine MXFP8 block scaling isn't available
|
| 558 |
+
"""
|
| 559 |
+
return unittest.skipUnless(
|
| 560 |
+
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
|
| 561 |
+
)(test_case)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def require_torchao(test_case):
|
| 565 |
+
"""
|
| 566 |
+
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
|
| 567 |
+
"""
|
| 568 |
+
return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def require_matplotlib(test_case):
|
| 572 |
+
"""
|
| 573 |
+
Decorator marking a test that requires matplotlib installed. These tests are skipped when matplotlib isn't
|
| 574 |
+
installed
|
| 575 |
+
"""
|
| 576 |
+
return unittest.skipUnless(is_matplotlib_available(), "test requires matplotlib")(test_case)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
_atleast_one_tracker_available = (
|
| 580 |
+
any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])
|
| 581 |
+
and not is_comet_ml_available()
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def require_trackers(test_case):
|
| 586 |
+
"""
|
| 587 |
+
Decorator marking that a test requires at least one tracking library installed. These tests are skipped when none
|
| 588 |
+
are installed
|
| 589 |
+
"""
|
| 590 |
+
return unittest.skipUnless(
|
| 591 |
+
_atleast_one_tracker_available,
|
| 592 |
+
"test requires at least one tracker to be available and for `comet_ml` to not be installed",
|
| 593 |
+
)(test_case)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def require_torchdata_stateful_dataloader(test_case):
|
| 597 |
+
"""
|
| 598 |
+
Decorator marking a test that requires torchdata.stateful_dataloader.
|
| 599 |
+
|
| 600 |
+
These tests are skipped when torchdata with stateful_dataloader module isn't installed.
|
| 601 |
+
|
| 602 |
+
"""
|
| 603 |
+
return unittest.skipUnless(
|
| 604 |
+
is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader"
|
| 605 |
+
)(test_case)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def run_first(test_case):
|
| 609 |
+
"""
|
| 610 |
+
Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are
|
| 611 |
+
guaranteed to run first.
|
| 612 |
+
|
| 613 |
+
This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
|
| 614 |
+
single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
|
| 615 |
+
allocation conflicts.
|
| 616 |
+
|
| 617 |
+
If pytest is not installed, test will be returned as is.
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
if is_pytest_available():
|
| 621 |
+
import pytest
|
| 622 |
+
|
| 623 |
+
return pytest.mark.order(1)(test_case)
|
| 624 |
+
return test_case
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
class TempDirTestCase(unittest.TestCase):
|
| 628 |
+
"""
|
| 629 |
+
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
|
| 630 |
+
data at the start of a test, and then destroys it at the end of the TestCase.
|
| 631 |
+
|
| 632 |
+
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
|
| 633 |
+
|
| 634 |
+
The temporary directory location will be stored in `self.tmpdir`
|
| 635 |
+
"""
|
| 636 |
+
|
| 637 |
+
clear_on_setup = True
|
| 638 |
+
|
| 639 |
+
@classmethod
|
| 640 |
+
def setUpClass(cls):
|
| 641 |
+
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
|
| 642 |
+
cls.tmpdir = Path(tempfile.mkdtemp())
|
| 643 |
+
|
| 644 |
+
@classmethod
|
| 645 |
+
def tearDownClass(cls):
|
| 646 |
+
"Remove `cls.tmpdir` after test suite has finished"
|
| 647 |
+
if os.path.exists(cls.tmpdir):
|
| 648 |
+
shutil.rmtree(cls.tmpdir)
|
| 649 |
+
|
| 650 |
+
def setUp(self):
|
| 651 |
+
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
|
| 652 |
+
if self.clear_on_setup:
|
| 653 |
+
for path in self.tmpdir.glob("**/*"):
|
| 654 |
+
if path.is_file():
|
| 655 |
+
path.unlink()
|
| 656 |
+
elif path.is_dir():
|
| 657 |
+
shutil.rmtree(path)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class AccelerateTestCase(unittest.TestCase):
|
| 661 |
+
"""
|
| 662 |
+
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
|
| 663 |
+
the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between
|
| 664 |
+
tests.
|
| 665 |
+
"""
|
| 666 |
+
|
| 667 |
+
def tearDown(self):
|
| 668 |
+
super().tearDown()
|
| 669 |
+
# Reset the state of the AcceleratorState singleton.
|
| 670 |
+
AcceleratorState._reset_state(True)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
class MockingTestCase(unittest.TestCase):
|
| 674 |
+
"""
|
| 675 |
+
A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the
|
| 676 |
+
behavior of a class-wide mock when defining one normally will not do.
|
| 677 |
+
|
| 678 |
+
Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as
|
| 679 |
+
setting an environment variable with that information.
|
| 680 |
+
|
| 681 |
+
The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to
|
| 682 |
+
`super().setUp()` such as:
|
| 683 |
+
```python
|
| 684 |
+
def setUp(self):
|
| 685 |
+
super().setUp()
|
| 686 |
+
mocks = mock.patch.dict(os.environ, {"SOME_ENV_VAR", "SOME_VALUE"})
|
| 687 |
+
self.add_mocks(mocks)
|
| 688 |
+
```
|
| 689 |
+
"""
|
| 690 |
+
|
| 691 |
+
def add_mocks(self, mocks: Union[mock.Mock, list[mock.Mock]]):
|
| 692 |
+
"""
|
| 693 |
+
Add custom mocks for tests that should be repeated on each test. Should be called during
|
| 694 |
+
`MockingTestCase.setUp`, after `super().setUp()`.
|
| 695 |
+
|
| 696 |
+
Args:
|
| 697 |
+
mocks (`mock.Mock` or list of `mock.Mock`):
|
| 698 |
+
Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run
|
| 699 |
+
"""
|
| 700 |
+
self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks]
|
| 701 |
+
for m in self.mocks:
|
| 702 |
+
m.start()
|
| 703 |
+
self.addCleanup(m.stop)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def are_the_same_tensors(tensor):
|
| 707 |
+
state = AcceleratorState()
|
| 708 |
+
tensor = tensor[None].clone().to(state.device)
|
| 709 |
+
tensors = gather(tensor).cpu()
|
| 710 |
+
tensor = tensor[0].cpu()
|
| 711 |
+
for i in range(tensors.shape[0]):
|
| 712 |
+
if not torch.equal(tensors[i], tensor):
|
| 713 |
+
return False
|
| 714 |
+
return True
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class _RunOutput:
|
| 718 |
+
def __init__(self, returncode, stdout, stderr):
|
| 719 |
+
self.returncode = returncode
|
| 720 |
+
self.stdout = stdout
|
| 721 |
+
self.stderr = stderr
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
async def _read_stream(stream, callback):
|
| 725 |
+
while True:
|
| 726 |
+
line = await stream.readline()
|
| 727 |
+
if line:
|
| 728 |
+
callback(line)
|
| 729 |
+
else:
|
| 730 |
+
break
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
| 734 |
+
if echo:
|
| 735 |
+
print("\nRunning: ", " ".join(cmd))
|
| 736 |
+
|
| 737 |
+
p = await asyncio.create_subprocess_exec(
|
| 738 |
+
cmd[0],
|
| 739 |
+
*cmd[1:],
|
| 740 |
+
stdin=stdin,
|
| 741 |
+
stdout=asyncio.subprocess.PIPE,
|
| 742 |
+
stderr=asyncio.subprocess.PIPE,
|
| 743 |
+
env=env,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
# note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
| 747 |
+
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
| 748 |
+
#
|
| 749 |
+
# If it starts hanging, will need to switch to the following code. The problem is that no data
|
| 750 |
+
# will be seen until it's done and if it hangs for example there will be no debug info.
|
| 751 |
+
# out, err = await p.communicate()
|
| 752 |
+
# return _RunOutput(p.returncode, out, err)
|
| 753 |
+
|
| 754 |
+
out = []
|
| 755 |
+
err = []
|
| 756 |
+
|
| 757 |
+
def tee(line, sink, pipe, label=""):
|
| 758 |
+
line = line.decode("utf-8").rstrip()
|
| 759 |
+
sink.append(line)
|
| 760 |
+
if not quiet:
|
| 761 |
+
print(label, line, file=pipe)
|
| 762 |
+
|
| 763 |
+
# XXX: the timeout doesn't seem to make any difference here
|
| 764 |
+
await asyncio.wait(
|
| 765 |
+
[
|
| 766 |
+
asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
|
| 767 |
+
asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
|
| 768 |
+
],
|
| 769 |
+
timeout=timeout,
|
| 770 |
+
)
|
| 771 |
+
return _RunOutput(await p.wait(), out, err)
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
|
| 775 |
+
# Cast every path in `cmd` to a string
|
| 776 |
+
for i, c in enumerate(cmd):
|
| 777 |
+
if isinstance(c, Path):
|
| 778 |
+
cmd[i] = str(c)
|
| 779 |
+
|
| 780 |
+
result = asyncio.run(_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo))
|
| 781 |
+
|
| 782 |
+
cmd_str = " ".join(cmd)
|
| 783 |
+
if result.returncode > 0:
|
| 784 |
+
stderr = "\n".join(result.stderr)
|
| 785 |
+
raise RuntimeError(
|
| 786 |
+
f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
|
| 787 |
+
f"The combined stderr from workers follows:\n{stderr}"
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
return result
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def pytest_xdist_worker_id():
|
| 794 |
+
"""
|
| 795 |
+
Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
|
| 796 |
+
if `-n 1` or `pytest-xdist` isn't being used.
|
| 797 |
+
"""
|
| 798 |
+
worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
|
| 799 |
+
worker = re.sub(r"^gw", "", worker, 0, re.M)
|
| 800 |
+
return int(worker)
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def get_torch_dist_unique_port():
|
| 804 |
+
"""
|
| 805 |
+
Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
|
| 806 |
+
|
| 807 |
+
Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
|
| 808 |
+
port at once.
|
| 809 |
+
"""
|
| 810 |
+
port = 29500
|
| 811 |
+
uniq_delta = pytest_xdist_worker_id()
|
| 812 |
+
return port + uniq_delta
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
class SubprocessCallException(Exception):
|
| 816 |
+
pass
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def run_command(command: list[str], return_stdout=False, env=None):
|
| 820 |
+
"""
|
| 821 |
+
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
| 822 |
+
if an error occurred while running `command`
|
| 823 |
+
"""
|
| 824 |
+
# Cast every path in `command` to a string
|
| 825 |
+
for i, c in enumerate(command):
|
| 826 |
+
if isinstance(c, Path):
|
| 827 |
+
command[i] = str(c)
|
| 828 |
+
if env is None:
|
| 829 |
+
env = os.environ.copy()
|
| 830 |
+
try:
|
| 831 |
+
output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
|
| 832 |
+
if return_stdout:
|
| 833 |
+
if hasattr(output, "decode"):
|
| 834 |
+
output = output.decode("utf-8")
|
| 835 |
+
return output
|
| 836 |
+
except subprocess.CalledProcessError as e:
|
| 837 |
+
raise SubprocessCallException(
|
| 838 |
+
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
| 839 |
+
) from e
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
def path_in_accelerate_package(*components: str) -> Path:
|
| 843 |
+
"""
|
| 844 |
+
Get a path within the `accelerate` package's directory.
|
| 845 |
+
|
| 846 |
+
Args:
|
| 847 |
+
*components: Components of the path to join after the package directory.
|
| 848 |
+
|
| 849 |
+
Returns:
|
| 850 |
+
`Path`: The path to the requested file or directory.
|
| 851 |
+
"""
|
| 852 |
+
|
| 853 |
+
accelerate_package_dir = Path(inspect.getfile(accelerate)).parent
|
| 854 |
+
return accelerate_package_dir.joinpath(*components)
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
@contextmanager
|
| 858 |
+
def assert_exception(exception_class: Exception, msg: Optional[str] = None) -> bool:
|
| 859 |
+
"""
|
| 860 |
+
Context manager to assert that the right `Exception` class was raised.
|
| 861 |
+
|
| 862 |
+
If `msg` is provided, will check that the message is contained in the raised exception.
|
| 863 |
+
"""
|
| 864 |
+
was_ran = False
|
| 865 |
+
try:
|
| 866 |
+
yield
|
| 867 |
+
was_ran = True
|
| 868 |
+
except Exception as e:
|
| 869 |
+
assert isinstance(e, exception_class), f"Expected exception of type {exception_class} but got {type(e)}"
|
| 870 |
+
if msg is not None:
|
| 871 |
+
assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'"
|
| 872 |
+
if was_ran:
|
| 873 |
+
raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.")
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
def capture_call_output(func, *args, **kwargs):
|
| 877 |
+
"""
|
| 878 |
+
Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string
|
| 879 |
+
"""
|
| 880 |
+
captured_output = io.StringIO()
|
| 881 |
+
original_stdout = sys.stdout
|
| 882 |
+
try:
|
| 883 |
+
sys.stdout = captured_output
|
| 884 |
+
func(*args, **kwargs)
|
| 885 |
+
except Exception as e:
|
| 886 |
+
raise e
|
| 887 |
+
finally:
|
| 888 |
+
sys.stdout = original_stdout
|
| 889 |
+
return captured_output.getvalue()
|
accelerate/test_utils/training.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
|
| 19 |
+
from accelerate.utils.dataclasses import DistributedType
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RegressionDataset:
|
| 23 |
+
def __init__(self, a=2, b=3, length=64, seed=None):
|
| 24 |
+
rng = np.random.default_rng(seed)
|
| 25 |
+
self.length = length
|
| 26 |
+
self.x = rng.normal(size=(length,)).astype(np.float32)
|
| 27 |
+
self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32)
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return self.length
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, i):
|
| 33 |
+
return {"x": self.x[i], "y": self.y[i]}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RegressionModel(torch.nn.Module):
|
| 37 |
+
def __init__(self, a=0, b=0, double_output=False):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.a = torch.nn.Parameter(torch.tensor(a).float())
|
| 40 |
+
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
| 41 |
+
self.first_batch = True
|
| 42 |
+
|
| 43 |
+
def forward(self, x=None):
|
| 44 |
+
if self.first_batch:
|
| 45 |
+
print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}")
|
| 46 |
+
self.first_batch = False
|
| 47 |
+
return x * self.a + self.b
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def mocked_dataloaders(accelerator, batch_size: int = 16):
|
| 51 |
+
from datasets import load_dataset
|
| 52 |
+
from transformers import AutoTokenizer
|
| 53 |
+
|
| 54 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
| 55 |
+
data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
|
| 56 |
+
datasets = load_dataset("csv", data_files=data_files)
|
| 57 |
+
label_list = datasets["train"].unique("label")
|
| 58 |
+
|
| 59 |
+
label_to_id = {v: i for i, v in enumerate(label_list)}
|
| 60 |
+
|
| 61 |
+
def tokenize_function(examples):
|
| 62 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 63 |
+
outputs = tokenizer(
|
| 64 |
+
examples["sentence1"], examples["sentence2"], truncation=True, max_length=None, padding="max_length"
|
| 65 |
+
)
|
| 66 |
+
if "label" in examples:
|
| 67 |
+
outputs["labels"] = [label_to_id[l] for l in examples["label"]]
|
| 68 |
+
return outputs
|
| 69 |
+
|
| 70 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 71 |
+
tokenized_datasets = datasets.map(
|
| 72 |
+
tokenize_function,
|
| 73 |
+
batched=True,
|
| 74 |
+
remove_columns=["sentence1", "sentence2", "label"],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def collate_fn(examples):
|
| 78 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 79 |
+
if accelerator.distributed_type == DistributedType.XLA:
|
| 80 |
+
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
| 81 |
+
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
| 82 |
+
|
| 83 |
+
# Instantiate dataloaders.
|
| 84 |
+
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=2)
|
| 85 |
+
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)
|
| 86 |
+
|
| 87 |
+
return train_dataloader, eval_dataloader
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def mocked_dataloaders_for_autoregressive_models(accelerator, batch_size: int = 16):
|
| 91 |
+
from datasets import load_dataset
|
| 92 |
+
from transformers import AutoTokenizer
|
| 93 |
+
|
| 94 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M")
|
| 95 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 96 |
+
|
| 97 |
+
data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
|
| 98 |
+
datasets = load_dataset("csv", data_files=data_files)
|
| 99 |
+
|
| 100 |
+
def tokenize_function(examples):
|
| 101 |
+
# max_length=None => use the model max length (it's actually the default)
|
| 102 |
+
outputs = tokenizer(examples["sentence1"], truncation=True, max_length=None, return_attention_mask=False)
|
| 103 |
+
return outputs
|
| 104 |
+
|
| 105 |
+
# Apply the method we just defined to all the examples in all the splits of the dataset
|
| 106 |
+
# starting with the main process first:
|
| 107 |
+
with accelerator.main_process_first():
|
| 108 |
+
tokenized_datasets = datasets.map(
|
| 109 |
+
tokenize_function,
|
| 110 |
+
batched=True,
|
| 111 |
+
remove_columns=["sentence1", "sentence2", "label"],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def collate_fn(examples):
|
| 115 |
+
# On TPU it's best to pad everything to the same length or training will be very slow.
|
| 116 |
+
max_length = (
|
| 117 |
+
128
|
| 118 |
+
if accelerator.distributed_type == DistributedType.XLA
|
| 119 |
+
else max([len(e["input_ids"]) for e in examples])
|
| 120 |
+
)
|
| 121 |
+
# When using mixed precision we want round multiples of 8/16
|
| 122 |
+
if accelerator.mixed_precision == "fp8":
|
| 123 |
+
pad_to_multiple_of = 16
|
| 124 |
+
elif accelerator.mixed_precision != "no":
|
| 125 |
+
pad_to_multiple_of = 8
|
| 126 |
+
else:
|
| 127 |
+
pad_to_multiple_of = None
|
| 128 |
+
|
| 129 |
+
batch = tokenizer.pad(
|
| 130 |
+
examples,
|
| 131 |
+
padding="max_length",
|
| 132 |
+
max_length=max_length + 1,
|
| 133 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 134 |
+
return_tensors="pt",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
batch["labels"] = batch["input_ids"][:, 1:]
|
| 138 |
+
batch["input_ids"] = batch["input_ids"][:, :-1]
|
| 139 |
+
if "attention_mask" in batch:
|
| 140 |
+
batch["attention_mask"] = batch["attention_mask"][:, :-1]
|
| 141 |
+
|
| 142 |
+
batch["labels"] = torch.where(batch["labels"] == tokenizer.pad_token_id, -100, batch["labels"])
|
| 143 |
+
|
| 144 |
+
return batch
|
| 145 |
+
|
| 146 |
+
# Instantiate dataloaders.
|
| 147 |
+
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=2)
|
| 148 |
+
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)
|
| 149 |
+
|
| 150 |
+
return train_dataloader, eval_dataloader
|
accelerate/utils/__init__.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from ..parallelism_config import ParallelismConfig
|
| 15 |
+
from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
|
| 16 |
+
from .constants import (
|
| 17 |
+
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
|
| 18 |
+
MODEL_NAME,
|
| 19 |
+
OPTIMIZER_NAME,
|
| 20 |
+
PROFILE_PATTERN_NAME,
|
| 21 |
+
RNG_STATE_NAME,
|
| 22 |
+
SAFE_MODEL_NAME,
|
| 23 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
| 24 |
+
SAFE_WEIGHTS_NAME,
|
| 25 |
+
SAFE_WEIGHTS_PATTERN_NAME,
|
| 26 |
+
SAMPLER_NAME,
|
| 27 |
+
SCALER_NAME,
|
| 28 |
+
SCHEDULER_NAME,
|
| 29 |
+
TORCH_DISTRIBUTED_OPERATION_TYPES,
|
| 30 |
+
TORCH_LAUNCH_PARAMS,
|
| 31 |
+
WEIGHTS_INDEX_NAME,
|
| 32 |
+
WEIGHTS_NAME,
|
| 33 |
+
WEIGHTS_PATTERN_NAME,
|
| 34 |
+
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
|
| 35 |
+
)
|
| 36 |
+
from .dataclasses import (
|
| 37 |
+
AORecipeKwargs,
|
| 38 |
+
AutocastKwargs,
|
| 39 |
+
BnbQuantizationConfig,
|
| 40 |
+
ComputeEnvironment,
|
| 41 |
+
CustomDtype,
|
| 42 |
+
DataLoaderConfiguration,
|
| 43 |
+
DDPCommunicationHookType,
|
| 44 |
+
DeepSpeedPlugin,
|
| 45 |
+
DeepSpeedSequenceParallelConfig,
|
| 46 |
+
DistributedDataParallelKwargs,
|
| 47 |
+
DistributedType,
|
| 48 |
+
DynamoBackend,
|
| 49 |
+
FP8RecipeKwargs,
|
| 50 |
+
FullyShardedDataParallelPlugin,
|
| 51 |
+
GradientAccumulationPlugin,
|
| 52 |
+
GradScalerKwargs,
|
| 53 |
+
InitProcessGroupKwargs,
|
| 54 |
+
KwargsHandler,
|
| 55 |
+
LoggerType,
|
| 56 |
+
MegatronLMPlugin,
|
| 57 |
+
MSAMPRecipeKwargs,
|
| 58 |
+
PrecisionType,
|
| 59 |
+
ProfileKwargs,
|
| 60 |
+
ProjectConfiguration,
|
| 61 |
+
RNGType,
|
| 62 |
+
SageMakerDistributedType,
|
| 63 |
+
TensorInformation,
|
| 64 |
+
TERecipeKwargs,
|
| 65 |
+
TorchContextParallelConfig,
|
| 66 |
+
TorchDynamoPlugin,
|
| 67 |
+
TorchTensorParallelConfig,
|
| 68 |
+
TorchTensorParallelPlugin,
|
| 69 |
+
add_model_config_to_megatron_parser,
|
| 70 |
+
)
|
| 71 |
+
from .environment import (
|
| 72 |
+
are_libraries_initialized,
|
| 73 |
+
check_cuda_fp8_capability,
|
| 74 |
+
check_cuda_p2p_ib_support,
|
| 75 |
+
clear_environment,
|
| 76 |
+
convert_dict_to_env_variables,
|
| 77 |
+
get_cpu_distributed_information,
|
| 78 |
+
get_current_device_type,
|
| 79 |
+
get_gpu_info,
|
| 80 |
+
get_int_from_env,
|
| 81 |
+
parse_choice_from_env,
|
| 82 |
+
parse_flag_from_env,
|
| 83 |
+
patch_environment,
|
| 84 |
+
purge_accelerate_environment,
|
| 85 |
+
set_numa_affinity,
|
| 86 |
+
str_to_bool,
|
| 87 |
+
)
|
| 88 |
+
from .imports import (
|
| 89 |
+
deepspeed_required,
|
| 90 |
+
is_4bit_bnb_available,
|
| 91 |
+
is_8bit_bnb_available,
|
| 92 |
+
is_aim_available,
|
| 93 |
+
is_bf16_available,
|
| 94 |
+
is_bitsandbytes_multi_backend_available,
|
| 95 |
+
is_bnb_available,
|
| 96 |
+
is_boto3_available,
|
| 97 |
+
is_clearml_available,
|
| 98 |
+
is_comet_ml_available,
|
| 99 |
+
is_cuda_available,
|
| 100 |
+
is_datasets_available,
|
| 101 |
+
is_deepspeed_available,
|
| 102 |
+
is_dvclive_available,
|
| 103 |
+
is_fp8_available,
|
| 104 |
+
is_fp16_available,
|
| 105 |
+
is_habana_gaudi1,
|
| 106 |
+
is_hpu_available,
|
| 107 |
+
is_import_timer_available,
|
| 108 |
+
is_lomo_available,
|
| 109 |
+
is_matplotlib_available,
|
| 110 |
+
is_megatron_lm_available,
|
| 111 |
+
is_mlflow_available,
|
| 112 |
+
is_mlu_available,
|
| 113 |
+
is_mps_available,
|
| 114 |
+
is_msamp_available,
|
| 115 |
+
is_musa_available,
|
| 116 |
+
is_neuron_available,
|
| 117 |
+
is_npu_available,
|
| 118 |
+
is_pandas_available,
|
| 119 |
+
is_peft_available,
|
| 120 |
+
is_pippy_available,
|
| 121 |
+
is_pynvml_available,
|
| 122 |
+
is_pytest_available,
|
| 123 |
+
is_rich_available,
|
| 124 |
+
is_sagemaker_available,
|
| 125 |
+
is_schedulefree_available,
|
| 126 |
+
is_sdaa_available,
|
| 127 |
+
is_swanlab_available,
|
| 128 |
+
is_tensorboard_available,
|
| 129 |
+
is_timm_available,
|
| 130 |
+
is_torch_xla_available,
|
| 131 |
+
is_torchao_available,
|
| 132 |
+
is_torchdata_available,
|
| 133 |
+
is_torchdata_stateful_dataloader_available,
|
| 134 |
+
is_torchvision_available,
|
| 135 |
+
is_trackio_available,
|
| 136 |
+
is_transformer_engine_available,
|
| 137 |
+
is_transformer_engine_mxfp8_available,
|
| 138 |
+
is_transformers_available,
|
| 139 |
+
is_triton_available,
|
| 140 |
+
is_wandb_available,
|
| 141 |
+
is_weights_only_available,
|
| 142 |
+
is_xccl_available,
|
| 143 |
+
is_xpu_available,
|
| 144 |
+
torchao_required,
|
| 145 |
+
)
|
| 146 |
+
from .modeling import (
|
| 147 |
+
align_module_device,
|
| 148 |
+
calculate_maximum_sizes,
|
| 149 |
+
check_device_map,
|
| 150 |
+
check_tied_parameters_in_config,
|
| 151 |
+
check_tied_parameters_on_same_device,
|
| 152 |
+
compute_module_sizes,
|
| 153 |
+
convert_file_size_to_int,
|
| 154 |
+
dtype_byte_size,
|
| 155 |
+
find_tied_parameters,
|
| 156 |
+
get_balanced_memory,
|
| 157 |
+
get_grad_scaler,
|
| 158 |
+
get_max_layer_size,
|
| 159 |
+
get_max_memory,
|
| 160 |
+
get_mixed_precision_context_manager,
|
| 161 |
+
has_offloaded_params,
|
| 162 |
+
id_tensor_storage,
|
| 163 |
+
infer_auto_device_map,
|
| 164 |
+
is_peft_model,
|
| 165 |
+
load_checkpoint_in_model,
|
| 166 |
+
load_offloaded_weights,
|
| 167 |
+
load_state_dict,
|
| 168 |
+
named_module_tensors,
|
| 169 |
+
retie_parameters,
|
| 170 |
+
set_module_tensor_to_device,
|
| 171 |
+
)
|
| 172 |
+
from .offload import (
|
| 173 |
+
OffloadedWeightsLoader,
|
| 174 |
+
PrefixedDataset,
|
| 175 |
+
extract_submodules_state_dict,
|
| 176 |
+
load_offloaded_weight,
|
| 177 |
+
offload_state_dict,
|
| 178 |
+
offload_weight,
|
| 179 |
+
save_offload_index,
|
| 180 |
+
)
|
| 181 |
+
from .operations import (
|
| 182 |
+
CannotPadNestedTensorWarning,
|
| 183 |
+
GatheredParameters,
|
| 184 |
+
broadcast,
|
| 185 |
+
broadcast_object_list,
|
| 186 |
+
concatenate,
|
| 187 |
+
convert_outputs_to_fp32,
|
| 188 |
+
convert_to_fp32,
|
| 189 |
+
copy_tensor_to_devices,
|
| 190 |
+
find_batch_size,
|
| 191 |
+
find_device,
|
| 192 |
+
gather,
|
| 193 |
+
gather_object,
|
| 194 |
+
get_data_structure,
|
| 195 |
+
honor_type,
|
| 196 |
+
ignorant_find_batch_size,
|
| 197 |
+
initialize_tensors,
|
| 198 |
+
is_namedtuple,
|
| 199 |
+
is_tensor_information,
|
| 200 |
+
is_torch_tensor,
|
| 201 |
+
listify,
|
| 202 |
+
pad_across_processes,
|
| 203 |
+
pad_input_tensors,
|
| 204 |
+
recursively_apply,
|
| 205 |
+
reduce,
|
| 206 |
+
send_to_device,
|
| 207 |
+
slice_tensors,
|
| 208 |
+
)
|
| 209 |
+
from .versions import compare_versions, is_torch_version
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if is_deepspeed_available():
|
| 213 |
+
from .deepspeed import (
|
| 214 |
+
DeepSpeedEngineWrapper,
|
| 215 |
+
DeepSpeedOptimizerWrapper,
|
| 216 |
+
DeepSpeedSchedulerWrapper,
|
| 217 |
+
DummyOptim,
|
| 218 |
+
DummyScheduler,
|
| 219 |
+
HfDeepSpeedConfig,
|
| 220 |
+
get_active_deepspeed_plugin,
|
| 221 |
+
map_pytorch_optim_to_deepspeed,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
from .bnb import has_4bit_bnb_layers, load_and_quantize_model
|
| 225 |
+
from .fsdp_utils import (
|
| 226 |
+
disable_fsdp_ram_efficient_loading,
|
| 227 |
+
enable_fsdp_ram_efficient_loading,
|
| 228 |
+
ensure_weights_retied,
|
| 229 |
+
fsdp2_apply_ac,
|
| 230 |
+
fsdp2_canonicalize_names,
|
| 231 |
+
fsdp2_load_full_state_dict,
|
| 232 |
+
fsdp2_prepare_model,
|
| 233 |
+
fsdp2_switch_optimizer_parameters,
|
| 234 |
+
get_fsdp2_grad_scaler,
|
| 235 |
+
load_fsdp_model,
|
| 236 |
+
load_fsdp_optimizer,
|
| 237 |
+
merge_fsdp_weights,
|
| 238 |
+
save_fsdp_model,
|
| 239 |
+
save_fsdp_optimizer,
|
| 240 |
+
)
|
| 241 |
+
from .launch import (
|
| 242 |
+
PrepareForLaunch,
|
| 243 |
+
_filter_args,
|
| 244 |
+
prepare_deepspeed_cmd_env,
|
| 245 |
+
prepare_multi_gpu_env,
|
| 246 |
+
prepare_sagemager_args_inputs,
|
| 247 |
+
prepare_simple_launcher_cmd_env,
|
| 248 |
+
prepare_tpu,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# For docs
|
| 252 |
+
from .megatron_lm import (
|
| 253 |
+
AbstractTrainStep,
|
| 254 |
+
BertTrainStep,
|
| 255 |
+
GPTTrainStep,
|
| 256 |
+
MegatronLMDummyDataLoader,
|
| 257 |
+
MegatronLMDummyScheduler,
|
| 258 |
+
T5TrainStep,
|
| 259 |
+
avg_losses_across_data_parallel_group,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if is_megatron_lm_available():
|
| 264 |
+
from .megatron_lm import (
|
| 265 |
+
MegatronEngine,
|
| 266 |
+
MegatronLMOptimizerWrapper,
|
| 267 |
+
MegatronLMSchedulerWrapper,
|
| 268 |
+
gather_across_data_parallel_groups,
|
| 269 |
+
)
|
| 270 |
+
from .megatron_lm import initialize as megatron_lm_initialize
|
| 271 |
+
from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
|
| 272 |
+
from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
|
| 273 |
+
from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
|
| 274 |
+
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
|
| 275 |
+
from .memory import find_executable_batch_size, release_memory
|
| 276 |
+
from .other import (
|
| 277 |
+
check_os_kernel,
|
| 278 |
+
clean_state_dict_for_safetensors,
|
| 279 |
+
compile_regions,
|
| 280 |
+
compile_regions_deepspeed,
|
| 281 |
+
convert_bytes,
|
| 282 |
+
extract_model_from_parallel,
|
| 283 |
+
get_module_children_bottom_up,
|
| 284 |
+
get_pretty_name,
|
| 285 |
+
has_compiled_regions,
|
| 286 |
+
is_compiled_module,
|
| 287 |
+
is_port_in_use,
|
| 288 |
+
load,
|
| 289 |
+
merge_dicts,
|
| 290 |
+
model_has_dtensor,
|
| 291 |
+
recursive_getattr,
|
| 292 |
+
save,
|
| 293 |
+
wait_for_everyone,
|
| 294 |
+
write_basic_config,
|
| 295 |
+
)
|
| 296 |
+
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
|
| 297 |
+
from .torch_xla import install_xla
|
| 298 |
+
from .tqdm import tqdm
|
| 299 |
+
from .transformer_engine import (
|
| 300 |
+
apply_fp8_autowrap,
|
| 301 |
+
contextual_fp8_autocast,
|
| 302 |
+
convert_model,
|
| 303 |
+
has_transformer_engine_layers,
|
| 304 |
+
)
|
accelerate/utils/ao.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Needed utilities for torchao FP8 training.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from functools import partial
|
| 20 |
+
from typing import TYPE_CHECKING, Callable, Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from .imports import is_torchao_available, torchao_required
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
if is_torchao_available():
|
| 29 |
+
from torchao.float8.float8_linear import Float8LinearConfig
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def find_first_last_linear_layers(model: torch.nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Finds the first and last linear layer names in a model.
|
| 35 |
+
|
| 36 |
+
This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
|
| 37 |
+
|
| 38 |
+
Ref: https://x.com/xariusrke/status/1826669142604141052
|
| 39 |
+
"""
|
| 40 |
+
first_linear, last_linear = None, None
|
| 41 |
+
for name, module in model.named_modules():
|
| 42 |
+
if isinstance(module, torch.nn.Linear):
|
| 43 |
+
if first_linear is None:
|
| 44 |
+
first_linear = name
|
| 45 |
+
last_linear = name
|
| 46 |
+
return first_linear, last_linear
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:
|
| 50 |
+
"""
|
| 51 |
+
A function which will check if `module` is:
|
| 52 |
+
- a `torch.nn.Linear` layer
|
| 53 |
+
- has in_features and out_features divisible by 16
|
| 54 |
+
- is not part of `layers_to_filter`
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
module (`torch.nn.Module`):
|
| 58 |
+
The module to check.
|
| 59 |
+
fqn (`str`):
|
| 60 |
+
The fully qualified name of the layer.
|
| 61 |
+
layers_to_filter (`List[str]`):
|
| 62 |
+
The list of layers to filter.
|
| 63 |
+
"""
|
| 64 |
+
if isinstance(module, torch.nn.Linear):
|
| 65 |
+
if module.in_features % 16 != 0 or module.out_features % 16 != 0:
|
| 66 |
+
return False
|
| 67 |
+
if fqn in layers_to_filter:
|
| 68 |
+
return False
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
|
| 73 |
+
"""
|
| 74 |
+
A filter function which will filter out all linear layers except the first and last.
|
| 75 |
+
|
| 76 |
+
<Tip>
|
| 77 |
+
|
| 78 |
+
For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
|
| 79 |
+
converging properly
|
| 80 |
+
|
| 81 |
+
</Tip>
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
module (`torch.nn.Module`):
|
| 85 |
+
The module to check.
|
| 86 |
+
fqn (`str`):
|
| 87 |
+
The fully qualified name of the layer.
|
| 88 |
+
"""
|
| 89 |
+
first_linear, last_linear = find_first_last_linear_layers(module)
|
| 90 |
+
return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@torchao_required
|
| 94 |
+
def has_ao_layers(model: torch.nn.Module):
|
| 95 |
+
from torchao.float8.float8_linear import Float8Linear
|
| 96 |
+
|
| 97 |
+
for name, module in model.named_modules():
|
| 98 |
+
if isinstance(module, Float8Linear):
|
| 99 |
+
return True
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@torchao_required
|
| 104 |
+
def convert_model_to_fp8_ao(
|
| 105 |
+
model: torch.nn.Module,
|
| 106 |
+
config: Optional["Float8LinearConfig"] = None,
|
| 107 |
+
module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
model (`torch.nn.Module`):
|
| 114 |
+
The model to convert.
|
| 115 |
+
config (`torchao.float8.Float8LinearConfig`, *optional*):
|
| 116 |
+
The configuration for the FP8 training. Recommended to utilize
|
| 117 |
+
`torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
|
| 118 |
+
sufficient (what is passed when set to `None`).
|
| 119 |
+
module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
|
| 120 |
+
Optional function that must take in a module and layer name, and returns a boolean indicating whether the
|
| 121 |
+
module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
|
| 122 |
+
|
| 123 |
+
Example:
|
| 124 |
+
|
| 125 |
+
```python
|
| 126 |
+
from accelerate.utils.ao import convert_model_to_fp8_ao
|
| 127 |
+
from accelerate import Accelerator
|
| 128 |
+
|
| 129 |
+
accelerator = Accelerator(
|
| 130 |
+
|
| 131 |
+
model = MyModel()
|
| 132 |
+
model.to(accelerator.device)
|
| 133 |
+
convert_to_float8_training(model)
|
| 134 |
+
|
| 135 |
+
model.train()
|
| 136 |
+
```
|
| 137 |
+
"""
|
| 138 |
+
from torchao.float8 import convert_to_float8_training
|
| 139 |
+
|
| 140 |
+
first_linear, last_linear = find_first_last_linear_layers(model)
|
| 141 |
+
if module_filter_func is None:
|
| 142 |
+
module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])
|
| 143 |
+
convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)
|
accelerate/utils/bnb.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
|
| 24 |
+
from accelerate.utils.imports import (
|
| 25 |
+
is_4bit_bnb_available,
|
| 26 |
+
is_8bit_bnb_available,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from ..big_modeling import dispatch_model, init_empty_weights
|
| 30 |
+
from .dataclasses import BnbQuantizationConfig
|
| 31 |
+
from .modeling import (
|
| 32 |
+
find_tied_parameters,
|
| 33 |
+
get_balanced_memory,
|
| 34 |
+
infer_auto_device_map,
|
| 35 |
+
load_checkpoint_in_model,
|
| 36 |
+
offload_weight,
|
| 37 |
+
set_module_tensor_to_device,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_and_quantize_model(
|
| 45 |
+
model: torch.nn.Module,
|
| 46 |
+
bnb_quantization_config: BnbQuantizationConfig,
|
| 47 |
+
weights_location: Optional[Union[str, os.PathLike]] = None,
|
| 48 |
+
device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
|
| 49 |
+
no_split_module_classes: Optional[list[str]] = None,
|
| 50 |
+
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
|
| 51 |
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
| 52 |
+
offload_state_dict: bool = False,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the
|
| 56 |
+
model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the
|
| 57 |
+
model is already loaded, we will quantize the model and put the model on the GPU,
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model (`torch.nn.Module`):
|
| 61 |
+
Input model. The model can be already loaded or on the meta device
|
| 62 |
+
bnb_quantization_config (`BnbQuantizationConfig`):
|
| 63 |
+
The bitsandbytes quantization parameters
|
| 64 |
+
weights_location (`str` or `os.PathLike`):
|
| 65 |
+
The folder weights_location to load. It can be:
|
| 66 |
+
- a path to a file containing a whole model state dict
|
| 67 |
+
- a path to a `.json` file containing the index to a sharded checkpoint
|
| 68 |
+
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
| 69 |
+
- a path to a folder containing a unique pytorch_model.bin file.
|
| 70 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 71 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
| 72 |
+
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
| 73 |
+
no_split_module_classes (`List[str]`, *optional*):
|
| 74 |
+
A list of layer class names that should never be split across device (for instance any layer that has a
|
| 75 |
+
residual connection).
|
| 76 |
+
max_memory (`Dict`, *optional*):
|
| 77 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
| 78 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
| 79 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
| 80 |
+
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
| 81 |
+
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
|
| 82 |
+
the weight of the CPU state dict + the biggest shard does not fit.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
`torch.nn.Module`: The quantized model
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
load_in_4bit = bnb_quantization_config.load_in_4bit
|
| 89 |
+
load_in_8bit = bnb_quantization_config.load_in_8bit
|
| 90 |
+
|
| 91 |
+
if load_in_8bit and not is_8bit_bnb_available():
|
| 92 |
+
raise ImportError(
|
| 93 |
+
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
|
| 94 |
+
" make sure you have the latest version of `bitsandbytes` installed."
|
| 95 |
+
)
|
| 96 |
+
if load_in_4bit and not is_4bit_bnb_available():
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
|
| 99 |
+
"make sure you have the latest version of `bitsandbytes` installed."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
modules_on_cpu = []
|
| 103 |
+
# custom device map
|
| 104 |
+
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
| 105 |
+
modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
|
| 106 |
+
|
| 107 |
+
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
| 108 |
+
if bnb_quantization_config.skip_modules is None:
|
| 109 |
+
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
|
| 110 |
+
|
| 111 |
+
# add cpu modules to skip modules only for 4-bit modules
|
| 112 |
+
if load_in_4bit:
|
| 113 |
+
bnb_quantization_config.skip_modules.extend(modules_on_cpu)
|
| 114 |
+
modules_to_not_convert = bnb_quantization_config.skip_modules
|
| 115 |
+
|
| 116 |
+
# We add the modules we want to keep in full precision
|
| 117 |
+
if bnb_quantization_config.keep_in_fp32_modules is None:
|
| 118 |
+
bnb_quantization_config.keep_in_fp32_modules = []
|
| 119 |
+
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
|
| 120 |
+
modules_to_not_convert.extend(keep_in_fp32_modules)
|
| 121 |
+
|
| 122 |
+
# compatibility with peft
|
| 123 |
+
model.is_loaded_in_4bit = load_in_4bit
|
| 124 |
+
model.is_loaded_in_8bit = load_in_8bit
|
| 125 |
+
|
| 126 |
+
model_device = get_parameter_device(model)
|
| 127 |
+
if model_device.type != "meta":
|
| 128 |
+
# quantization of an already loaded model
|
| 129 |
+
logger.warning(
|
| 130 |
+
"It is not recommended to quantize a loaded model. "
|
| 131 |
+
"The model should be instantiated under the `init_empty_weights` context manager."
|
| 132 |
+
)
|
| 133 |
+
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
|
| 134 |
+
# convert param to the right dtype
|
| 135 |
+
dtype = bnb_quantization_config.torch_dtype
|
| 136 |
+
for name, param in model.named_parameters():
|
| 137 |
+
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
| 138 |
+
param.data = param.data.to(torch.float32)
|
| 139 |
+
elif torch.is_floating_point(param):
|
| 140 |
+
param.data = param.data.to(dtype)
|
| 141 |
+
if model_device.type == "cuda":
|
| 142 |
+
model.cuda(torch.cuda.current_device())
|
| 143 |
+
torch.cuda.empty_cache()
|
| 144 |
+
elif torch.cuda.is_available():
|
| 145 |
+
model.to(torch.cuda.current_device())
|
| 146 |
+
elif torch.xpu.is_available():
|
| 147 |
+
model.to(torch.xpu.current_device())
|
| 148 |
+
else:
|
| 149 |
+
raise RuntimeError("No GPU or Intel XPU found. A GPU or Intel XPU is needed for quantization.")
|
| 150 |
+
logger.info(
|
| 151 |
+
f"The model device type is {model_device.type}. However, gpu or intel xpu is needed for quantization."
|
| 152 |
+
"We move the model to it."
|
| 153 |
+
)
|
| 154 |
+
return model
|
| 155 |
+
|
| 156 |
+
elif weights_location is None:
|
| 157 |
+
raise RuntimeError(
|
| 158 |
+
f"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} "
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
else:
|
| 162 |
+
with init_empty_weights():
|
| 163 |
+
model = replace_with_bnb_layers(
|
| 164 |
+
model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert
|
| 165 |
+
)
|
| 166 |
+
device_map = get_quantized_model_device_map(
|
| 167 |
+
model,
|
| 168 |
+
bnb_quantization_config,
|
| 169 |
+
device_map,
|
| 170 |
+
max_memory=max_memory,
|
| 171 |
+
no_split_module_classes=no_split_module_classes,
|
| 172 |
+
)
|
| 173 |
+
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
|
| 174 |
+
offload_state_dict = True
|
| 175 |
+
|
| 176 |
+
offload = any(x in list(device_map.values()) for x in ["cpu", "disk"])
|
| 177 |
+
|
| 178 |
+
load_checkpoint_in_model(
|
| 179 |
+
model,
|
| 180 |
+
weights_location,
|
| 181 |
+
device_map,
|
| 182 |
+
dtype=bnb_quantization_config.torch_dtype,
|
| 183 |
+
offload_folder=offload_folder,
|
| 184 |
+
offload_state_dict=offload_state_dict,
|
| 185 |
+
keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules,
|
| 186 |
+
offload_8bit_bnb=load_in_8bit and offload,
|
| 187 |
+
)
|
| 188 |
+
return dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_quantized_model_device_map(
|
| 192 |
+
model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None
|
| 193 |
+
):
|
| 194 |
+
if device_map is None:
|
| 195 |
+
if torch.cuda.is_available():
|
| 196 |
+
device_map = {"": torch.cuda.current_device()}
|
| 197 |
+
elif torch.xpu.is_available():
|
| 198 |
+
device_map = {"": torch.xpu.current_device()}
|
| 199 |
+
else:
|
| 200 |
+
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
| 201 |
+
logger.info("The device_map was not initialized.Setting device_map to `{'':torch.cuda.current_device()}`.")
|
| 202 |
+
|
| 203 |
+
if isinstance(device_map, str):
|
| 204 |
+
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
| 205 |
+
raise ValueError(
|
| 206 |
+
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
| 207 |
+
"'sequential'."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
special_dtypes = {}
|
| 211 |
+
special_dtypes.update(
|
| 212 |
+
{
|
| 213 |
+
name: bnb_quantization_config.torch_dtype
|
| 214 |
+
for name, _ in model.named_parameters()
|
| 215 |
+
if any(m in name for m in bnb_quantization_config.skip_modules)
|
| 216 |
+
}
|
| 217 |
+
)
|
| 218 |
+
special_dtypes.update(
|
| 219 |
+
{
|
| 220 |
+
name: torch.float32
|
| 221 |
+
for name, _ in model.named_parameters()
|
| 222 |
+
if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules)
|
| 223 |
+
}
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
kwargs = {}
|
| 227 |
+
kwargs["special_dtypes"] = special_dtypes
|
| 228 |
+
kwargs["no_split_module_classes"] = no_split_module_classes
|
| 229 |
+
kwargs["dtype"] = bnb_quantization_config.target_dtype
|
| 230 |
+
|
| 231 |
+
# get max_memory for each device.
|
| 232 |
+
if device_map != "sequential":
|
| 233 |
+
max_memory = get_balanced_memory(
|
| 234 |
+
model,
|
| 235 |
+
low_zero=(device_map == "balanced_low_0"),
|
| 236 |
+
max_memory=max_memory,
|
| 237 |
+
**kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
kwargs["max_memory"] = max_memory
|
| 241 |
+
device_map = infer_auto_device_map(model, **kwargs)
|
| 242 |
+
|
| 243 |
+
if isinstance(device_map, dict):
|
| 244 |
+
# check if don't have any quantized module on the cpu
|
| 245 |
+
modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules
|
| 246 |
+
|
| 247 |
+
device_map_without_some_modules = {
|
| 248 |
+
key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert
|
| 249 |
+
}
|
| 250 |
+
for device in ["cpu", "disk"]:
|
| 251 |
+
if device in device_map_without_some_modules.values():
|
| 252 |
+
if bnb_quantization_config.load_in_4bit:
|
| 253 |
+
raise ValueError(
|
| 254 |
+
"""
|
| 255 |
+
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
|
| 256 |
+
the quantized model. If you want to dispatch the model on the CPU or the disk while keeping
|
| 257 |
+
these modules in `torch_dtype`, you need to pass a custom `device_map` to
|
| 258 |
+
`load_and_quantize_model`. Check
|
| 259 |
+
https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk
|
| 260 |
+
for more details.
|
| 261 |
+
"""
|
| 262 |
+
)
|
| 263 |
+
else:
|
| 264 |
+
logger.info(
|
| 265 |
+
"Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit"
|
| 266 |
+
)
|
| 267 |
+
del device_map_without_some_modules
|
| 268 |
+
return device_map
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
|
| 272 |
+
"""
|
| 273 |
+
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
|
| 274 |
+
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
|
| 275 |
+
|
| 276 |
+
Parameters:
|
| 277 |
+
model (`torch.nn.Module`):
|
| 278 |
+
Input model or `torch.nn.Module` as the function is run recursively.
|
| 279 |
+
modules_to_not_convert (`List[str]`):
|
| 280 |
+
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
|
| 281 |
+
numerical stability reasons.
|
| 282 |
+
current_key_name (`List[str]`, *optional*):
|
| 283 |
+
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
| 284 |
+
it) is not in the list of modules to not convert.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
if modules_to_not_convert is None:
|
| 288 |
+
modules_to_not_convert = []
|
| 289 |
+
|
| 290 |
+
model, has_been_replaced = _replace_with_bnb_layers(
|
| 291 |
+
model, bnb_quantization_config, modules_to_not_convert, current_key_name
|
| 292 |
+
)
|
| 293 |
+
if not has_been_replaced:
|
| 294 |
+
logger.warning(
|
| 295 |
+
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
| 296 |
+
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
| 297 |
+
" Please double check your model architecture, or submit an issue on github if you think this is"
|
| 298 |
+
" a bug."
|
| 299 |
+
)
|
| 300 |
+
return model
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _replace_with_bnb_layers(
|
| 304 |
+
model,
|
| 305 |
+
bnb_quantization_config,
|
| 306 |
+
modules_to_not_convert=None,
|
| 307 |
+
current_key_name=None,
|
| 308 |
+
):
|
| 309 |
+
"""
|
| 310 |
+
Private method that wraps the recursion for module replacement.
|
| 311 |
+
|
| 312 |
+
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
|
| 313 |
+
"""
|
| 314 |
+
# bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily
|
| 315 |
+
import bitsandbytes as bnb
|
| 316 |
+
|
| 317 |
+
has_been_replaced = False
|
| 318 |
+
for name, module in model.named_children():
|
| 319 |
+
if current_key_name is None:
|
| 320 |
+
current_key_name = []
|
| 321 |
+
current_key_name.append(name)
|
| 322 |
+
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
| 323 |
+
# Check if the current key is not in the `modules_to_not_convert`
|
| 324 |
+
current_key_name_str = ".".join(current_key_name)
|
| 325 |
+
proceed = True
|
| 326 |
+
for key in modules_to_not_convert:
|
| 327 |
+
if (
|
| 328 |
+
(key in current_key_name_str) and (key + "." in current_key_name_str)
|
| 329 |
+
) or key == current_key_name_str:
|
| 330 |
+
proceed = False
|
| 331 |
+
break
|
| 332 |
+
if proceed:
|
| 333 |
+
# Load bnb module with empty weight and replace ``nn.Linear` module
|
| 334 |
+
if bnb_quantization_config.load_in_8bit:
|
| 335 |
+
bnb_module = bnb.nn.Linear8bitLt(
|
| 336 |
+
module.in_features,
|
| 337 |
+
module.out_features,
|
| 338 |
+
module.bias is not None,
|
| 339 |
+
has_fp16_weights=False,
|
| 340 |
+
threshold=bnb_quantization_config.llm_int8_threshold,
|
| 341 |
+
)
|
| 342 |
+
elif bnb_quantization_config.load_in_4bit:
|
| 343 |
+
bnb_module = bnb.nn.Linear4bit(
|
| 344 |
+
module.in_features,
|
| 345 |
+
module.out_features,
|
| 346 |
+
module.bias is not None,
|
| 347 |
+
bnb_quantization_config.bnb_4bit_compute_dtype,
|
| 348 |
+
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
|
| 349 |
+
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
|
| 353 |
+
bnb_module.weight.data = module.weight.data
|
| 354 |
+
if module.bias is not None:
|
| 355 |
+
bnb_module.bias.data = module.bias.data
|
| 356 |
+
bnb_module.requires_grad_(False)
|
| 357 |
+
setattr(model, name, bnb_module)
|
| 358 |
+
has_been_replaced = True
|
| 359 |
+
if len(list(module.children())) > 0:
|
| 360 |
+
_, _has_been_replaced = _replace_with_bnb_layers(
|
| 361 |
+
module, bnb_quantization_config, modules_to_not_convert, current_key_name
|
| 362 |
+
)
|
| 363 |
+
has_been_replaced = has_been_replaced | _has_been_replaced
|
| 364 |
+
# Remove the last key for recursion
|
| 365 |
+
current_key_name.pop(-1)
|
| 366 |
+
return model, has_been_replaced
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def get_keys_to_not_convert(model):
|
| 370 |
+
r"""
|
| 371 |
+
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
| 372 |
+
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
| 373 |
+
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
| 374 |
+
int8.
|
| 375 |
+
|
| 376 |
+
Parameters:
|
| 377 |
+
model (`torch.nn.Module`):
|
| 378 |
+
Input model
|
| 379 |
+
"""
|
| 380 |
+
# Create a copy of the model
|
| 381 |
+
with init_empty_weights():
|
| 382 |
+
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
| 383 |
+
|
| 384 |
+
tied_params = find_tied_parameters(tied_model)
|
| 385 |
+
# For compatibility with Accelerate < 0.18
|
| 386 |
+
if isinstance(tied_params, dict):
|
| 387 |
+
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
|
| 388 |
+
else:
|
| 389 |
+
tied_keys = sum(tied_params, [])
|
| 390 |
+
has_tied_params = len(tied_keys) > 0
|
| 391 |
+
|
| 392 |
+
# Check if it is a base model
|
| 393 |
+
is_base_model = False
|
| 394 |
+
if hasattr(model, "base_model_prefix"):
|
| 395 |
+
is_base_model = not hasattr(model, model.base_model_prefix)
|
| 396 |
+
|
| 397 |
+
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
| 398 |
+
if (not has_tied_params) and is_base_model:
|
| 399 |
+
return []
|
| 400 |
+
|
| 401 |
+
# otherwise they have an attached head
|
| 402 |
+
list_modules = list(model.named_children())
|
| 403 |
+
list_last_module = [list_modules[-1][0]]
|
| 404 |
+
|
| 405 |
+
# add last module together with tied weights
|
| 406 |
+
intersection = set(list_last_module) - set(tied_keys)
|
| 407 |
+
list_untouched = list(set(tied_keys)) + list(intersection)
|
| 408 |
+
|
| 409 |
+
# remove ".weight" from the keys
|
| 410 |
+
names_to_remove = [".weight", ".bias"]
|
| 411 |
+
filtered_module_names = []
|
| 412 |
+
for name in list_untouched:
|
| 413 |
+
for name_to_remove in names_to_remove:
|
| 414 |
+
if name_to_remove in name:
|
| 415 |
+
name = name.replace(name_to_remove, "")
|
| 416 |
+
filtered_module_names.append(name)
|
| 417 |
+
|
| 418 |
+
return filtered_module_names
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def has_4bit_bnb_layers(model):
|
| 422 |
+
"""Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model"""
|
| 423 |
+
# bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily
|
| 424 |
+
import bitsandbytes as bnb
|
| 425 |
+
|
| 426 |
+
for m in model.modules():
|
| 427 |
+
if isinstance(m, bnb.nn.Linear4bit):
|
| 428 |
+
return True
|
| 429 |
+
return False
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def get_parameter_device(parameter: nn.Module):
|
| 433 |
+
return next(parameter.parameters()).device
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):
|
| 437 |
+
# if it is not quantized, we quantize and offload the quantized weights and the SCB stats
|
| 438 |
+
if fp16_statistics is None:
|
| 439 |
+
set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
|
| 440 |
+
tensor_name = param_name
|
| 441 |
+
module = model
|
| 442 |
+
if "." in tensor_name:
|
| 443 |
+
splits = tensor_name.split(".")
|
| 444 |
+
for split in splits[:-1]:
|
| 445 |
+
new_module = getattr(module, split)
|
| 446 |
+
if new_module is None:
|
| 447 |
+
raise ValueError(f"{module} has no attribute {split}.")
|
| 448 |
+
module = new_module
|
| 449 |
+
tensor_name = splits[-1]
|
| 450 |
+
# offload weights
|
| 451 |
+
module._parameters[tensor_name].requires_grad = False
|
| 452 |
+
offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
|
| 453 |
+
if hasattr(module._parameters[tensor_name], "SCB"):
|
| 454 |
+
offload_weight(
|
| 455 |
+
module._parameters[tensor_name].SCB,
|
| 456 |
+
param_name.replace("weight", "SCB"),
|
| 457 |
+
offload_folder,
|
| 458 |
+
index=offload_index,
|
| 459 |
+
)
|
| 460 |
+
else:
|
| 461 |
+
offload_weight(param, param_name, offload_folder, index=offload_index)
|
| 462 |
+
offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index)
|
| 463 |
+
|
| 464 |
+
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size()))
|
accelerate/utils/constants.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import operator as op
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
SCALER_NAME = "scaler.pt"
|
| 21 |
+
MODEL_NAME = "pytorch_model"
|
| 22 |
+
SAFE_MODEL_NAME = "model"
|
| 23 |
+
RNG_STATE_NAME = "random_states"
|
| 24 |
+
OPTIMIZER_NAME = "optimizer"
|
| 25 |
+
SCHEDULER_NAME = "scheduler"
|
| 26 |
+
SAMPLER_NAME = "sampler"
|
| 27 |
+
PROFILE_PATTERN_NAME = "profile_{suffix}.json"
|
| 28 |
+
WEIGHTS_NAME = f"{MODEL_NAME}.bin"
|
| 29 |
+
WEIGHTS_PATTERN_NAME = "pytorch_model{suffix}.bin"
|
| 30 |
+
WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json"
|
| 31 |
+
SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors"
|
| 32 |
+
SAFE_WEIGHTS_PATTERN_NAME = "model{suffix}.safetensors"
|
| 33 |
+
SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json"
|
| 34 |
+
SAGEMAKER_PYTORCH_VERSION = "1.10.2"
|
| 35 |
+
SAGEMAKER_PYTHON_VERSION = "py38"
|
| 36 |
+
SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0"
|
| 37 |
+
SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"]
|
| 38 |
+
FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "HYBRID_SHARD_ZERO2"]
|
| 39 |
+
FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
|
| 40 |
+
FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
|
| 41 |
+
FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
|
| 42 |
+
FSDP2_STATE_DICT_TYPE = ["SHARDED_STATE_DICT", "FULL_STATE_DICT"]
|
| 43 |
+
FSDP_PYTORCH_VERSION = (
|
| 44 |
+
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
|
| 45 |
+
)
|
| 46 |
+
FSDP2_PYTORCH_VERSION = "2.6.0"
|
| 47 |
+
DTENSOR_PYTORCH_VERSION = "2.5.0"
|
| 48 |
+
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
| 49 |
+
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
|
| 50 |
+
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
|
| 51 |
+
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
|
| 52 |
+
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
|
| 53 |
+
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
|
| 54 |
+
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
|
| 55 |
+
|
| 56 |
+
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
|
| 57 |
+
BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
|
| 58 |
+
BETA_SP_AVAILABLE_DEEPSPEED_VERSION = "0.18.2"
|
| 59 |
+
|
| 60 |
+
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
| 61 |
+
|
| 62 |
+
# These are the args for `torch.distributed.launch` for pytorch < 1.9
|
| 63 |
+
TORCH_LAUNCH_PARAMS = [
|
| 64 |
+
"nnodes",
|
| 65 |
+
"nproc_per_node",
|
| 66 |
+
"rdzv_backend",
|
| 67 |
+
"rdzv_endpoint",
|
| 68 |
+
"rdzv_id",
|
| 69 |
+
"rdzv_conf",
|
| 70 |
+
"standalone",
|
| 71 |
+
"max_restarts",
|
| 72 |
+
"monitor_interval",
|
| 73 |
+
"start_method",
|
| 74 |
+
"role",
|
| 75 |
+
"module",
|
| 76 |
+
"m",
|
| 77 |
+
"no_python",
|
| 78 |
+
"run_path",
|
| 79 |
+
"log_dir",
|
| 80 |
+
"r",
|
| 81 |
+
"redirects",
|
| 82 |
+
"t",
|
| 83 |
+
"tee",
|
| 84 |
+
"node_rank",
|
| 85 |
+
"master_addr",
|
| 86 |
+
"master_port",
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
|
| 90 |
+
TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
|
| 91 |
+
"MULTI_NPU",
|
| 92 |
+
"MULTI_MLU",
|
| 93 |
+
"MULTI_SDAA",
|
| 94 |
+
"MULTI_MUSA",
|
| 95 |
+
"MULTI_XPU",
|
| 96 |
+
"MULTI_CPU",
|
| 97 |
+
"MULTI_HPU",
|
| 98 |
+
"MULTI_NEURON",
|
| 99 |
+
]
|
| 100 |
+
SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING = (
|
| 101 |
+
torch.nn.Conv1d,
|
| 102 |
+
torch.nn.Conv2d,
|
| 103 |
+
torch.nn.Conv3d,
|
| 104 |
+
torch.nn.ConvTranspose1d,
|
| 105 |
+
torch.nn.ConvTranspose2d,
|
| 106 |
+
torch.nn.ConvTranspose3d,
|
| 107 |
+
torch.nn.Linear,
|
| 108 |
+
)
|