BryanW commited on
Commit
0d11c13
·
verified ·
1 Parent(s): 38f7dd9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/__init__.cpython-312.pyc +0 -0
  2. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/override.cpython-312.pyc +0 -0
  3. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/tpu.py +157 -0
  4. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_ds_multiple_model.cpython-312.pyc +0 -0
  5. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/__init__.py +306 -0
  6. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/ao.py +140 -0
  7. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/bnb.py +469 -0
  8. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/constants.py +106 -0
  9. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/dataclasses.py +0 -0
  10. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/deepspeed.py +385 -0
  11. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/environment.py +471 -0
  12. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/fsdp_utils.py +829 -0
  13. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/imports.py +564 -0
  14. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/launch.py +781 -0
  15. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/megatron_lm.py +1424 -0
  16. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/memory.py +210 -0
  17. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/modeling.py +2186 -0
  18. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/offload.py +213 -0
  19. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py +867 -0
  20. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/other.py +561 -0
  21. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/random.py +156 -0
  22. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/rich.py +24 -0
  23. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/torch_xla.py +51 -0
  24. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/tqdm.py +43 -0
  25. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py +186 -0
  26. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/versions.py +56 -0
  27. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/__init__.cpython-312.pyc +0 -0
  28. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/main.cpython-312.pyc +0 -0
  29. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/cuda_pathfinder-1.4.0.dist-info/licenses/LICENSE +177 -0
  30. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__init__.cpython-312.pyc +0 -0
  31. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__version__.cpython-312.pyc +0 -0
  32. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align.cpython-312.pyc +0 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align_getter.cpython-312.pyc +0 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_base.cpython-312.pyc +0 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_column.cpython-312.pyc +0 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_common.cpython-312.pyc +0 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_container.cpython-312.pyc +0 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_converter.cpython-312.pyc +0 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_dataproperty.cpython-312.pyc +0 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_extractor.cpython-312.pyc +0 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_formatter.cpython-312.pyc +0 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_function.cpython-312.pyc +0 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_interface.cpython-312.pyc +0 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_line_break.cpython-312.pyc +0 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_preprocessor.cpython-312.pyc +0 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/typing.cpython-312.pyc +0 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__init__.py +7 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/__init__.cpython-312.pyc +0 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_logger.cpython-312.pyc +0 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_null_logger.cpython-312.pyc +0 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_distutils_hack/__pycache__/override.cpython-312.pyc ADDED
Binary file (322 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/commands/tpu.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+ import subprocess
20
+
21
+ from packaging.version import Version, parse
22
+
23
+ from accelerate.commands.config.config_args import default_config_file, load_config_from_file
24
+
25
+
26
+ _description = "Run commands across TPU VMs for initial setup before running `accelerate launch`."
27
+
28
+
29
+ def tpu_command_parser(subparsers=None):
30
+ if subparsers is not None:
31
+ parser = subparsers.add_parser("tpu-config", description=_description)
32
+ else:
33
+ parser = argparse.ArgumentParser("Accelerate tpu-config command", description=_description)
34
+ # Core arguments
35
+ config_args = parser.add_argument_group(
36
+ "Config Arguments", "Arguments that can be configured through `accelerate config`."
37
+ )
38
+ config_args.add_argument(
39
+ "--config_file",
40
+ type=str,
41
+ default=None,
42
+ help="Path to the config file to use for accelerate.",
43
+ )
44
+ config_args.add_argument(
45
+ "--tpu_name",
46
+ default=None,
47
+ help="The name of the TPU to use. If not specified, will use the TPU specified in the config file.",
48
+ )
49
+ config_args.add_argument(
50
+ "--tpu_zone",
51
+ default=None,
52
+ help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.",
53
+ )
54
+ pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.")
55
+ pod_args.add_argument(
56
+ "--use_alpha",
57
+ action="store_true",
58
+ help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.",
59
+ )
60
+ pod_args.add_argument(
61
+ "--command_file",
62
+ default=None,
63
+ help="The path to the file containing the commands to run on the pod on startup.",
64
+ )
65
+ pod_args.add_argument(
66
+ "--command",
67
+ action="append",
68
+ nargs="+",
69
+ help="A command to run on the pod. Can be passed multiple times.",
70
+ )
71
+ pod_args.add_argument(
72
+ "--install_accelerate",
73
+ action="store_true",
74
+ help="Whether to install accelerate on the pod. Defaults to False.",
75
+ )
76
+ pod_args.add_argument(
77
+ "--accelerate_version",
78
+ default="latest",
79
+ help="The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.",
80
+ )
81
+ pod_args.add_argument(
82
+ "--debug", action="store_true", help="If set, will print the command that would be run instead of running it."
83
+ )
84
+
85
+ if subparsers is not None:
86
+ parser.set_defaults(func=tpu_command_launcher)
87
+ return parser
88
+
89
+
90
+ def tpu_command_launcher(args):
91
+ defaults = None
92
+
93
+ # Get the default from the config file if it exists.
94
+ if args.config_file is not None or os.path.isfile(default_config_file):
95
+ defaults = load_config_from_file(args.config_file)
96
+ if not args.command_file and defaults.command_file is not None and not args.command:
97
+ args.command_file = defaults.command_file
98
+ if not args.command and defaults.commands is not None:
99
+ args.command = defaults.commands
100
+ if not args.tpu_name:
101
+ args.tpu_name = defaults.tpu_name
102
+ if not args.tpu_zone:
103
+ args.tpu_zone = defaults.tpu_zone
104
+ if args.accelerate_version == "dev":
105
+ args.accelerate_version = "git+https://github.com/huggingface/accelerate.git"
106
+ elif args.accelerate_version == "latest":
107
+ args.accelerate_version = "accelerate -U"
108
+ elif isinstance(parse(args.accelerate_version), Version):
109
+ args.accelerate_version = f"accelerate=={args.accelerate_version}"
110
+
111
+ if not args.command_file and not args.command:
112
+ raise ValueError("You must specify either a command file or a command to run on the pod.")
113
+
114
+ if args.command_file:
115
+ with open(args.command_file) as f:
116
+ args.command = [f.read().splitlines()]
117
+
118
+ # To turn list of lists into list of strings
119
+ if isinstance(args.command[0], list):
120
+ args.command = [line for cmd in args.command for line in cmd]
121
+ # Default to the shared folder and install accelerate
122
+ new_cmd = ["cd /usr/share"]
123
+ if args.install_accelerate:
124
+ new_cmd += [f"pip install {args.accelerate_version}"]
125
+ new_cmd += args.command
126
+ args.command = "; ".join(new_cmd)
127
+
128
+ # Then send it to gcloud
129
+ # Eventually try to use google-api-core to do this instead of subprocess
130
+ cmd = ["gcloud"]
131
+ if args.use_alpha:
132
+ cmd += ["alpha"]
133
+ cmd += [
134
+ "compute",
135
+ "tpus",
136
+ "tpu-vm",
137
+ "ssh",
138
+ args.tpu_name,
139
+ "--zone",
140
+ args.tpu_zone,
141
+ "--command",
142
+ args.command,
143
+ "--worker",
144
+ "all",
145
+ ]
146
+ if args.debug:
147
+ print(f"Running {' '.join(cmd)}")
148
+ return
149
+ subprocess.run(cmd)
150
+ print("Successfully setup pod.")
151
+
152
+
153
+ def main():
154
+ parser = tpu_command_parser()
155
+ args = parser.parse_args()
156
+
157
+ tpu_command_launcher(args)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_ds_multiple_model.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/__init__.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from ..parallelism_config import ParallelismConfig
15
+ from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
16
+ from .constants import (
17
+ MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
18
+ MODEL_NAME,
19
+ OPTIMIZER_NAME,
20
+ PROFILE_PATTERN_NAME,
21
+ RNG_STATE_NAME,
22
+ SAFE_MODEL_NAME,
23
+ SAFE_WEIGHTS_INDEX_NAME,
24
+ SAFE_WEIGHTS_NAME,
25
+ SAFE_WEIGHTS_PATTERN_NAME,
26
+ SAMPLER_NAME,
27
+ SCALER_NAME,
28
+ SCHEDULER_NAME,
29
+ TORCH_DISTRIBUTED_OPERATION_TYPES,
30
+ TORCH_LAUNCH_PARAMS,
31
+ WEIGHTS_INDEX_NAME,
32
+ WEIGHTS_NAME,
33
+ WEIGHTS_PATTERN_NAME,
34
+ XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
35
+ )
36
+ from .dataclasses import (
37
+ AORecipeKwargs,
38
+ AutocastKwargs,
39
+ BnbQuantizationConfig,
40
+ ComputeEnvironment,
41
+ CustomDtype,
42
+ DataLoaderConfiguration,
43
+ DDPCommunicationHookType,
44
+ DeepSpeedPlugin,
45
+ DeepSpeedSequenceParallelConfig,
46
+ DistributedDataParallelKwargs,
47
+ DistributedType,
48
+ DynamoBackend,
49
+ FP8RecipeKwargs,
50
+ FullyShardedDataParallelPlugin,
51
+ GradientAccumulationPlugin,
52
+ GradScalerKwargs,
53
+ InitProcessGroupKwargs,
54
+ KwargsHandler,
55
+ LoggerType,
56
+ MegatronLMPlugin,
57
+ MSAMPRecipeKwargs,
58
+ PrecisionType,
59
+ ProfileKwargs,
60
+ ProjectConfiguration,
61
+ RNGType,
62
+ SageMakerDistributedType,
63
+ TensorInformation,
64
+ TERecipeKwargs,
65
+ TorchContextParallelConfig,
66
+ TorchDynamoPlugin,
67
+ TorchTensorParallelConfig,
68
+ TorchTensorParallelPlugin,
69
+ add_model_config_to_megatron_parser,
70
+ )
71
+ from .environment import (
72
+ are_libraries_initialized,
73
+ check_cuda_fp8_capability,
74
+ check_cuda_p2p_ib_support,
75
+ clear_environment,
76
+ convert_dict_to_env_variables,
77
+ get_cpu_distributed_information,
78
+ get_current_device_type,
79
+ get_gpu_info,
80
+ get_int_from_env,
81
+ parse_choice_from_env,
82
+ parse_flag_from_env,
83
+ patch_environment,
84
+ purge_accelerate_environment,
85
+ set_numa_affinity,
86
+ str_to_bool,
87
+ )
88
+ from .imports import (
89
+ deepspeed_required,
90
+ get_ccl_version,
91
+ is_4bit_bnb_available,
92
+ is_8bit_bnb_available,
93
+ is_aim_available,
94
+ is_bf16_available,
95
+ is_bitsandbytes_multi_backend_available,
96
+ is_bnb_available,
97
+ is_boto3_available,
98
+ is_ccl_available,
99
+ is_clearml_available,
100
+ is_comet_ml_available,
101
+ is_cuda_available,
102
+ is_datasets_available,
103
+ is_deepspeed_available,
104
+ is_dvclive_available,
105
+ is_fp8_available,
106
+ is_fp16_available,
107
+ is_habana_gaudi1,
108
+ is_hpu_available,
109
+ is_import_timer_available,
110
+ is_ipex_available,
111
+ is_lomo_available,
112
+ is_matplotlib_available,
113
+ is_megatron_lm_available,
114
+ is_mlflow_available,
115
+ is_mlu_available,
116
+ is_mps_available,
117
+ is_msamp_available,
118
+ is_musa_available,
119
+ is_npu_available,
120
+ is_pandas_available,
121
+ is_peft_available,
122
+ is_pippy_available,
123
+ is_pynvml_available,
124
+ is_pytest_available,
125
+ is_rich_available,
126
+ is_sagemaker_available,
127
+ is_schedulefree_available,
128
+ is_sdaa_available,
129
+ is_swanlab_available,
130
+ is_tensorboard_available,
131
+ is_timm_available,
132
+ is_torch_xla_available,
133
+ is_torchao_available,
134
+ is_torchdata_available,
135
+ is_torchdata_stateful_dataloader_available,
136
+ is_torchvision_available,
137
+ is_trackio_available,
138
+ is_transformer_engine_available,
139
+ is_transformer_engine_mxfp8_available,
140
+ is_transformers_available,
141
+ is_triton_available,
142
+ is_wandb_available,
143
+ is_weights_only_available,
144
+ is_xccl_available,
145
+ is_xpu_available,
146
+ torchao_required,
147
+ )
148
+ from .modeling import (
149
+ align_module_device,
150
+ calculate_maximum_sizes,
151
+ check_device_map,
152
+ check_tied_parameters_in_config,
153
+ check_tied_parameters_on_same_device,
154
+ compute_module_sizes,
155
+ convert_file_size_to_int,
156
+ dtype_byte_size,
157
+ find_tied_parameters,
158
+ get_balanced_memory,
159
+ get_grad_scaler,
160
+ get_max_layer_size,
161
+ get_max_memory,
162
+ get_mixed_precision_context_manager,
163
+ has_offloaded_params,
164
+ id_tensor_storage,
165
+ infer_auto_device_map,
166
+ is_peft_model,
167
+ load_checkpoint_in_model,
168
+ load_offloaded_weights,
169
+ load_state_dict,
170
+ named_module_tensors,
171
+ retie_parameters,
172
+ set_module_tensor_to_device,
173
+ )
174
+ from .offload import (
175
+ OffloadedWeightsLoader,
176
+ PrefixedDataset,
177
+ extract_submodules_state_dict,
178
+ load_offloaded_weight,
179
+ offload_state_dict,
180
+ offload_weight,
181
+ save_offload_index,
182
+ )
183
+ from .operations import (
184
+ CannotPadNestedTensorWarning,
185
+ GatheredParameters,
186
+ broadcast,
187
+ broadcast_object_list,
188
+ concatenate,
189
+ convert_outputs_to_fp32,
190
+ convert_to_fp32,
191
+ copy_tensor_to_devices,
192
+ find_batch_size,
193
+ find_device,
194
+ gather,
195
+ gather_object,
196
+ get_data_structure,
197
+ honor_type,
198
+ ignorant_find_batch_size,
199
+ initialize_tensors,
200
+ is_namedtuple,
201
+ is_tensor_information,
202
+ is_torch_tensor,
203
+ listify,
204
+ pad_across_processes,
205
+ pad_input_tensors,
206
+ recursively_apply,
207
+ reduce,
208
+ send_to_device,
209
+ slice_tensors,
210
+ )
211
+ from .versions import compare_versions, is_torch_version
212
+
213
+
214
+ if is_deepspeed_available():
215
+ from .deepspeed import (
216
+ DeepSpeedEngineWrapper,
217
+ DeepSpeedOptimizerWrapper,
218
+ DeepSpeedSchedulerWrapper,
219
+ DummyOptim,
220
+ DummyScheduler,
221
+ HfDeepSpeedConfig,
222
+ get_active_deepspeed_plugin,
223
+ map_pytorch_optim_to_deepspeed,
224
+ )
225
+
226
+ from .bnb import has_4bit_bnb_layers, load_and_quantize_model
227
+ from .fsdp_utils import (
228
+ disable_fsdp_ram_efficient_loading,
229
+ enable_fsdp_ram_efficient_loading,
230
+ ensure_weights_retied,
231
+ fsdp2_apply_ac,
232
+ fsdp2_canonicalize_names,
233
+ fsdp2_load_full_state_dict,
234
+ fsdp2_prepare_model,
235
+ fsdp2_switch_optimizer_parameters,
236
+ get_fsdp2_grad_scaler,
237
+ load_fsdp_model,
238
+ load_fsdp_optimizer,
239
+ merge_fsdp_weights,
240
+ save_fsdp_model,
241
+ save_fsdp_optimizer,
242
+ )
243
+ from .launch import (
244
+ PrepareForLaunch,
245
+ _filter_args,
246
+ prepare_deepspeed_cmd_env,
247
+ prepare_multi_gpu_env,
248
+ prepare_sagemager_args_inputs,
249
+ prepare_simple_launcher_cmd_env,
250
+ prepare_tpu,
251
+ )
252
+
253
+ # For docs
254
+ from .megatron_lm import (
255
+ AbstractTrainStep,
256
+ BertTrainStep,
257
+ GPTTrainStep,
258
+ MegatronLMDummyDataLoader,
259
+ MegatronLMDummyScheduler,
260
+ T5TrainStep,
261
+ avg_losses_across_data_parallel_group,
262
+ )
263
+
264
+
265
+ if is_megatron_lm_available():
266
+ from .megatron_lm import (
267
+ MegatronEngine,
268
+ MegatronLMOptimizerWrapper,
269
+ MegatronLMSchedulerWrapper,
270
+ gather_across_data_parallel_groups,
271
+ )
272
+ from .megatron_lm import initialize as megatron_lm_initialize
273
+ from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
274
+ from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
275
+ from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
276
+ from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
277
+ from .memory import find_executable_batch_size, release_memory
278
+ from .other import (
279
+ check_os_kernel,
280
+ clean_state_dict_for_safetensors,
281
+ compile_regions,
282
+ compile_regions_deepspeed,
283
+ convert_bytes,
284
+ extract_model_from_parallel,
285
+ get_module_children_bottom_up,
286
+ get_pretty_name,
287
+ has_compiled_regions,
288
+ is_compiled_module,
289
+ is_port_in_use,
290
+ load,
291
+ merge_dicts,
292
+ model_has_dtensor,
293
+ recursive_getattr,
294
+ save,
295
+ wait_for_everyone,
296
+ write_basic_config,
297
+ )
298
+ from .random import set_seed, synchronize_rng_state, synchronize_rng_states
299
+ from .torch_xla import install_xla
300
+ from .tqdm import tqdm
301
+ from .transformer_engine import (
302
+ apply_fp8_autowrap,
303
+ contextual_fp8_autocast,
304
+ convert_model,
305
+ has_transformer_engine_layers,
306
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/ao.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Needed utilities for torchao FP8 training.
17
+ """
18
+
19
+ from functools import partial
20
+ from typing import TYPE_CHECKING, Callable, Optional
21
+
22
+ import torch
23
+
24
+ from .imports import is_torchao_available, torchao_required
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ if is_torchao_available():
29
+ from torchao.float8.float8_linear import Float8LinearConfig
30
+
31
+
32
+ def find_first_last_linear_layers(model: torch.nn.Module):
33
+ """
34
+ Finds the first and last linear layer names in a model.
35
+
36
+ This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
37
+
38
+ Ref: https://x.com/xariusrke/status/1826669142604141052
39
+ """
40
+ first_linear, last_linear = None, None
41
+ for name, module in model.named_modules():
42
+ if isinstance(module, torch.nn.Linear):
43
+ if first_linear is None:
44
+ first_linear = name
45
+ last_linear = name
46
+ return first_linear, last_linear
47
+
48
+
49
+ def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:
50
+ """
51
+ A function which will check if `module` is:
52
+ - a `torch.nn.Linear` layer
53
+ - has in_features and out_features divisible by 16
54
+ - is not part of `layers_to_filter`
55
+
56
+ Args:
57
+ module (`torch.nn.Module`):
58
+ The module to check.
59
+ fqn (`str`):
60
+ The fully qualified name of the layer.
61
+ layers_to_filter (`List[str]`):
62
+ The list of layers to filter.
63
+ """
64
+ if isinstance(module, torch.nn.Linear):
65
+ if module.in_features % 16 != 0 or module.out_features % 16 != 0:
66
+ return False
67
+ if fqn in layers_to_filter:
68
+ return False
69
+ return True
70
+
71
+
72
+ def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
73
+ """
74
+ A filter function which will filter out all linear layers except the first and last.
75
+
76
+ <Tip>
77
+
78
+ For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
79
+ converging properly
80
+
81
+ </Tip>
82
+
83
+ Args:
84
+ module (`torch.nn.Module`):
85
+ The module to check.
86
+ fqn (`str`):
87
+ The fully qualified name of the layer.
88
+ """
89
+ first_linear, last_linear = find_first_last_linear_layers(module)
90
+ return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])
91
+
92
+
93
+ @torchao_required
94
+ def has_ao_layers(model: torch.nn.Module):
95
+ from torchao.float8.float8_linear import Float8Linear
96
+
97
+ for name, module in model.named_modules():
98
+ if isinstance(module, Float8Linear):
99
+ return True
100
+ return False
101
+
102
+
103
+ @torchao_required
104
+ def convert_model_to_fp8_ao(
105
+ model: torch.nn.Module,
106
+ config: Optional["Float8LinearConfig"] = None,
107
+ module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
108
+ ):
109
+ """
110
+ Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
111
+
112
+ Args:
113
+ model (`torch.nn.Module`):
114
+ The model to convert.
115
+ config (`torchao.float8.Float8LinearConfig`, *optional*):
116
+ The configuration for the FP8 training. Recommended to utilize
117
+ `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
118
+ sufficient (what is passed when set to `None`).
119
+ module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
120
+ Optional function that must take in a module and layer name, and returns a boolean indicating whether the
121
+ module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
122
+
123
+ Example:
124
+
125
+ ```python
126
+ from accelerate.utils.ao import convert_model_to_fp8_ao
127
+
128
+ model = MyModel()
129
+ model.to("cuda")
130
+ convert_to_float8_training(model)
131
+
132
+ model.train()
133
+ ```
134
+ """
135
+ from torchao.float8 import convert_to_float8_training
136
+
137
+ first_linear, last_linear = find_first_last_linear_layers(model)
138
+ if module_filter_func is None:
139
+ module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])
140
+ convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/bnb.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+ import os
18
+ from copy import deepcopy
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from accelerate.utils.imports import (
25
+ is_4bit_bnb_available,
26
+ is_8bit_bnb_available,
27
+ )
28
+
29
+ from ..big_modeling import dispatch_model, init_empty_weights
30
+ from .dataclasses import BnbQuantizationConfig
31
+ from .modeling import (
32
+ find_tied_parameters,
33
+ get_balanced_memory,
34
+ infer_auto_device_map,
35
+ load_checkpoint_in_model,
36
+ offload_weight,
37
+ set_module_tensor_to_device,
38
+ )
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def load_and_quantize_model(
45
+ model: torch.nn.Module,
46
+ bnb_quantization_config: BnbQuantizationConfig,
47
+ weights_location: Optional[Union[str, os.PathLike]] = None,
48
+ device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
49
+ no_split_module_classes: Optional[list[str]] = None,
50
+ max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
51
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
52
+ offload_state_dict: bool = False,
53
+ ):
54
+ """
55
+ This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the
56
+ model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the
57
+ model is already loaded, we will quantize the model and put the model on the GPU,
58
+
59
+ Args:
60
+ model (`torch.nn.Module`):
61
+ Input model. The model can be already loaded or on the meta device
62
+ bnb_quantization_config (`BnbQuantizationConfig`):
63
+ The bitsandbytes quantization parameters
64
+ weights_location (`str` or `os.PathLike`):
65
+ The folder weights_location to load. It can be:
66
+ - a path to a file containing a whole model state dict
67
+ - a path to a `.json` file containing the index to a sharded checkpoint
68
+ - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
69
+ - a path to a folder containing a unique pytorch_model.bin file.
70
+ device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
71
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
72
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
73
+ no_split_module_classes (`List[str]`, *optional*):
74
+ A list of layer class names that should never be split across device (for instance any layer that has a
75
+ residual connection).
76
+ max_memory (`Dict`, *optional*):
77
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
78
+ offload_folder (`str` or `os.PathLike`, *optional*):
79
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
80
+ offload_state_dict (`bool`, *optional*, defaults to `False`):
81
+ If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
82
+ the weight of the CPU state dict + the biggest shard does not fit.
83
+
84
+ Returns:
85
+ `torch.nn.Module`: The quantized model
86
+ """
87
+
88
+ load_in_4bit = bnb_quantization_config.load_in_4bit
89
+ load_in_8bit = bnb_quantization_config.load_in_8bit
90
+
91
+ if load_in_8bit and not is_8bit_bnb_available():
92
+ raise ImportError(
93
+ "You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
94
+ " make sure you have the latest version of `bitsandbytes` installed."
95
+ )
96
+ if load_in_4bit and not is_4bit_bnb_available():
97
+ raise ValueError(
98
+ "You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
99
+ "make sure you have the latest version of `bitsandbytes` installed."
100
+ )
101
+
102
+ modules_on_cpu = []
103
+ # custom device map
104
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
105
+ modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
106
+
107
+ # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
108
+ if bnb_quantization_config.skip_modules is None:
109
+ bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
110
+
111
+ # add cpu modules to skip modules only for 4-bit modules
112
+ if load_in_4bit:
113
+ bnb_quantization_config.skip_modules.extend(modules_on_cpu)
114
+ modules_to_not_convert = bnb_quantization_config.skip_modules
115
+
116
+ # We add the modules we want to keep in full precision
117
+ if bnb_quantization_config.keep_in_fp32_modules is None:
118
+ bnb_quantization_config.keep_in_fp32_modules = []
119
+ keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
120
+ modules_to_not_convert.extend(keep_in_fp32_modules)
121
+
122
+ # compatibility with peft
123
+ model.is_loaded_in_4bit = load_in_4bit
124
+ model.is_loaded_in_8bit = load_in_8bit
125
+
126
+ model_device = get_parameter_device(model)
127
+ if model_device.type != "meta":
128
+ # quantization of an already loaded model
129
+ logger.warning(
130
+ "It is not recommended to quantize a loaded model. "
131
+ "The model should be instantiated under the `init_empty_weights` context manager."
132
+ )
133
+ model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
134
+ # convert param to the right dtype
135
+ dtype = bnb_quantization_config.torch_dtype
136
+ for name, param in model.state_dict().items():
137
+ if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
138
+ param.to(torch.float32)
139
+ if param.dtype != torch.float32:
140
+ name = name.replace(".weight", "").replace(".bias", "")
141
+ param = getattr(model, name, None)
142
+ if param is not None:
143
+ param.to(torch.float32)
144
+ elif torch.is_floating_point(param):
145
+ param.to(dtype)
146
+ if model_device.type == "cuda":
147
+ model.cuda(torch.cuda.current_device())
148
+ torch.cuda.empty_cache()
149
+ elif torch.cuda.is_available():
150
+ model.to(torch.cuda.current_device())
151
+ elif torch.xpu.is_available():
152
+ model.to(torch.xpu.current_device())
153
+ else:
154
+ raise RuntimeError("No GPU or Intel XPU found. A GPU or Intel XPU is needed for quantization.")
155
+ logger.info(
156
+ f"The model device type is {model_device.type}. However, gpu or intel xpu is needed for quantization."
157
+ "We move the model to it."
158
+ )
159
+ return model
160
+
161
+ elif weights_location is None:
162
+ raise RuntimeError(
163
+ f"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} "
164
+ )
165
+
166
+ else:
167
+ with init_empty_weights():
168
+ model = replace_with_bnb_layers(
169
+ model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert
170
+ )
171
+ device_map = get_quantized_model_device_map(
172
+ model,
173
+ bnb_quantization_config,
174
+ device_map,
175
+ max_memory=max_memory,
176
+ no_split_module_classes=no_split_module_classes,
177
+ )
178
+ if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
179
+ offload_state_dict = True
180
+
181
+ offload = any(x in list(device_map.values()) for x in ["cpu", "disk"])
182
+
183
+ load_checkpoint_in_model(
184
+ model,
185
+ weights_location,
186
+ device_map,
187
+ dtype=bnb_quantization_config.torch_dtype,
188
+ offload_folder=offload_folder,
189
+ offload_state_dict=offload_state_dict,
190
+ keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules,
191
+ offload_8bit_bnb=load_in_8bit and offload,
192
+ )
193
+ return dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
194
+
195
+
196
+ def get_quantized_model_device_map(
197
+ model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None
198
+ ):
199
+ if device_map is None:
200
+ if torch.cuda.is_available():
201
+ device_map = {"": torch.cuda.current_device()}
202
+ elif torch.xpu.is_available():
203
+ device_map = {"": torch.xpu.current_device()}
204
+ else:
205
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
206
+ logger.info("The device_map was not initialized.Setting device_map to `{'':torch.cuda.current_device()}`.")
207
+
208
+ if isinstance(device_map, str):
209
+ if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
210
+ raise ValueError(
211
+ "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
212
+ "'sequential'."
213
+ )
214
+
215
+ special_dtypes = {}
216
+ special_dtypes.update(
217
+ {
218
+ name: bnb_quantization_config.torch_dtype
219
+ for name, _ in model.named_parameters()
220
+ if any(m in name for m in bnb_quantization_config.skip_modules)
221
+ }
222
+ )
223
+ special_dtypes.update(
224
+ {
225
+ name: torch.float32
226
+ for name, _ in model.named_parameters()
227
+ if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules)
228
+ }
229
+ )
230
+
231
+ kwargs = {}
232
+ kwargs["special_dtypes"] = special_dtypes
233
+ kwargs["no_split_module_classes"] = no_split_module_classes
234
+ kwargs["dtype"] = bnb_quantization_config.target_dtype
235
+
236
+ # get max_memory for each device.
237
+ if device_map != "sequential":
238
+ max_memory = get_balanced_memory(
239
+ model,
240
+ low_zero=(device_map == "balanced_low_0"),
241
+ max_memory=max_memory,
242
+ **kwargs,
243
+ )
244
+
245
+ kwargs["max_memory"] = max_memory
246
+ device_map = infer_auto_device_map(model, **kwargs)
247
+
248
+ if isinstance(device_map, dict):
249
+ # check if don't have any quantized module on the cpu
250
+ modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules
251
+
252
+ device_map_without_some_modules = {
253
+ key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert
254
+ }
255
+ for device in ["cpu", "disk"]:
256
+ if device in device_map_without_some_modules.values():
257
+ if bnb_quantization_config.load_in_4bit:
258
+ raise ValueError(
259
+ """
260
+ Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
261
+ the quantized model. If you want to dispatch the model on the CPU or the disk while keeping
262
+ these modules in `torch_dtype`, you need to pass a custom `device_map` to
263
+ `load_and_quantize_model`. Check
264
+ https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk
265
+ for more details.
266
+ """
267
+ )
268
+ else:
269
+ logger.info(
270
+ "Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit"
271
+ )
272
+ del device_map_without_some_modules
273
+ return device_map
274
+
275
+
276
+ def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
277
+ """
278
+ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
279
+ modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
280
+
281
+ Parameters:
282
+ model (`torch.nn.Module`):
283
+ Input model or `torch.nn.Module` as the function is run recursively.
284
+ modules_to_not_convert (`List[str]`):
285
+ Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
286
+ numerical stability reasons.
287
+ current_key_name (`List[str]`, *optional*):
288
+ An array to track the current key of the recursion. This is used to check whether the current key (part of
289
+ it) is not in the list of modules to not convert.
290
+ """
291
+
292
+ if modules_to_not_convert is None:
293
+ modules_to_not_convert = []
294
+
295
+ model, has_been_replaced = _replace_with_bnb_layers(
296
+ model, bnb_quantization_config, modules_to_not_convert, current_key_name
297
+ )
298
+ if not has_been_replaced:
299
+ logger.warning(
300
+ "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
301
+ " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
302
+ " Please double check your model architecture, or submit an issue on github if you think this is"
303
+ " a bug."
304
+ )
305
+ return model
306
+
307
+
308
+ def _replace_with_bnb_layers(
309
+ model,
310
+ bnb_quantization_config,
311
+ modules_to_not_convert=None,
312
+ current_key_name=None,
313
+ ):
314
+ """
315
+ Private method that wraps the recursion for module replacement.
316
+
317
+ Returns the converted model and a boolean that indicates if the conversion has been successful or not.
318
+ """
319
+ # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
320
+ import bitsandbytes as bnb
321
+
322
+ has_been_replaced = False
323
+ for name, module in model.named_children():
324
+ if current_key_name is None:
325
+ current_key_name = []
326
+ current_key_name.append(name)
327
+ if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
328
+ # Check if the current key is not in the `modules_to_not_convert`
329
+ current_key_name_str = ".".join(current_key_name)
330
+ proceed = True
331
+ for key in modules_to_not_convert:
332
+ if (
333
+ (key in current_key_name_str) and (key + "." in current_key_name_str)
334
+ ) or key == current_key_name_str:
335
+ proceed = False
336
+ break
337
+ if proceed:
338
+ # Load bnb module with empty weight and replace ``nn.Linear` module
339
+ if bnb_quantization_config.load_in_8bit:
340
+ bnb_module = bnb.nn.Linear8bitLt(
341
+ module.in_features,
342
+ module.out_features,
343
+ module.bias is not None,
344
+ has_fp16_weights=False,
345
+ threshold=bnb_quantization_config.llm_int8_threshold,
346
+ )
347
+ elif bnb_quantization_config.load_in_4bit:
348
+ bnb_module = bnb.nn.Linear4bit(
349
+ module.in_features,
350
+ module.out_features,
351
+ module.bias is not None,
352
+ bnb_quantization_config.bnb_4bit_compute_dtype,
353
+ compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
354
+ quant_type=bnb_quantization_config.bnb_4bit_quant_type,
355
+ )
356
+ else:
357
+ raise ValueError("load_in_8bit and load_in_4bit can't be both False")
358
+ bnb_module.weight.data = module.weight.data
359
+ if module.bias is not None:
360
+ bnb_module.bias.data = module.bias.data
361
+ bnb_module.requires_grad_(False)
362
+ setattr(model, name, bnb_module)
363
+ has_been_replaced = True
364
+ if len(list(module.children())) > 0:
365
+ _, _has_been_replaced = _replace_with_bnb_layers(
366
+ module, bnb_quantization_config, modules_to_not_convert, current_key_name
367
+ )
368
+ has_been_replaced = has_been_replaced | _has_been_replaced
369
+ # Remove the last key for recursion
370
+ current_key_name.pop(-1)
371
+ return model, has_been_replaced
372
+
373
+
374
+ def get_keys_to_not_convert(model):
375
+ r"""
376
+ An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
377
+ we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
378
+ to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
379
+ int8.
380
+
381
+ Parameters:
382
+ model (`torch.nn.Module`):
383
+ Input model
384
+ """
385
+ # Create a copy of the model
386
+ with init_empty_weights():
387
+ tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
388
+
389
+ tied_params = find_tied_parameters(tied_model)
390
+ # For compatibility with Accelerate < 0.18
391
+ if isinstance(tied_params, dict):
392
+ tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
393
+ else:
394
+ tied_keys = sum(tied_params, [])
395
+ has_tied_params = len(tied_keys) > 0
396
+
397
+ # Check if it is a base model
398
+ is_base_model = False
399
+ if hasattr(model, "base_model_prefix"):
400
+ is_base_model = not hasattr(model, model.base_model_prefix)
401
+
402
+ # Ignore this for base models (BertModel, GPT2Model, etc.)
403
+ if (not has_tied_params) and is_base_model:
404
+ return []
405
+
406
+ # otherwise they have an attached head
407
+ list_modules = list(model.named_children())
408
+ list_last_module = [list_modules[-1][0]]
409
+
410
+ # add last module together with tied weights
411
+ intersection = set(list_last_module) - set(tied_keys)
412
+ list_untouched = list(set(tied_keys)) + list(intersection)
413
+
414
+ # remove ".weight" from the keys
415
+ names_to_remove = [".weight", ".bias"]
416
+ filtered_module_names = []
417
+ for name in list_untouched:
418
+ for name_to_remove in names_to_remove:
419
+ if name_to_remove in name:
420
+ name = name.replace(name_to_remove, "")
421
+ filtered_module_names.append(name)
422
+
423
+ return filtered_module_names
424
+
425
+
426
+ def has_4bit_bnb_layers(model):
427
+ """Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model"""
428
+ # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
429
+ import bitsandbytes as bnb
430
+
431
+ for m in model.modules():
432
+ if isinstance(m, bnb.nn.Linear4bit):
433
+ return True
434
+ return False
435
+
436
+
437
+ def get_parameter_device(parameter: nn.Module):
438
+ return next(parameter.parameters()).device
439
+
440
+
441
+ def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):
442
+ # if it is not quantized, we quantize and offload the quantized weights and the SCB stats
443
+ if fp16_statistics is None:
444
+ set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
445
+ tensor_name = param_name
446
+ module = model
447
+ if "." in tensor_name:
448
+ splits = tensor_name.split(".")
449
+ for split in splits[:-1]:
450
+ new_module = getattr(module, split)
451
+ if new_module is None:
452
+ raise ValueError(f"{module} has no attribute {split}.")
453
+ module = new_module
454
+ tensor_name = splits[-1]
455
+ # offload weights
456
+ module._parameters[tensor_name].requires_grad = False
457
+ offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
458
+ if hasattr(module._parameters[tensor_name], "SCB"):
459
+ offload_weight(
460
+ module._parameters[tensor_name].SCB,
461
+ param_name.replace("weight", "SCB"),
462
+ offload_folder,
463
+ index=offload_index,
464
+ )
465
+ else:
466
+ offload_weight(param, param_name, offload_folder, index=offload_index)
467
+ offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index)
468
+
469
+ set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size()))
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/constants.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import operator as op
16
+
17
+ import torch
18
+
19
+
20
+ SCALER_NAME = "scaler.pt"
21
+ MODEL_NAME = "pytorch_model"
22
+ SAFE_MODEL_NAME = "model"
23
+ RNG_STATE_NAME = "random_states"
24
+ OPTIMIZER_NAME = "optimizer"
25
+ SCHEDULER_NAME = "scheduler"
26
+ SAMPLER_NAME = "sampler"
27
+ PROFILE_PATTERN_NAME = "profile_{suffix}.json"
28
+ WEIGHTS_NAME = f"{MODEL_NAME}.bin"
29
+ WEIGHTS_PATTERN_NAME = "pytorch_model{suffix}.bin"
30
+ WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json"
31
+ SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors"
32
+ SAFE_WEIGHTS_PATTERN_NAME = "model{suffix}.safetensors"
33
+ SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json"
34
+ SAGEMAKER_PYTORCH_VERSION = "1.10.2"
35
+ SAGEMAKER_PYTHON_VERSION = "py38"
36
+ SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0"
37
+ SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"]
38
+ FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "HYBRID_SHARD_ZERO2"]
39
+ FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
40
+ FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
41
+ FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
42
+ FSDP2_STATE_DICT_TYPE = ["SHARDED_STATE_DICT", "FULL_STATE_DICT"]
43
+ FSDP_PYTORCH_VERSION = (
44
+ "2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
45
+ )
46
+ FSDP2_PYTORCH_VERSION = "2.6.0"
47
+ FSDP_MODEL_NAME = "pytorch_model_fsdp"
48
+ DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
49
+ TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
50
+ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
51
+ XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
52
+ MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
53
+ BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
54
+
55
+ BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
56
+ BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
57
+ BETA_SP_AVAILABLE_DEEPSPEED_VERSION = "0.18.2"
58
+
59
+ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
60
+
61
+ # These are the args for `torch.distributed.launch` for pytorch < 1.9
62
+ TORCH_LAUNCH_PARAMS = [
63
+ "nnodes",
64
+ "nproc_per_node",
65
+ "rdzv_backend",
66
+ "rdzv_endpoint",
67
+ "rdzv_id",
68
+ "rdzv_conf",
69
+ "standalone",
70
+ "max_restarts",
71
+ "monitor_interval",
72
+ "start_method",
73
+ "role",
74
+ "module",
75
+ "m",
76
+ "no_python",
77
+ "run_path",
78
+ "log_dir",
79
+ "r",
80
+ "redirects",
81
+ "t",
82
+ "tee",
83
+ "node_rank",
84
+ "master_addr",
85
+ "master_port",
86
+ ]
87
+
88
+ CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
89
+ TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
90
+ "MULTI_NPU",
91
+ "MULTI_MLU",
92
+ "MULTI_SDAA",
93
+ "MULTI_MUSA",
94
+ "MULTI_XPU",
95
+ "MULTI_CPU",
96
+ "MULTI_HPU",
97
+ ]
98
+ SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING = (
99
+ torch.nn.Conv1d,
100
+ torch.nn.Conv2d,
101
+ torch.nn.Conv3d,
102
+ torch.nn.ConvTranspose1d,
103
+ torch.nn.ConvTranspose2d,
104
+ torch.nn.ConvTranspose3d,
105
+ torch.nn.Linear,
106
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/dataclasses.py ADDED
The diff for this file is too large to render. See raw diff
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/deepspeed.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import json
17
+ import os
18
+ from copy import deepcopy
19
+
20
+ from torch import optim
21
+
22
+ from ..optimizer import AcceleratedOptimizer
23
+ from ..scheduler import AcceleratedScheduler
24
+ from .dataclasses import DistributedType
25
+ from .imports import is_bnb_available
26
+ from .versions import compare_versions
27
+
28
+
29
+ def map_pytorch_optim_to_deepspeed(optimizer):
30
+ """
31
+ Args:
32
+ optimizer: torch.optim.Optimizer
33
+
34
+ Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer.
35
+ """
36
+
37
+ defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]}
38
+
39
+ # Select the DeepSpeedCPUOptimizer based on the original optimizer class.
40
+ # DeepSpeedCPUAdam is the default
41
+ from deepspeed.ops.adam import DeepSpeedCPUAdam
42
+
43
+ optimizer_class = DeepSpeedCPUAdam
44
+
45
+ # For DeepSpeedCPUAdam (adamw_mode)
46
+ if compare_versions("deepspeed", ">=", "0.3.1"):
47
+ defaults["adamw_mode"] = False
48
+ is_adaw = isinstance(optimizer, optim.AdamW)
49
+
50
+ if is_bnb_available() and not is_adaw:
51
+ import bitsandbytes.optim as bnb_opt
52
+
53
+ if isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)):
54
+ try:
55
+ is_adaw = optimizer.optim_bits == 32
56
+ except AttributeError:
57
+ is_adaw = optimizer.args.optim_bits == 32
58
+ else:
59
+ is_adaw = False
60
+
61
+ if is_adaw:
62
+ defaults["adamw_mode"] = True
63
+
64
+ # For DeepSpeedCPUAdagrad
65
+ if compare_versions("deepspeed", ">=", "0.5.5"):
66
+ # Check if the optimizer is PyTorch's Adagrad.
67
+ is_ada = isinstance(optimizer, optim.Adagrad)
68
+ # If not, and bitsandbytes is available,
69
+ # # check if the optimizer is the 32-bit bitsandbytes Adagrad.
70
+ if is_bnb_available() and not is_ada:
71
+ import bitsandbytes.optim as bnb_opt
72
+
73
+ if isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)):
74
+ try:
75
+ is_ada = optimizer.optim_bits == 32
76
+ except AttributeError:
77
+ is_ada = optimizer.args.optim_bits == 32
78
+ if is_ada:
79
+ from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
80
+
81
+ optimizer_class = DeepSpeedCPUAdagrad
82
+
83
+ # For DeepSpeedCPULion
84
+ if is_bnb_available(min_version="0.38.0") and compare_versions("deepspeed", ">=", "0.11.0"):
85
+ from bitsandbytes.optim import Lion, Lion32bit
86
+
87
+ if isinstance(optimizer, (Lion, Lion32bit)):
88
+ try:
89
+ is_bnb_32bits = optimizer.optim_bits == 32
90
+ except AttributeError:
91
+ is_bnb_32bits = optimizer.args.optim_bits == 32
92
+ if is_bnb_32bits:
93
+ from deepspeed.ops.lion import DeepSpeedCPULion
94
+
95
+ optimizer_class = DeepSpeedCPULion
96
+
97
+ return optimizer_class(optimizer.param_groups, **defaults)
98
+
99
+
100
+ def get_active_deepspeed_plugin(state):
101
+ """
102
+ Returns the currently active DeepSpeedPlugin.
103
+
104
+ Raises:
105
+ ValueError: If DeepSpeed was not enabled and this function is called.
106
+ """
107
+ if state.distributed_type != DistributedType.DEEPSPEED:
108
+ raise ValueError(
109
+ "Couldn't retrieve the active `DeepSpeedPlugin` as none were enabled. "
110
+ "Please make sure that either `Accelerator` is configured for `deepspeed` "
111
+ "or make sure that the desired `DeepSpeedPlugin` has been enabled (`AcceleratorState().select_deepspeed_plugin(name)`) "
112
+ "before calling this function."
113
+ )
114
+ if not isinstance(state.deepspeed_plugins, dict):
115
+ return state.deepspeed_plugins
116
+ return next(plugin for plugin in state.deepspeed_plugins.values() if plugin.selected)
117
+
118
+
119
+ class HfDeepSpeedConfig:
120
+ """
121
+ This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
122
+
123
+ A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
124
+ things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
125
+ it's important that this object remains alive while the program is still running.
126
+
127
+ [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
128
+ with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
129
+ the DeepSpeed configuration is not modified in any way.
130
+
131
+ Args:
132
+ config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.
133
+
134
+ """
135
+
136
+ def __init__(self, config_file_or_dict):
137
+ if isinstance(config_file_or_dict, dict):
138
+ # Don't modify user's data should they want to reuse it (e.g. in tests), because once we
139
+ # modified it, it will not be accepted here again, since `auto` values would have been overridden
140
+ config = deepcopy(config_file_or_dict)
141
+ elif os.path.exists(config_file_or_dict):
142
+ with open(config_file_or_dict, encoding="utf-8") as f:
143
+ config = json.load(f)
144
+ else:
145
+ try:
146
+ try:
147
+ # First try parsing as JSON directly
148
+ config = json.loads(config_file_or_dict)
149
+ except json.JSONDecodeError:
150
+ # If that fails, try base64 decoding
151
+ config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8")
152
+ config = json.loads(config_decoded)
153
+ except (UnicodeDecodeError, AttributeError, ValueError):
154
+ raise ValueError(
155
+ f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}"
156
+ )
157
+
158
+ self.config = config
159
+
160
+ self.set_stage_and_offload()
161
+
162
+ def set_stage_and_offload(self):
163
+ # zero stage - this is done as early as possible, before model is created, to allow
164
+ # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
165
+ # during ``zero.Init()`` which needs to know the dtype, and some other hparams.
166
+ self._stage = self.get_value("zero_optimization.stage", -1)
167
+
168
+ # offload
169
+ self._offload = False
170
+ if self.is_zero2() or self.is_zero3():
171
+ offload_devices_valid = set(["cpu", "nvme"])
172
+ offload_devices = set(
173
+ [
174
+ self.get_value("zero_optimization.offload_optimizer.device"),
175
+ self.get_value("zero_optimization.offload_param.device"),
176
+ ]
177
+ )
178
+ if len(offload_devices & offload_devices_valid) > 0:
179
+ self._offload = True
180
+
181
+ def find_config_node(self, ds_key_long):
182
+ config = self.config
183
+
184
+ # find the config node of interest if it exists
185
+ nodes = ds_key_long.split(".")
186
+ ds_key = nodes.pop()
187
+ for node in nodes:
188
+ config = config.get(node)
189
+ if config is None:
190
+ return None, ds_key
191
+
192
+ return config, ds_key
193
+
194
+ def get_value(self, ds_key_long, default=None):
195
+ """
196
+ Returns the set value or `default` if no value is set
197
+ """
198
+ config, ds_key = self.find_config_node(ds_key_long)
199
+ if config is None:
200
+ return default
201
+ return config.get(ds_key, default)
202
+
203
+ def del_config_sub_tree(self, ds_key_long, must_exist=False):
204
+ """
205
+ Deletes a sub-section of the config file if it's found.
206
+
207
+ Unless `must_exist` is `True` the section doesn't have to exist.
208
+ """
209
+ config = self.config
210
+
211
+ # find the config node of interest if it exists
212
+ nodes = ds_key_long.split(".")
213
+ for node in nodes:
214
+ parent_config = config
215
+ config = config.get(node)
216
+ if config is None:
217
+ if must_exist:
218
+ raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}")
219
+ else:
220
+ return
221
+
222
+ # if found remove it
223
+ if parent_config is not None:
224
+ parent_config.pop(node)
225
+
226
+ def is_true(self, ds_key_long):
227
+ """
228
+ Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
229
+ specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set).
230
+
231
+ """
232
+ value = self.get_value(ds_key_long)
233
+ return False if value is None else bool(value)
234
+
235
+ def is_false(self, ds_key_long):
236
+ """
237
+ Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
238
+ specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set).
239
+ """
240
+ value = self.get_value(ds_key_long)
241
+ return False if value is None else not bool(value)
242
+
243
+ def is_zero2(self):
244
+ return self._stage == 2
245
+
246
+ def is_zero3(self):
247
+ return self._stage == 3
248
+
249
+ def is_offload(self):
250
+ return self._offload
251
+
252
+
253
+ class DeepSpeedEngineWrapper:
254
+ """
255
+ Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.
256
+
257
+ Args:
258
+ engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap
259
+ """
260
+
261
+ def __init__(self, engine):
262
+ self.engine = engine
263
+
264
+ def backward(self, loss, sync_gradients=True, **kwargs):
265
+ # Set gradient accumulation boundary based on Accelerate's sync_gradients state
266
+ # This tells DeepSpeed whether this is the final micro-batch before gradient sync
267
+ self.engine.set_gradient_accumulation_boundary(is_boundary=sync_gradients)
268
+
269
+ # runs backpropagation and handles mixed precision
270
+ self.engine.backward(loss, **kwargs)
271
+
272
+ # Only perform step and related operations at gradient accumulation boundaries
273
+ if sync_gradients:
274
+ # Deepspeed's `engine.step` performs the following operations:
275
+ # - gradient accumulation check
276
+ # - gradient clipping
277
+ # - optimizer step
278
+ # - zero grad
279
+ # - checking overflow
280
+ # - lr_scheduler step (only if engine.lr_scheduler is not None)
281
+ self.engine.step()
282
+ # and this plugin overrides the above calls with no-ops when Accelerate runs under
283
+ # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple
284
+ # training loop that works transparently under many training regimes.
285
+
286
+ def get_global_grad_norm(self):
287
+ """Get the global gradient norm from DeepSpeed engine."""
288
+ grad_norm = self.engine.get_global_grad_norm()
289
+ # Convert to scalar if it's a tensor
290
+ if hasattr(grad_norm, "item"):
291
+ return grad_norm.item()
292
+ return grad_norm
293
+
294
+
295
+ class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
296
+ """
297
+ Internal wrapper around a deepspeed optimizer.
298
+
299
+ Args:
300
+ optimizer (`torch.optim.optimizer.Optimizer`):
301
+ The optimizer to wrap.
302
+ """
303
+
304
+ def __init__(self, optimizer):
305
+ super().__init__(optimizer, device_placement=False, scaler=None)
306
+ self.__has_overflow__ = hasattr(self.optimizer, "overflow")
307
+
308
+ def zero_grad(self, set_to_none=None):
309
+ pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
310
+
311
+ def step(self):
312
+ pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
313
+
314
+ @property
315
+ def step_was_skipped(self):
316
+ """Whether or not the optimizer step was done, or skipped because of gradient overflow."""
317
+ if self.__has_overflow__:
318
+ return self.optimizer.overflow
319
+ return False
320
+
321
+
322
+ class DeepSpeedSchedulerWrapper(AcceleratedScheduler):
323
+ """
324
+ Internal wrapper around a deepspeed scheduler.
325
+
326
+ Args:
327
+ scheduler (`torch.optim.lr_scheduler.LambdaLR`):
328
+ The scheduler to wrap.
329
+ optimizers (one or a list of `torch.optim.Optimizer`):
330
+ """
331
+
332
+ def __init__(self, scheduler, optimizers):
333
+ super().__init__(scheduler, optimizers)
334
+
335
+ def step(self):
336
+ pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
337
+
338
+
339
+ class DummyOptim:
340
+ """
341
+ Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training
342
+ loop when optimizer config is specified in the deepspeed config file.
343
+
344
+ Args:
345
+ lr (float):
346
+ Learning rate.
347
+ params (iterable): iterable of parameters to optimize or dicts defining
348
+ parameter groups
349
+ weight_decay (float):
350
+ Weight decay.
351
+ **kwargs (additional keyword arguments, *optional*):
352
+ Other arguments.
353
+ """
354
+
355
+ def __init__(self, params, lr=0.001, weight_decay=0, **kwargs):
356
+ self.params = params
357
+ self.lr = lr
358
+ self.weight_decay = weight_decay
359
+ self.kwargs = kwargs
360
+
361
+
362
+ class DummyScheduler:
363
+ """
364
+ Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
365
+ loop when scheduler config is specified in the deepspeed config file.
366
+
367
+ Args:
368
+ optimizer (`torch.optim.optimizer.Optimizer`):
369
+ The optimizer to wrap.
370
+ total_num_steps (int, *optional*):
371
+ Total number of steps.
372
+ warmup_num_steps (int, *optional*):
373
+ Number of steps for warmup.
374
+ lr_scheduler_callable (callable, *optional*):
375
+ A callable function that creates an LR Scheduler. It accepts only one argument `optimizer`.
376
+ **kwargs (additional keyword arguments, *optional*):
377
+ Other arguments.
378
+ """
379
+
380
+ def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, lr_scheduler_callable=None, **kwargs):
381
+ self.optimizer = optimizer
382
+ self.total_num_steps = total_num_steps
383
+ self.warmup_num_steps = warmup_num_steps
384
+ self.lr_scheduler_callable = lr_scheduler_callable
385
+ self.kwargs = kwargs
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/environment.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ import os
18
+ import platform
19
+ import subprocess
20
+ import sys
21
+ from contextlib import contextmanager
22
+ from dataclasses import dataclass, field
23
+ from functools import lru_cache, wraps
24
+ from shutil import which
25
+ from typing import Optional, Union
26
+
27
+ import torch
28
+ from packaging.version import parse
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def convert_dict_to_env_variables(current_env: dict):
35
+ """
36
+ Verifies that all keys and values in `current_env` do not contain illegal keys or values, and returns a list of
37
+ strings as the result.
38
+
39
+ Example:
40
+ ```python
41
+ >>> from accelerate.utils.environment import verify_env
42
+
43
+ >>> env = {"ACCELERATE_DEBUG_MODE": "1", "BAD_ENV_NAME": "<mything", "OTHER_ENV": "2"}
44
+ >>> valid_env_items = verify_env(env)
45
+ >>> print(valid_env_items)
46
+ ["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"]
47
+ ```
48
+ """
49
+ forbidden_chars = [";", "\n", "<", ">", " "]
50
+ valid_env_items = []
51
+ for key, value in current_env.items():
52
+ if all(char not in (key + value) for char in forbidden_chars) and len(key) >= 1 and len(value) >= 1:
53
+ valid_env_items.append(f"{key}={value}\n")
54
+ else:
55
+ logger.warning(f"WARNING: Skipping {key}={value} as it contains forbidden characters or missing values.")
56
+ return valid_env_items
57
+
58
+
59
+ def str_to_bool(value, to_bool: bool = False) -> Union[int, bool]:
60
+ """
61
+ Converts a string representation of truth to `True` (1) or `False` (0).
62
+
63
+ True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
64
+ """
65
+ value = value.lower()
66
+ if value in ("y", "yes", "t", "true", "on", "1"):
67
+ return 1 if not to_bool else True
68
+ elif value in ("n", "no", "f", "false", "off", "0"):
69
+ return 0 if not to_bool else False
70
+ else:
71
+ raise ValueError(f"invalid truth value {value}")
72
+
73
+
74
+ def get_int_from_env(env_keys, default):
75
+ """Returns the first positive env value found in the `env_keys` list or the default."""
76
+ for e in env_keys:
77
+ val = int(os.environ.get(e, -1))
78
+ if val >= 0:
79
+ return val
80
+ return default
81
+
82
+
83
+ def parse_flag_from_env(key, default=False):
84
+ """Returns truthy value for `key` from the env if available else the default."""
85
+ value = os.environ.get(key, str(default))
86
+ return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
87
+
88
+
89
+ def parse_choice_from_env(key, default="no"):
90
+ value = os.environ.get(key, str(default))
91
+ return value
92
+
93
+
94
+ def are_libraries_initialized(*library_names: str) -> list[str]:
95
+ """
96
+ Checks if any of `library_names` are imported in the environment. Will return any names that are.
97
+ """
98
+ return [lib_name for lib_name in library_names if lib_name in sys.modules.keys()]
99
+
100
+
101
+ def get_current_device_type() -> tuple[str, str]:
102
+ """
103
+ Determines the current device type and distributed type without initializing any device.
104
+
105
+ This is particularly important when using fork-based multiprocessing, as device initialization
106
+ before forking can cause errors.
107
+
108
+ The device detection order follows the same priority as state.py:_prepare_backend():
109
+ MLU -> SDAA -> MUSA -> NPU -> HPU -> CUDA -> XPU
110
+
111
+ Returns:
112
+ tuple[str, str]: A tuple of (device_type, distributed_type)
113
+ - device_type: The device string (e.g., "cuda", "npu", "xpu")
114
+ - distributed_type: The distributed type string (e.g., "MULTI_GPU", "MULTI_NPU")
115
+
116
+ Example:
117
+ ```python
118
+ >>> device_type, distributed_type = get_current_device_type()
119
+ >>> print(device_type) # "cuda"
120
+ >>> print(distributed_type) # "MULTI_GPU"
121
+ ```
122
+ """
123
+ from .imports import (
124
+ is_hpu_available,
125
+ is_mlu_available,
126
+ is_musa_available,
127
+ is_npu_available,
128
+ is_sdaa_available,
129
+ is_xpu_available,
130
+ )
131
+
132
+ if is_mlu_available():
133
+ return "mlu", "MULTI_MLU"
134
+ elif is_sdaa_available():
135
+ return "sdaa", "MULTI_SDAA"
136
+ elif is_musa_available():
137
+ return "musa", "MULTI_MUSA"
138
+ elif is_npu_available():
139
+ return "npu", "MULTI_NPU"
140
+ elif is_hpu_available():
141
+ return "hpu", "MULTI_HPU"
142
+ elif torch.cuda.is_available():
143
+ return "cuda", "MULTI_GPU"
144
+ elif is_xpu_available():
145
+ return "xpu", "MULTI_XPU"
146
+ else:
147
+ # Default to CUDA even if not available (for CPU-only scenarios where CUDA code paths are still used)
148
+ return "cuda", "MULTI_GPU"
149
+
150
+
151
+ def _nvidia_smi():
152
+ """
153
+ Returns the right nvidia-smi command based on the system.
154
+ """
155
+ if platform.system() == "Windows":
156
+ # If platform is Windows and nvidia-smi can't be found in path
157
+ # try from systemd drive with default installation path
158
+ command = which("nvidia-smi")
159
+ if command is None:
160
+ command = f"{os.environ['systemdrive']}\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe"
161
+ else:
162
+ command = "nvidia-smi"
163
+ return command
164
+
165
+
166
+ def get_gpu_info():
167
+ """
168
+ Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA.
169
+
170
+ Largely based on the `gputil` library.
171
+ """
172
+ # Returns as list of `n` GPUs and their names
173
+ output = subprocess.check_output(
174
+ [_nvidia_smi(), "--query-gpu=count,name", "--format=csv,noheader"], universal_newlines=True
175
+ )
176
+ output = output.strip()
177
+ gpus = output.split(os.linesep)
178
+ # Get names from output
179
+ gpu_count = len(gpus)
180
+ gpu_names = [gpu.split(",")[1].strip() for gpu in gpus]
181
+ return gpu_names, gpu_count
182
+
183
+
184
+ def get_driver_version():
185
+ """
186
+ Returns the driver version
187
+
188
+ In the case of multiple GPUs, will return the first.
189
+ """
190
+ output = subprocess.check_output(
191
+ [_nvidia_smi(), "--query-gpu=driver_version", "--format=csv,noheader"], universal_newlines=True
192
+ )
193
+ output = output.strip()
194
+ return output.split(os.linesep)[0]
195
+
196
+
197
+ def check_cuda_p2p_ib_support():
198
+ """
199
+ Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
200
+ the 3090.
201
+
202
+ Notably uses `nvidia-smi` instead of torch to not initialize CUDA.
203
+ """
204
+ try:
205
+ device_names, device_count = get_gpu_info()
206
+ # As new consumer GPUs get released, add them to `unsupported_devices``
207
+ unsupported_devices = {"RTX 40"}
208
+ if device_count > 1:
209
+ if any(
210
+ unsupported_device in device_name
211
+ for device_name in device_names
212
+ for unsupported_device in unsupported_devices
213
+ ):
214
+ # Check if they have the right driver version
215
+ acceptable_driver_version = "550.40.07"
216
+ current_driver_version = get_driver_version()
217
+ if parse(current_driver_version) < parse(acceptable_driver_version):
218
+ return False
219
+ return True
220
+ except Exception:
221
+ pass
222
+ return True
223
+
224
+
225
+ @lru_cache
226
+ def check_cuda_fp8_capability():
227
+ """
228
+ Checks if the current GPU available supports FP8.
229
+
230
+ Notably might initialize `torch.cuda` to check.
231
+ """
232
+
233
+ try:
234
+ # try to get the compute capability from nvidia-smi
235
+ output = subprocess.check_output(
236
+ [_nvidia_smi(), "--query-gpu=compute_capability", "--format=csv,noheader"], universal_newlines=True
237
+ )
238
+ output = output.strip()
239
+ # we take the first GPU's compute capability
240
+ compute_capability = tuple(map(int, output.split(os.linesep)[0].split(".")))
241
+ except Exception:
242
+ compute_capability = torch.cuda.get_device_capability()
243
+
244
+ return compute_capability >= (8, 9)
245
+
246
+
247
+ @dataclass
248
+ class CPUInformation:
249
+ """
250
+ Stores information about the CPU in a distributed environment. It contains the following attributes:
251
+ - rank: The rank of the current process.
252
+ - world_size: The total number of processes in the world.
253
+ - local_rank: The rank of the current process on the local node.
254
+ - local_world_size: The total number of processes on the local node.
255
+ """
256
+
257
+ rank: int = field(default=0, metadata={"help": "The rank of the current process."})
258
+ world_size: int = field(default=1, metadata={"help": "The total number of processes in the world."})
259
+ local_rank: int = field(default=0, metadata={"help": "The rank of the current process on the local node."})
260
+ local_world_size: int = field(default=1, metadata={"help": "The total number of processes on the local node."})
261
+
262
+
263
+ def get_cpu_distributed_information() -> CPUInformation:
264
+ """
265
+ Returns various information about the environment in relation to CPU distributed training as a `CPUInformation`
266
+ dataclass.
267
+ """
268
+ information = {}
269
+ information["rank"] = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
270
+ information["world_size"] = get_int_from_env(
271
+ ["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1
272
+ )
273
+ information["local_rank"] = get_int_from_env(
274
+ ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0
275
+ )
276
+ information["local_world_size"] = get_int_from_env(
277
+ ["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"],
278
+ 1,
279
+ )
280
+ return CPUInformation(**information)
281
+
282
+
283
+ def override_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None:
284
+ """
285
+ Overrides whatever NUMA affinity is set for the current process. This is very taxing and requires recalculating the
286
+ affinity to set, ideally you should use `utils.environment.set_numa_affinity` instead.
287
+
288
+ Args:
289
+ local_process_index (int):
290
+ The index of the current process on the current server.
291
+ verbose (bool, *optional*):
292
+ Whether to log out the assignment of each CPU. If `ACCELERATE_DEBUG_MODE` is enabled, will default to True.
293
+ """
294
+ if verbose is None:
295
+ verbose = parse_flag_from_env("ACCELERATE_DEBUG_MODE", False)
296
+ if torch.cuda.is_available():
297
+ from accelerate.utils import is_pynvml_available
298
+
299
+ if not is_pynvml_available():
300
+ raise ImportError(
301
+ "To set CPU affinity on CUDA GPUs the `nvidia-ml-py` package must be available. (`pip install nvidia-ml-py`)"
302
+ )
303
+ import pynvml as nvml
304
+
305
+ # The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py
306
+ nvml.nvmlInit()
307
+ num_elements = math.ceil(os.cpu_count() / 64)
308
+ handle = nvml.nvmlDeviceGetHandleByIndex(local_process_index)
309
+ affinity_string = ""
310
+ for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements):
311
+ # assume nvml returns list of 64 bit ints
312
+ affinity_string = f"{j:064b}{affinity_string}"
313
+ affinity_list = [int(x) for x in affinity_string]
314
+ affinity_list.reverse() # so core 0 is the 0th element
315
+ affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0]
316
+ os.sched_setaffinity(0, affinity_to_set)
317
+ if verbose:
318
+ cpu_cores = os.sched_getaffinity(0)
319
+ logger.info(f"Assigning {len(cpu_cores)} cpu cores to process {local_process_index}: {cpu_cores}")
320
+
321
+
322
+ @lru_cache
323
+ def set_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None:
324
+ """
325
+ Assigns the current process to a specific NUMA node. Ideally most efficient when having at least 2 cpus per node.
326
+
327
+ This result is cached between calls. If you want to override it, please use
328
+ `accelerate.utils.environment.override_numa_afifnity`.
329
+
330
+ Args:
331
+ local_process_index (int):
332
+ The index of the current process on the current server.
333
+ verbose (bool, *optional*):
334
+ Whether to print the new cpu cores assignment for each process. If `ACCELERATE_DEBUG_MODE` is enabled, will
335
+ default to True.
336
+ """
337
+ override_numa_affinity(local_process_index=local_process_index, verbose=verbose)
338
+
339
+
340
+ @contextmanager
341
+ def clear_environment():
342
+ """
343
+ A context manager that will temporarily clear environment variables.
344
+
345
+ When this context exits, the previous environment variables will be back.
346
+
347
+ Example:
348
+
349
+ ```python
350
+ >>> import os
351
+ >>> from accelerate.utils import clear_environment
352
+
353
+ >>> os.environ["FOO"] = "bar"
354
+ >>> with clear_environment():
355
+ ... print(os.environ)
356
+ ... os.environ["FOO"] = "new_bar"
357
+ ... print(os.environ["FOO"])
358
+ {}
359
+ new_bar
360
+
361
+ >>> print(os.environ["FOO"])
362
+ bar
363
+ ```
364
+ """
365
+ _old_os_environ = os.environ.copy()
366
+ os.environ.clear()
367
+
368
+ try:
369
+ yield
370
+ finally:
371
+ os.environ.clear() # clear any added keys,
372
+ os.environ.update(_old_os_environ) # then restore previous environment
373
+
374
+
375
+ @contextmanager
376
+ def patch_environment(**kwargs):
377
+ """
378
+ A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
379
+
380
+ Will convert the values in `kwargs` to strings and upper-case all the keys.
381
+
382
+ Example:
383
+
384
+ ```python
385
+ >>> import os
386
+ >>> from accelerate.utils import patch_environment
387
+
388
+ >>> with patch_environment(FOO="bar"):
389
+ ... print(os.environ["FOO"]) # prints "bar"
390
+ >>> print(os.environ["FOO"]) # raises KeyError
391
+ ```
392
+ """
393
+ existing_vars = {}
394
+ for key, value in kwargs.items():
395
+ key = key.upper()
396
+ if key in os.environ:
397
+ existing_vars[key] = os.environ[key]
398
+ os.environ[key] = str(value)
399
+
400
+ try:
401
+ yield
402
+ finally:
403
+ for key in kwargs:
404
+ key = key.upper()
405
+ if key in existing_vars:
406
+ # restore previous value
407
+ os.environ[key] = existing_vars[key]
408
+ else:
409
+ os.environ.pop(key, None)
410
+
411
+
412
+ def purge_accelerate_environment(func_or_cls):
413
+ """Decorator to clean up accelerate environment variables set by the decorated class or function.
414
+
415
+ In some circumstances, calling certain classes or functions can result in accelerate env vars being set and not
416
+ being cleaned up afterwards. As an example, when calling:
417
+
418
+ TrainingArguments(fp16=True, ...)
419
+
420
+ The following env var will be set:
421
+
422
+ ACCELERATE_MIXED_PRECISION=fp16
423
+
424
+ This can affect subsequent code, since the env var takes precedence over TrainingArguments(fp16=False). This is
425
+ especially relevant for unit testing, where we want to avoid the individual tests to have side effects on one
426
+ another. Decorate the unit test function or whole class with this decorator to ensure that after each test, the env
427
+ vars are cleaned up. This works for both unittest.TestCase and normal classes (pytest); it also works when
428
+ decorating the parent class.
429
+
430
+ """
431
+ prefix = "ACCELERATE_"
432
+
433
+ @contextmanager
434
+ def env_var_context():
435
+ # Store existing accelerate env vars
436
+ existing_vars = {k: v for k, v in os.environ.items() if k.startswith(prefix)}
437
+ try:
438
+ yield
439
+ finally:
440
+ # Restore original env vars or remove new ones
441
+ for key in [k for k in os.environ if k.startswith(prefix)]:
442
+ if key in existing_vars:
443
+ os.environ[key] = existing_vars[key]
444
+ else:
445
+ os.environ.pop(key, None)
446
+
447
+ def wrap_function(func):
448
+ @wraps(func)
449
+ def wrapper(*args, **kwargs):
450
+ with env_var_context():
451
+ return func(*args, **kwargs)
452
+
453
+ wrapper._accelerate_is_purged_environment_wrapped = True
454
+ return wrapper
455
+
456
+ if not isinstance(func_or_cls, type):
457
+ return wrap_function(func_or_cls)
458
+
459
+ # Handle classes by wrapping test methods
460
+ def wrap_test_methods(test_class_instance):
461
+ for name in dir(test_class_instance):
462
+ if name.startswith("test"):
463
+ method = getattr(test_class_instance, name)
464
+ if callable(method) and not hasattr(method, "_accelerate_is_purged_environment_wrapped"):
465
+ setattr(test_class_instance, name, wrap_function(method))
466
+ return test_class_instance
467
+
468
+ # Handle inheritance
469
+ wrap_test_methods(func_or_cls)
470
+ func_or_cls.__init_subclass__ = classmethod(lambda cls, **kw: wrap_test_methods(cls))
471
+ return func_or_cls
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/fsdp_utils.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+ import functools
16
+ import os
17
+ import re
18
+ import shutil
19
+ import warnings
20
+ from collections import defaultdict
21
+ from collections.abc import Iterable
22
+ from contextlib import nullcontext
23
+ from pathlib import Path
24
+ from typing import Callable, Union
25
+
26
+ import torch
27
+
28
+ from ..logging import get_logger
29
+ from .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
30
+ from .dataclasses import get_module_class_from_name
31
+ from .modeling import get_non_persistent_buffers, is_peft_model
32
+ from .other import get_module_children_bottom_up, is_compiled_module, save
33
+ from .versions import is_torch_version
34
+
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ def enable_fsdp_ram_efficient_loading():
40
+ """
41
+ Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
42
+ """
43
+ # Sets values for `transformers.modeling_utils.is_fsdp_enabled`
44
+ if "ACCELERATE_USE_FSDP" not in os.environ:
45
+ os.environ["ACCELERATE_USE_FSDP"] = "True"
46
+ os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "True"
47
+
48
+
49
+ def disable_fsdp_ram_efficient_loading():
50
+ """
51
+ Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
52
+ """
53
+ os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "False"
54
+
55
+
56
+ def _get_model_state_dict(model, adapter_only=False, sd_options=None):
57
+ if adapter_only and is_peft_model(model):
58
+ from peft import get_peft_model_state_dict
59
+
60
+ return get_peft_model_state_dict(model, adapter_name=model.active_adapter)
61
+
62
+ # Invariant: `sd_options` is not None only for FSDP2
63
+ if sd_options is not None:
64
+ from torch.distributed.checkpoint.state_dict import get_model_state_dict
65
+
66
+ return get_model_state_dict(model, options=sd_options)
67
+ else:
68
+ return model.state_dict()
69
+
70
+
71
+ def _set_model_state_dict(model, state_dict, adapter_only=False, sd_options=None):
72
+ if adapter_only and is_peft_model(model):
73
+ from peft import set_peft_model_state_dict
74
+
75
+ return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter)
76
+
77
+ # Invariant: `sd_options` is not None only for FSDP2
78
+ if sd_options is not None:
79
+ from torch.distributed.checkpoint.state_dict import set_model_state_dict
80
+
81
+ return set_model_state_dict(model, state_dict, options=sd_options)
82
+ else:
83
+ return model.load_state_dict(state_dict)
84
+
85
+
86
+ def _prepare_sd_options(fsdp_plugin):
87
+ sd_options = None
88
+
89
+ # we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0
90
+ if fsdp_plugin.fsdp_version == 2:
91
+ from torch.distributed.checkpoint.state_dict import StateDictOptions
92
+ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
93
+
94
+ sd_options = StateDictOptions(
95
+ full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT,
96
+ cpu_offload=getattr(fsdp_plugin.state_dict_config, "offload_to_cpu", False),
97
+ broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, "rank0_only", False),
98
+ )
99
+
100
+ return sd_options
101
+
102
+
103
+ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):
104
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
105
+ import torch.distributed.checkpoint as dist_cp
106
+ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
107
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
108
+ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
109
+
110
+ os.makedirs(output_dir, exist_ok=True)
111
+ if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
112
+ # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
113
+ # so, only enable it when num_processes>1
114
+ is_multi_process = accelerator.num_processes > 1
115
+ fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
116
+ fsdp_plugin.state_dict_config.rank0_only = is_multi_process
117
+
118
+ ctx = (
119
+ FSDP.state_dict_type(
120
+ model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
121
+ )
122
+ if fsdp_plugin.fsdp_version == 1
123
+ else nullcontext()
124
+ )
125
+ sd_options = _prepare_sd_options(fsdp_plugin)
126
+
127
+ with ctx:
128
+ state_dict = _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)
129
+ if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
130
+ weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
131
+ output_model_file = os.path.join(output_dir, weights_name)
132
+ if accelerator.process_index == 0:
133
+ logger.info(f"Saving model to {output_model_file}")
134
+ torch.save(state_dict, output_model_file)
135
+ logger.info(f"Model saved to {output_model_file}")
136
+ # Invariant: `LOCAL_STATE_DICT` is never possible with `FSDP2`
137
+ elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
138
+ weights_name = (
139
+ f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
140
+ if model_index == 0
141
+ else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
142
+ )
143
+ output_model_file = os.path.join(output_dir, weights_name)
144
+ logger.info(f"Saving model to {output_model_file}")
145
+ torch.save(state_dict, output_model_file)
146
+ logger.info(f"Model saved to {output_model_file}")
147
+ elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:
148
+ ckpt_dir = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{model_index}")
149
+ os.makedirs(ckpt_dir, exist_ok=True)
150
+ logger.info(f"Saving model to {ckpt_dir}")
151
+ state_dict = {"model": state_dict}
152
+
153
+ dist_cp.save(
154
+ state_dict=state_dict,
155
+ storage_writer=dist_cp.FileSystemWriter(ckpt_dir),
156
+ planner=DefaultSavePlanner(),
157
+ )
158
+ logger.info(f"Model saved to {ckpt_dir}")
159
+
160
+
161
+ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):
162
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
163
+ import torch.distributed.checkpoint as dist_cp
164
+ from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
165
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
166
+ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
167
+
168
+ accelerator.wait_for_everyone()
169
+ if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
170
+ # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
171
+ # so, only enable it when num_processes>1
172
+ is_multi_process = accelerator.num_processes > 1
173
+ fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
174
+ fsdp_plugin.state_dict_config.rank0_only = is_multi_process
175
+
176
+ ctx = (
177
+ FSDP.state_dict_type(
178
+ model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
179
+ )
180
+ if fsdp_plugin.fsdp_version == 1
181
+ else nullcontext()
182
+ )
183
+ sd_options = _prepare_sd_options(fsdp_plugin)
184
+ with ctx:
185
+ if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
186
+ if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2:
187
+ if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1:
188
+ raise ValueError(
189
+ "Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
190
+ "initializing FSDP object"
191
+ )
192
+ return
193
+ weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
194
+ input_model_file = os.path.join(input_dir, weights_name)
195
+ logger.info(f"Loading model from {input_model_file}")
196
+ # we want an empty state dict for FSDP2 as we use `broadcast_from_rank0`
197
+ load_model = not accelerator.is_fsdp2 or accelerator.is_main_process
198
+ if load_model:
199
+ state_dict = torch.load(input_model_file, weights_only=True)
200
+ else:
201
+ state_dict = {}
202
+ logger.info(f"Model loaded from {input_model_file}")
203
+ elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
204
+ weights_name = (
205
+ f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
206
+ if model_index == 0
207
+ else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
208
+ )
209
+ input_model_file = os.path.join(input_dir, weights_name)
210
+ logger.info(f"Loading model from {input_model_file}")
211
+ state_dict = torch.load(input_model_file, weights_only=True)
212
+ logger.info(f"Model loaded from {input_model_file}")
213
+ elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:
214
+ ckpt_dir = (
215
+ os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{model_index}")
216
+ if f"{FSDP_MODEL_NAME}" not in input_dir
217
+ else input_dir
218
+ )
219
+ logger.info(f"Loading model from {ckpt_dir}")
220
+ state_dict = {"model": _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)}
221
+ dist_cp.load(
222
+ state_dict=state_dict,
223
+ storage_reader=dist_cp.FileSystemReader(ckpt_dir),
224
+ planner=DefaultLoadPlanner(),
225
+ )
226
+ state_dict = state_dict["model"]
227
+ logger.info(f"Model loaded from {ckpt_dir}")
228
+
229
+ load_result = _set_model_state_dict(model, state_dict, adapter_only=adapter_only, sd_options=sd_options)
230
+ return load_result
231
+
232
+
233
+ def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):
234
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
235
+ import torch.distributed.checkpoint as dist_cp
236
+ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
237
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
238
+ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
239
+
240
+ os.makedirs(output_dir, exist_ok=True)
241
+
242
+ ctx = (
243
+ FSDP.state_dict_type(
244
+ model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
245
+ )
246
+ if fsdp_plugin.fsdp_version == 1
247
+ else nullcontext()
248
+ )
249
+
250
+ sd_options = _prepare_sd_options(fsdp_plugin)
251
+
252
+ with ctx:
253
+ if fsdp_plugin.fsdp_version == 2:
254
+ from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
255
+
256
+ optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options)
257
+ else:
258
+ optim_state = FSDP.optim_state_dict(model, optimizer)
259
+
260
+ if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
261
+ if accelerator.process_index == 0:
262
+ optim_state_name = (
263
+ f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
264
+ )
265
+ output_optimizer_file = os.path.join(output_dir, optim_state_name)
266
+ logger.info(f"Saving Optimizer state to {output_optimizer_file}")
267
+ torch.save(optim_state, output_optimizer_file)
268
+ logger.info(f"Optimizer state saved in {output_optimizer_file}")
269
+ else:
270
+ ckpt_dir = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
271
+ os.makedirs(ckpt_dir, exist_ok=True)
272
+ logger.info(f"Saving Optimizer state to {ckpt_dir}")
273
+ dist_cp.save(
274
+ state_dict={"optimizer": optim_state},
275
+ storage_writer=dist_cp.FileSystemWriter(ckpt_dir),
276
+ planner=DefaultSavePlanner(),
277
+ )
278
+ logger.info(f"Optimizer state saved in {ckpt_dir}")
279
+
280
+
281
+ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
282
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
283
+ import torch.distributed.checkpoint as dist_cp
284
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
285
+ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
286
+
287
+ accelerator.wait_for_everyone()
288
+ ctx = (
289
+ FSDP.state_dict_type(
290
+ model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
291
+ )
292
+ if fsdp_plugin.fsdp_version == 1
293
+ else nullcontext()
294
+ )
295
+ sd_options = _prepare_sd_options(fsdp_plugin)
296
+ with ctx:
297
+ if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
298
+ optim_state = None
299
+ if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
300
+ optimizer_name = (
301
+ f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
302
+ )
303
+ input_optimizer_file = os.path.join(input_dir, optimizer_name)
304
+ logger.info(f"Loading Optimizer state from {input_optimizer_file}")
305
+ optim_state = torch.load(input_optimizer_file, weights_only=True)
306
+ logger.info(f"Optimizer state loaded from {input_optimizer_file}")
307
+ else:
308
+ ckpt_dir = (
309
+ os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
310
+ if f"{OPTIMIZER_NAME}" not in input_dir
311
+ else input_dir
312
+ )
313
+ logger.info(f"Loading Optimizer from {ckpt_dir}")
314
+ optim_state = {"optimizer": optimizer.state_dict()}
315
+ dist_cp.load(
316
+ optim_state,
317
+ checkpoint_id=ckpt_dir,
318
+ storage_reader=dist_cp.FileSystemReader(ckpt_dir),
319
+ )
320
+ optim_state = optim_state["optimizer"]
321
+ logger.info(f"Optimizer loaded from {ckpt_dir}")
322
+
323
+ if fsdp_plugin.fsdp_version == 1:
324
+ flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)
325
+ optimizer.load_state_dict(flattened_osd)
326
+ else:
327
+ from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict
328
+
329
+ set_optimizer_state_dict(model, optimizer, optim_state, options=sd_options)
330
+
331
+
332
+ def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True):
333
+ """
334
+ Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
335
+
336
+ Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
337
+ """
338
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
339
+ import torch.distributed.checkpoint as dist_cp
340
+ import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
341
+
342
+ state_dict = {}
343
+ save_path = Path(save_path)
344
+ save_path.mkdir(exist_ok=True)
345
+ dist_cp_format_utils._load_state_dict(
346
+ state_dict,
347
+ storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
348
+ planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),
349
+ no_dist=True,
350
+ )
351
+ save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME
352
+
353
+ # To handle if state is a dict like {model: {...}}
354
+ if len(state_dict.keys()) == 1:
355
+ state_dict = state_dict[list(state_dict)[0]]
356
+ save(state_dict, save_path, safe_serialization=safe_serialization)
357
+ return save_path
358
+
359
+
360
+ def merge_fsdp_weights(
361
+ checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False
362
+ ):
363
+ """
364
+ Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
365
+ `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
366
+ `safe_serialization` else `pytorch_model.bin`.
367
+
368
+ Note: this is a CPU-bound process.
369
+
370
+ Args:
371
+ checkpoint_dir (`str`):
372
+ The directory containing the FSDP checkpoints (can be either the model or optimizer).
373
+ output_path (`str`):
374
+ The path to save the merged checkpoint.
375
+ safe_serialization (`bool`, *optional*, defaults to `True`):
376
+ Whether to save the merged weights with safetensors (recommended).
377
+ remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
378
+ Whether to remove the checkpoint directory after merging.
379
+ """
380
+ checkpoint_dir = Path(checkpoint_dir)
381
+ from accelerate.state import PartialState
382
+
383
+ if not is_torch_version(">=", "2.3.0"):
384
+ raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
385
+
386
+ # Verify that the checkpoint directory exists
387
+ if not checkpoint_dir.exists():
388
+ model_path_exists = (checkpoint_dir / "pytorch_model_fsdp_0").exists()
389
+ optimizer_path_exists = (checkpoint_dir / "optimizer_0").exists()
390
+ err = f"Tried to load from {checkpoint_dir} but couldn't find a valid metadata file."
391
+ if model_path_exists and optimizer_path_exists:
392
+ err += " However, potential model and optimizer checkpoint directories exist."
393
+ err += f"Please pass in either {checkpoint_dir}/pytorch_model_fsdp_0 or {checkpoint_dir}/optimizer_0"
394
+ err += "instead."
395
+ elif model_path_exists:
396
+ err += " However, a potential model checkpoint directory exists."
397
+ err += f"Please try passing in {checkpoint_dir}/pytorch_model_fsdp_0 instead."
398
+ elif optimizer_path_exists:
399
+ err += " However, a potential optimizer checkpoint directory exists."
400
+ err += f"Please try passing in {checkpoint_dir}/optimizer_0 instead."
401
+ raise ValueError(err)
402
+
403
+ # To setup `save` to work
404
+ state = PartialState()
405
+ if state.is_main_process:
406
+ logger.info(f"Merging FSDP weights from {checkpoint_dir}")
407
+ save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)
408
+ logger.info(f"Successfully merged FSDP weights and saved to {save_path}")
409
+ if remove_checkpoint_dir:
410
+ logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
411
+ shutil.rmtree(checkpoint_dir)
412
+ state.wait_for_everyone()
413
+
414
+
415
+ def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.device):
416
+ _tied_names = getattr(model, "_tied_weights_keys", None)
417
+ if not _tied_names:
418
+ # if no tied names just passthrough
419
+ return param_init_fn
420
+
421
+ # get map of parameter instances to params.
422
+ # - needed for replacement later
423
+ _tied_params = {}
424
+ for name in _tied_names:
425
+ name = name.split(".")
426
+ name, param_name = ".".join(name[:-1]), name[-1]
427
+ mod = model.get_submodule(name)
428
+ param = getattr(mod, param_name)
429
+
430
+ _tied_params[id(param)] = None # placeholder for the param first
431
+
432
+ # build param_init_fn for the case with tied params
433
+ def param_init_fn_tied_param(module: torch.nn.Module):
434
+ # track which params to tie
435
+ # - usually only 1, but for completeness consider > 1
436
+ params_to_tie = defaultdict(list)
437
+ for n, param in module.named_parameters(recurse=False):
438
+ if id(param) in _tied_params:
439
+ params_to_tie[id(param)].append(n)
440
+
441
+ # call the param init fn, which potentially re-allocates the
442
+ # parameters
443
+ module = param_init_fn(module)
444
+
445
+ # search the parameters again and tie them up again
446
+ for id_key, _param_names in params_to_tie.items():
447
+ for param_name in _param_names:
448
+ param = _tied_params[id_key]
449
+ if param is None:
450
+ # everything will be tied to the first time the
451
+ # param is observed
452
+ _tied_params[id_key] = getattr(module, param_name)
453
+ else:
454
+ setattr(module, param_name, param) # tie
455
+
456
+ return module
457
+
458
+ return param_init_fn_tied_param
459
+
460
+
461
+ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
462
+ """
463
+ Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
464
+ parameters from rank 0 to all other ranks. This function modifies the model in-place.
465
+
466
+ Args:
467
+ accelerator (`Accelerator`): The accelerator instance
468
+ model (`torch.nn.Module`):
469
+ The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
470
+ full_sd (`dict`): The full state dict to load, can only be on rank 0
471
+ """
472
+ import torch.distributed as dist
473
+ from torch.distributed.tensor import DTensor, distribute_tensor
474
+
475
+ # Model was previously copied to meta device
476
+ meta_sharded_sd = model.state_dict()
477
+ sharded_sd = {}
478
+
479
+ # Rank 0 distributes the full state dict to other ranks
480
+ def _infer_parameter_dtype(model, param_name, empty_param):
481
+ try:
482
+ old_param = model.get_parameter_or_buffer(param_name)
483
+ except AttributeError:
484
+ # Need this for LORA, as there some params are not *parameters* of sorts
485
+ base_param_name, local_param_name = param_name.rsplit(".", 1)
486
+ submodule = model.get_submodule(base_param_name)
487
+ old_param = getattr(submodule, local_param_name)
488
+
489
+ is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
490
+ casting_dtype = None
491
+ is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
492
+
493
+ if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
494
+ casting_dtype = old_param.dtype
495
+
496
+ return old_param is not None and old_param.is_contiguous(), casting_dtype
497
+
498
+ def _cast_and_contiguous(tensor, to_contiguous, dtype):
499
+ if dtype is not None:
500
+ tensor = tensor.to(dtype=dtype)
501
+ if to_contiguous:
502
+ tensor = tensor.contiguous()
503
+ return tensor
504
+
505
+ if accelerator.is_main_process:
506
+ for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
507
+ device_mesh = sharded_param.device_mesh
508
+ full_param = full_param.detach().to(device_mesh.device_type)
509
+ if isinstance(full_param, DTensor):
510
+ # dist.broadcast() only supports torch.Tensor.
511
+ # After prepare_tp(), model parameters may become DTensor.
512
+ # To broadcast such a parameter, convert it to a local tensor first.
513
+ full_param = full_param.to_local()
514
+ dist.broadcast(full_param, src=0, group=dist.group.WORLD)
515
+ sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
516
+ to_contiguous, casting_dtype = _infer_parameter_dtype(
517
+ model,
518
+ param_name,
519
+ full_param,
520
+ )
521
+ sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
522
+ sharded_sd[param_name] = sharded_tensor
523
+ # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
524
+ else:
525
+ for param_name, sharded_param in meta_sharded_sd.items():
526
+ device_mesh = sharded_param.device_mesh
527
+ full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
528
+ dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
529
+ sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
530
+ to_contiguous, casting_dtype = _infer_parameter_dtype(
531
+ model,
532
+ param_name,
533
+ full_tensor,
534
+ )
535
+ sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
536
+ sharded_sd[param_name] = sharded_tensor
537
+
538
+ # we set `assign=True` because our params are on meta device
539
+ model.load_state_dict(sharded_sd, assign=True)
540
+ return model
541
+
542
+
543
+ def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping: dict):
544
+ """
545
+ Switches the parameters of the optimizer to new ones (sharded parameters in usual case). This function modifies the
546
+ optimizer in-place.
547
+
548
+ Args:
549
+ optimizer (`torch.optim.Optimizer`): Optimizer instance which contains the original model parameters
550
+ mapping (`dict`): Mapping from the original parameter (specified by `data_ptr`) to the sharded parameter
551
+
552
+ Raises:
553
+ KeyError:
554
+ If a parameter in the optimizer couldn't be switched to its sharded version. This should never happen and
555
+ indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically
556
+ correct and weights wouldn't get updated.
557
+ """
558
+ from torch.distributed.tensor import DTensor
559
+
560
+ accessor_mapping = {}
561
+
562
+ accessor_mapping[DTensor] = "_local_tensor"
563
+ try:
564
+ for param_group in optimizer.param_groups:
565
+ param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
566
+ except KeyError:
567
+ # This shouldn't ever happen, but we want to fail here else training wouldn't be numerically correct
568
+ # This basically means that we're missing a mapping from the original parameter to the sharded parameter
569
+ raise KeyError(
570
+ "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
571
+ )
572
+
573
+
574
+ def fsdp2_apply_ac(accelerator, model: torch.nn.Module):
575
+ """
576
+ Applies the activation checkpointing to the model.
577
+
578
+ Args:
579
+ accelerator (`Accelerator`): The accelerator instance
580
+ model (`torch.nn.Module`): The model to apply the activation checkpointing to
581
+
582
+ Returns:
583
+ `torch.nn.Module`: The model with the activation checkpointing applied
584
+ """
585
+
586
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
587
+ checkpoint_wrapper,
588
+ )
589
+
590
+ auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(accelerator.state.fsdp_plugin, model)
591
+
592
+ for layer_name, layer in get_module_children_bottom_up(model, return_fqns=True)[:-1]:
593
+ if len(layer_name.split(".")) > 1:
594
+ parent_name, child_name = layer_name.rsplit(".", 1)
595
+ else:
596
+ parent_name = None
597
+ child_name = layer_name
598
+
599
+ parent_module = model.get_submodule(parent_name) if parent_name else model
600
+ if auto_wrap_policy_func(parent_module):
601
+ layer = checkpoint_wrapper(layer, preserve_rng_state=False)
602
+ parent_module.register_module(child_name, layer)
603
+
604
+ return model
605
+
606
+
607
+ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
608
+ """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
609
+
610
+ Args:
611
+ accelerator (`Accelerator`): The accelerator instance
612
+ model (`torch.nn.Module`): The model to prepare
613
+
614
+ Returns:
615
+ `torch.nn.Module`: Prepared model
616
+ """
617
+ from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard
618
+
619
+ is_type_fsdp = isinstance(model, FSDPModule) or (
620
+ is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)
621
+ )
622
+ if is_type_fsdp:
623
+ return model
624
+
625
+ fsdp2_plugin = accelerator.state.fsdp_plugin
626
+
627
+ fsdp2_plugin.set_auto_wrap_policy(model)
628
+
629
+ original_sd = model.state_dict()
630
+ mesh = getattr(accelerator, "torch_device_mesh", None)
631
+
632
+ fsdp2_kwargs = {
633
+ "reshard_after_forward": fsdp2_plugin.reshard_after_forward,
634
+ "offload_policy": fsdp2_plugin.cpu_offload,
635
+ # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
636
+ "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
637
+ "mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
638
+ "ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device),
639
+ }
640
+
641
+ model_has_params4bit = False
642
+ for name, param in model.named_parameters():
643
+ # this is a temporary fix whereby loading models with bnb params cannot be moved from
644
+ # GPU to a meta device due with FSDP2 because torch operations don't return the original class type
645
+ # bypassing the move to meta will still cause the VRAM spike, but at least it still will load
646
+ if param.__class__.__name__ == "Params4bit":
647
+ model_has_params4bit = True
648
+ break
649
+
650
+ if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
651
+ # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
652
+ # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
653
+ # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU
654
+ # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
655
+
656
+ # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
657
+ # Also, these buffers aren't getting sharded by default
658
+ # We get the FQNs of all non-persistent buffers, to re-register them after
659
+ non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True)
660
+ original_non_persistent_buffers = copy.deepcopy(
661
+ {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
662
+ )
663
+ # We move the model to meta device, as then sharding happens on meta device
664
+ model = model.to(torch.device("meta"))
665
+ # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
666
+ # We assume `transformers` models have a `tie_weights` method if they support it
667
+ if hasattr(model, "tie_weights"):
668
+ model.tie_weights()
669
+
670
+ auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
671
+ if auto_wrap_policy_func is not None:
672
+ # We skip the model itself, as that one is always wrapped
673
+ for module in get_module_children_bottom_up(model)[:-1]:
674
+ if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule):
675
+ fully_shard(module, **fsdp2_kwargs)
676
+
677
+ if not isinstance(model, FSDPModule):
678
+ fully_shard(model, **fsdp2_kwargs)
679
+
680
+ if fsdp2_plugin.cpu_ram_efficient_loading:
681
+ # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights
682
+ # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly
683
+ fsdp2_load_full_state_dict(accelerator, model, original_sd)
684
+
685
+ if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
686
+ # We re-register the buffers, as they may not be in the state_dict
687
+ for fqn, buffer_tensor in original_non_persistent_buffers.items():
688
+ buffer_tensor = buffer_tensor.to(accelerator.device)
689
+
690
+ if "." in fqn:
691
+ parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
692
+ parent_module = model.get_submodule(parent_fqn)
693
+ else:
694
+ local_buffer_name = fqn
695
+ parent_module = model
696
+
697
+ parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False)
698
+
699
+ # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
700
+ # Needs to be called both here and above
701
+ # removing this call makes the have slightly different loss
702
+ # removing the call above leads to extra memory usage as explained in the comment above
703
+ if hasattr(model, "tie_weights"):
704
+ model.tie_weights()
705
+
706
+ # There is no `dtype` attribution for nn.Module
707
+ # Set it to None if it doesn't exist and do the upcast always
708
+ model_dtype = getattr(model, "dtype", None)
709
+ if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
710
+ # We upcast the model according to `deepspeed`'s implementation
711
+ # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
712
+ model = model.to(torch.float32)
713
+ if accelerator.is_main_process:
714
+ # TODO(siro1): Add a warning for each parameter that was upcasted
715
+ warnings.warn(
716
+ "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints."
717
+ )
718
+ return model
719
+
720
+
721
+ def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module) -> Callable[[torch.nn.Module], bool]:
722
+ """Prepares the auto wrap policy based on its type, done to mimic the behaviour of FSDP1 auto wrap policy.
723
+
724
+ Args:
725
+ fsdp2_plugin (`FullyShardedDataParallelPlugin`):
726
+ Instance of `FullyShardedDataParallelPlugin` containing the configuration options
727
+ auto_wrap_policy_type (`str`):
728
+ Either `transformer` or `size`
729
+ model (`torch.nn.Module`):
730
+ The model to wrap
731
+
732
+ Returns:
733
+ `Callable[[torch.nn.Module], bool]`:
734
+ The auto wrap policy function to be applied to the model
735
+ """
736
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
737
+
738
+ fn = fsdp2_plugin.auto_wrap_policy
739
+
740
+ if isinstance(fn, functools.partial):
741
+ fn = fn.func
742
+
743
+ if fn is transformer_auto_wrap_policy:
744
+ no_split_modules = getattr(model, "_no_split_modules", None)
745
+ if no_split_modules is None:
746
+ no_split_modules = []
747
+ transformer_cls_names_to_wrap = list(no_split_modules)
748
+ if fsdp2_plugin.transformer_cls_names_to_wrap is not None:
749
+ transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap
750
+ transformer_cls_to_wrap = set()
751
+
752
+ for layer_class in transformer_cls_names_to_wrap:
753
+ transformer_cls = get_module_class_from_name(model, layer_class)
754
+ if transformer_cls is None:
755
+ raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.")
756
+ transformer_cls_to_wrap.add(transformer_cls)
757
+
758
+ def policy(module: torch.nn.Module) -> bool:
759
+ if fsdp2_plugin.transformer_cls_names_to_wrap is None:
760
+ return False
761
+ return isinstance(module, tuple(transformer_cls_to_wrap))
762
+
763
+ elif fn is size_based_auto_wrap_policy:
764
+
765
+ def policy(module: torch.nn.Module) -> bool:
766
+ module_num_params = sum(p.numel() for p in module.parameters())
767
+ return module_num_params > fsdp2_plugin.min_num_params
768
+ else:
769
+ return None
770
+
771
+ return policy
772
+
773
+
774
+ def get_fsdp2_grad_scaler(**kwargs):
775
+ """
776
+ Returns a `GradScaler` for FSDP2, as the current implementation of `get_grad_scaler` doesn't accept other args. We
777
+ need this as current `get_grad_scaler` accepts only `distributed_type` as arg, which doesn't differentiate between
778
+ FSDP1 and FSDP2
779
+ """
780
+ from torch.amp.grad_scaler import GradScaler
781
+
782
+ return GradScaler(**kwargs)
783
+
784
+
785
+ def fsdp2_canonicalize_names(named_params: dict) -> dict:
786
+ """Removes parameter name modifiers in order to map them back to their original names.
787
+
788
+ See huggingface/accelerate#3554 for more context.
789
+
790
+ Args:
791
+ named_params (`dict`): The named parameters dictionary to canonicalize.
792
+
793
+ Returns:
794
+ `dict`: The canonicalized named parameters dictionary
795
+ """
796
+ named_params = {k.replace("._checkpoint_wrapped_module", ""): v for k, v in named_params.items()}
797
+ named_params = {
798
+ k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k: v for k, v in named_params.items()
799
+ }
800
+ named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
801
+ return named_params
802
+
803
+
804
+ def get_parameters_from_modules(
805
+ modules: Union[Iterable[torch.nn.Module], str], model, device
806
+ ) -> set[torch.nn.Parameter]:
807
+ """Converts modules to parameters where modules can be a string or list of torch.nn.Module
808
+
809
+ Args:
810
+ modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
811
+
812
+ Returns:
813
+ `set[torch.nn.Parameter]`: List of parameters
814
+ """
815
+ if modules is None:
816
+ return set()
817
+ parameters = []
818
+ # code taken from accelerate while preparing kwargs for FSDP
819
+ if isinstance(modules, str):
820
+ reg = re.compile(modules)
821
+ mapped_modules = []
822
+ for name, module in model.named_modules():
823
+ if reg.fullmatch(name):
824
+ module.to(device)
825
+ mapped_modules.append(module)
826
+ modules = mapped_modules
827
+ for module in modules:
828
+ parameters.extend(list(module.parameters()))
829
+ return set(parameters)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/imports.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import importlib.metadata
17
+ import os
18
+ import sys
19
+ import warnings
20
+ from functools import lru_cache, wraps
21
+
22
+ import torch
23
+ from packaging import version
24
+ from packaging.version import parse
25
+
26
+ from .environment import parse_flag_from_env, patch_environment, str_to_bool
27
+ from .versions import compare_versions, is_torch_version
28
+
29
+
30
+ # Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
31
+ USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True)
32
+
33
+ _torch_xla_available = False
34
+ if USE_TORCH_XLA:
35
+ try:
36
+ import torch_xla.core.xla_model as xm # noqa: F401
37
+ import torch_xla.runtime
38
+
39
+ _torch_xla_available = True
40
+ except ImportError:
41
+ pass
42
+
43
+ # Keep it for is_tpu_available. It will be removed along with is_tpu_available.
44
+ _tpu_available = _torch_xla_available
45
+
46
+ # Cache this result has it's a C FFI call which can be pretty time-consuming
47
+ _torch_distributed_available = torch.distributed.is_available()
48
+
49
+
50
+ def _is_package_available(pkg_name, metadata_name=None):
51
+ # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
52
+ package_exists = importlib.util.find_spec(pkg_name) is not None
53
+ if package_exists:
54
+ try:
55
+ # Some libraries have different names in the metadata
56
+ _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name)
57
+ return True
58
+ except importlib.metadata.PackageNotFoundError:
59
+ return False
60
+
61
+
62
+ def is_torch_distributed_available() -> bool:
63
+ return _torch_distributed_available
64
+
65
+
66
+ def is_xccl_available():
67
+ if is_torch_version(">=", "2.7.0"):
68
+ return torch.distributed.distributed_c10d.is_xccl_available()
69
+ if is_ipex_available():
70
+ return False
71
+ return False
72
+
73
+
74
+ def is_ccl_available():
75
+ try:
76
+ pass
77
+ except ImportError:
78
+ print(
79
+ "Intel(R) oneCCL Bindings for PyTorch* is required to run DDP on Intel(R) XPUs, but it is not"
80
+ " detected. If you see \"ValueError: Invalid backend: 'ccl'\" error, please install Intel(R) oneCCL"
81
+ " Bindings for PyTorch*."
82
+ )
83
+ return importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
84
+
85
+
86
+ def get_ccl_version():
87
+ return importlib.metadata.version("oneccl_bind_pt")
88
+
89
+
90
+ def is_import_timer_available():
91
+ return _is_package_available("import_timer")
92
+
93
+
94
+ def is_pynvml_available():
95
+ return _is_package_available("pynvml") or _is_package_available("pynvml", "nvidia-ml-py")
96
+
97
+
98
+ def is_pytest_available():
99
+ return _is_package_available("pytest")
100
+
101
+
102
+ def is_msamp_available():
103
+ return _is_package_available("msamp", "ms-amp")
104
+
105
+
106
+ def is_schedulefree_available():
107
+ return _is_package_available("schedulefree")
108
+
109
+
110
+ def is_transformer_engine_available():
111
+ if is_hpu_available():
112
+ return _is_package_available("intel_transformer_engine", "intel-transformer-engine")
113
+ else:
114
+ return _is_package_available("transformer_engine", "transformer-engine")
115
+
116
+
117
+ def is_transformer_engine_mxfp8_available():
118
+ if _is_package_available("transformer_engine", "transformer-engine"):
119
+ import transformer_engine.pytorch as te
120
+
121
+ return te.fp8.check_mxfp8_support()[0]
122
+ return False
123
+
124
+
125
+ def is_lomo_available():
126
+ return _is_package_available("lomo_optim")
127
+
128
+
129
+ def is_cuda_available():
130
+ """
131
+ Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda
132
+ uninitialized.
133
+ """
134
+ with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"):
135
+ available = torch.cuda.is_available()
136
+
137
+ return available
138
+
139
+
140
+ @lru_cache
141
+ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
142
+ """
143
+ Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
144
+ the USE_TORCH_XLA to false.
145
+ """
146
+ assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
147
+
148
+ if not _torch_xla_available:
149
+ return False
150
+ elif check_is_gpu:
151
+ return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
152
+ elif check_is_tpu:
153
+ return torch_xla.runtime.device_type() == "TPU"
154
+
155
+ return True
156
+
157
+
158
+ def is_torchao_available():
159
+ package_exists = _is_package_available("torchao")
160
+ if package_exists:
161
+ torchao_version = version.parse(importlib.metadata.version("torchao"))
162
+ return compare_versions(torchao_version, ">=", "0.6.1")
163
+ return False
164
+
165
+
166
+ def is_deepspeed_available():
167
+ return _is_package_available("deepspeed")
168
+
169
+
170
+ def is_pippy_available():
171
+ return is_torch_version(">=", "2.4.0")
172
+
173
+
174
+ def is_bf16_available(ignore_tpu=False):
175
+ "Checks if bf16 is supported, optionally ignoring the TPU"
176
+ if is_torch_xla_available(check_is_tpu=True):
177
+ return not ignore_tpu
178
+ if is_cuda_available():
179
+ return torch.cuda.is_bf16_supported()
180
+ if is_mlu_available():
181
+ return torch.mlu.is_bf16_supported()
182
+ if is_xpu_available():
183
+ return torch.xpu.is_bf16_supported()
184
+ if is_mps_available():
185
+ return torch.backends.mps.is_macos_or_newer(14, 0)
186
+ return True
187
+
188
+
189
+ def is_fp16_available():
190
+ "Checks if fp16 is supported"
191
+ if is_habana_gaudi1():
192
+ return False
193
+
194
+ return True
195
+
196
+
197
+ def is_fp8_available():
198
+ "Checks if fp8 is supported"
199
+ return is_msamp_available() or is_transformer_engine_available() or is_torchao_available()
200
+
201
+
202
+ def is_4bit_bnb_available():
203
+ package_exists = _is_package_available("bitsandbytes")
204
+ if package_exists:
205
+ bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
206
+ return compare_versions(bnb_version, ">=", "0.39.0")
207
+ return False
208
+
209
+
210
+ def is_8bit_bnb_available():
211
+ package_exists = _is_package_available("bitsandbytes")
212
+ if package_exists:
213
+ bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
214
+ return compare_versions(bnb_version, ">=", "0.37.2")
215
+ return False
216
+
217
+
218
+ def is_bnb_available(min_version=None):
219
+ package_exists = _is_package_available("bitsandbytes")
220
+ if package_exists and min_version is not None:
221
+ bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
222
+ return compare_versions(bnb_version, ">=", min_version)
223
+ else:
224
+ return package_exists
225
+
226
+
227
+ def is_bitsandbytes_multi_backend_available():
228
+ if not is_bnb_available():
229
+ return False
230
+ import bitsandbytes as bnb
231
+
232
+ return "multi_backend" in getattr(bnb, "features", set())
233
+
234
+
235
+ def is_torchvision_available():
236
+ return _is_package_available("torchvision")
237
+
238
+
239
+ def is_megatron_lm_available():
240
+ if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
241
+ if importlib.util.find_spec("megatron") is not None:
242
+ try:
243
+ megatron_version = parse(importlib.metadata.version("megatron-core"))
244
+ if compare_versions(megatron_version, ">=", "0.8.0"):
245
+ return importlib.util.find_spec(".training", "megatron")
246
+ except Exception as e:
247
+ warnings.warn(f"Parse Megatron version failed. Exception:{e}")
248
+ return False
249
+
250
+
251
+ def is_transformers_available():
252
+ return _is_package_available("transformers")
253
+
254
+
255
+ def is_datasets_available():
256
+ return _is_package_available("datasets")
257
+
258
+
259
+ def is_peft_available():
260
+ return _is_package_available("peft")
261
+
262
+
263
+ def is_timm_available():
264
+ return _is_package_available("timm")
265
+
266
+
267
+ def is_triton_available():
268
+ if is_xpu_available():
269
+ return _is_package_available("triton", "pytorch-triton-xpu")
270
+ return _is_package_available("triton")
271
+
272
+
273
+ def is_aim_available():
274
+ package_exists = _is_package_available("aim")
275
+ if package_exists:
276
+ aim_version = version.parse(importlib.metadata.version("aim"))
277
+ return compare_versions(aim_version, "<", "4.0.0")
278
+ return False
279
+
280
+
281
+ def is_tensorboard_available():
282
+ return _is_package_available("tensorboard") or _is_package_available("tensorboardX")
283
+
284
+
285
+ def is_wandb_available():
286
+ return _is_package_available("wandb")
287
+
288
+
289
+ def is_comet_ml_available():
290
+ return _is_package_available("comet_ml")
291
+
292
+
293
+ def is_swanlab_available():
294
+ return _is_package_available("swanlab")
295
+
296
+
297
+ def is_trackio_available():
298
+ return sys.version_info >= (3, 10) and _is_package_available("trackio")
299
+
300
+
301
+ def is_boto3_available():
302
+ return _is_package_available("boto3")
303
+
304
+
305
+ def is_rich_available():
306
+ if _is_package_available("rich"):
307
+ return parse_flag_from_env("ACCELERATE_ENABLE_RICH", False)
308
+ return False
309
+
310
+
311
+ def is_sagemaker_available():
312
+ return _is_package_available("sagemaker")
313
+
314
+
315
+ def is_tqdm_available():
316
+ return _is_package_available("tqdm")
317
+
318
+
319
+ def is_clearml_available():
320
+ return _is_package_available("clearml")
321
+
322
+
323
+ def is_pandas_available():
324
+ return _is_package_available("pandas")
325
+
326
+
327
+ def is_matplotlib_available():
328
+ return _is_package_available("matplotlib")
329
+
330
+
331
+ def is_mlflow_available():
332
+ if _is_package_available("mlflow"):
333
+ return True
334
+
335
+ if importlib.util.find_spec("mlflow") is not None:
336
+ try:
337
+ _ = importlib.metadata.metadata("mlflow-skinny")
338
+ return True
339
+ except importlib.metadata.PackageNotFoundError:
340
+ return False
341
+ return False
342
+
343
+
344
+ def is_mps_available(min_version="1.12"):
345
+ "Checks if MPS device is available. The minimum version required is 1.12."
346
+ # With torch 1.12, you can use torch.backends.mps
347
+ # With torch 2.0.0, you can use torch.mps
348
+ return is_torch_version(">=", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built()
349
+
350
+
351
+ def is_ipex_available():
352
+ "Checks if ipex is installed."
353
+
354
+ def get_major_and_minor_from_version(full_version):
355
+ return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
356
+
357
+ _torch_version = importlib.metadata.version("torch")
358
+ if importlib.util.find_spec("intel_extension_for_pytorch") is None:
359
+ return False
360
+ _ipex_version = "N/A"
361
+ try:
362
+ _ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
363
+ except importlib.metadata.PackageNotFoundError:
364
+ return False
365
+ torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
366
+ ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
367
+ if torch_major_and_minor != ipex_major_and_minor:
368
+ warnings.warn(
369
+ f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
370
+ f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
371
+ )
372
+ return False
373
+ return True
374
+
375
+
376
+ @lru_cache
377
+ def is_mlu_available(check_device=False):
378
+ """
379
+ Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
380
+ uninitialized.
381
+ """
382
+ if importlib.util.find_spec("torch_mlu") is None:
383
+ return False
384
+
385
+ import torch_mlu # noqa: F401
386
+
387
+ with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"):
388
+ available = torch.mlu.is_available()
389
+
390
+ return available
391
+
392
+
393
+ @lru_cache
394
+ def is_musa_available(check_device=False):
395
+ "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
396
+ if importlib.util.find_spec("torch_musa") is None:
397
+ return False
398
+
399
+ import torch_musa # noqa: F401
400
+
401
+ if check_device:
402
+ try:
403
+ # Will raise a RuntimeError if no MUSA is found
404
+ _ = torch.musa.device_count()
405
+ return torch.musa.is_available()
406
+ except RuntimeError:
407
+ return False
408
+ return hasattr(torch, "musa") and torch.musa.is_available()
409
+
410
+
411
+ @lru_cache
412
+ def is_npu_available(check_device=False):
413
+ "Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
414
+ if importlib.util.find_spec("torch_npu") is None:
415
+ return False
416
+
417
+ # NOTE: importing torch_npu may raise error in some envs
418
+ # e.g. inside cpu-only container with torch_npu installed
419
+ try:
420
+ import torch_npu # noqa: F401
421
+ except Exception:
422
+ return False
423
+
424
+ if check_device:
425
+ try:
426
+ # Will raise a RuntimeError if no NPU is found
427
+ _ = torch.npu.device_count()
428
+ return torch.npu.is_available()
429
+ except RuntimeError:
430
+ return False
431
+ return hasattr(torch, "npu") and torch.npu.is_available()
432
+
433
+
434
+ @lru_cache
435
+ def is_sdaa_available(check_device=False):
436
+ "Checks if `torch_sdaa` is installed and potentially if a SDAA is in the environment"
437
+ if importlib.util.find_spec("torch_sdaa") is None:
438
+ return False
439
+
440
+ import torch_sdaa # noqa: F401
441
+
442
+ if check_device:
443
+ try:
444
+ # Will raise a RuntimeError if no NPU is found
445
+ _ = torch.sdaa.device_count()
446
+ return torch.sdaa.is_available()
447
+ except RuntimeError:
448
+ return False
449
+ return hasattr(torch, "sdaa") and torch.sdaa.is_available()
450
+
451
+
452
+ @lru_cache
453
+ def is_hpu_available(init_hccl=False):
454
+ "Checks if `torch.hpu` is installed and potentially if a HPU is in the environment"
455
+ if (
456
+ importlib.util.find_spec("habana_frameworks") is None
457
+ or importlib.util.find_spec("habana_frameworks.torch") is None
458
+ ):
459
+ return False
460
+
461
+ import habana_frameworks.torch # noqa: F401
462
+
463
+ if init_hccl:
464
+ import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401
465
+
466
+ return hasattr(torch, "hpu") and torch.hpu.is_available()
467
+
468
+
469
+ def is_habana_gaudi1():
470
+ if is_hpu_available():
471
+ import habana_frameworks.torch.utils.experimental as htexp # noqa: F401
472
+
473
+ if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi:
474
+ return True
475
+
476
+ return False
477
+
478
+
479
+ @lru_cache
480
+ def is_xpu_available(check_device=False):
481
+ """
482
+ Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or via stock PyTorch (>=2.4) and
483
+ potentially if a XPU is in the environment
484
+ """
485
+
486
+ if is_ipex_available():
487
+ import intel_extension_for_pytorch # noqa: F401
488
+ else:
489
+ if is_torch_version("<=", "2.3"):
490
+ return False
491
+
492
+ if check_device:
493
+ try:
494
+ # Will raise a RuntimeError if no XPU is found
495
+ _ = torch.xpu.device_count()
496
+ return torch.xpu.is_available()
497
+ except RuntimeError:
498
+ return False
499
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
500
+
501
+
502
+ def is_dvclive_available():
503
+ return _is_package_available("dvclive")
504
+
505
+
506
+ def is_torchdata_available():
507
+ return _is_package_available("torchdata")
508
+
509
+
510
+ # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
511
+ def is_torchdata_stateful_dataloader_available():
512
+ package_exists = _is_package_available("torchdata")
513
+ if package_exists:
514
+ torchdata_version = version.parse(importlib.metadata.version("torchdata"))
515
+ return compare_versions(torchdata_version, ">=", "0.8.0")
516
+ return False
517
+
518
+
519
+ def torchao_required(func):
520
+ """
521
+ A decorator that ensures the decorated function is only called when torchao is available.
522
+ """
523
+
524
+ @wraps(func)
525
+ def wrapper(*args, **kwargs):
526
+ if not is_torchao_available():
527
+ raise ImportError(
528
+ "`torchao` is not available, please install it before calling this function via `pip install torchao`."
529
+ )
530
+ return func(*args, **kwargs)
531
+
532
+ return wrapper
533
+
534
+
535
+ # TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed`
536
+ def deepspeed_required(func):
537
+ """
538
+ A decorator that ensures the decorated function is only called when deepspeed is enabled.
539
+ """
540
+
541
+ @wraps(func)
542
+ def wrapper(*args, **kwargs):
543
+ from accelerate.state import AcceleratorState
544
+ from accelerate.utils.dataclasses import DistributedType
545
+
546
+ if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type != DistributedType.DEEPSPEED:
547
+ raise ValueError(
548
+ "DeepSpeed is not enabled, please make sure that an `Accelerator` is configured for `deepspeed` "
549
+ "before calling this function."
550
+ )
551
+ return func(*args, **kwargs)
552
+
553
+ return wrapper
554
+
555
+
556
+ def is_weights_only_available():
557
+ # Weights only with allowlist was added in 2.4.0
558
+ # ref: https://github.com/pytorch/pytorch/pull/124331
559
+ return is_torch_version(">=", "2.4.0")
560
+
561
+
562
+ def is_numpy_available(min_version="1.25.0"):
563
+ numpy_version = parse(importlib.metadata.version("numpy"))
564
+ return compare_versions(numpy_version, ">=", min_version)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/launch.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+ import subprocess
18
+ import sys
19
+ import warnings
20
+ from ast import literal_eval
21
+ from shutil import which
22
+ from typing import Any
23
+
24
+ import torch
25
+
26
+ from ..commands.config.config_args import SageMakerConfig
27
+ from ..utils import (
28
+ DynamoBackend,
29
+ PrecisionType,
30
+ is_ccl_available,
31
+ is_fp8_available,
32
+ is_hpu_available,
33
+ is_ipex_available,
34
+ is_mlu_available,
35
+ is_musa_available,
36
+ is_npu_available,
37
+ is_sdaa_available,
38
+ is_torch_xla_available,
39
+ is_xpu_available,
40
+ )
41
+ from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
42
+ from ..utils.other import get_free_port, is_port_in_use, merge_dicts
43
+ from ..utils.versions import compare_versions
44
+ from .dataclasses import DistributedType, SageMakerDistributedType
45
+
46
+
47
+ def _filter_args(args, parser, default_args=[]):
48
+ """
49
+ Filters out all `accelerate` specific args
50
+ """
51
+ new_args, _ = parser.parse_known_args(default_args)
52
+ for key, value in vars(args).items():
53
+ if key in vars(new_args).keys():
54
+ setattr(new_args, key, value)
55
+ return new_args
56
+
57
+
58
+ def _get_mpirun_args():
59
+ """
60
+ Determines the executable and argument names for mpirun, based on the type of install. The supported MPI programs
61
+ are: OpenMPI, Intel MPI, or MVAPICH.
62
+
63
+ Returns: Program name and arg names for hostfile, num processes, and processes per node
64
+ """
65
+ # Find the MPI program name
66
+ mpi_apps = [x for x in ["mpirun", "mpiexec"] if which(x)]
67
+
68
+ if len(mpi_apps) == 0:
69
+ raise OSError("mpirun or mpiexec were not found. Ensure that Intel MPI, Open MPI, or MVAPICH are installed.")
70
+
71
+ # Call the app with the --version flag to determine which MPI app is installed
72
+ mpi_app = mpi_apps[0]
73
+ mpirun_version = subprocess.check_output([mpi_app, "--version"])
74
+
75
+ if b"Open MPI" in mpirun_version:
76
+ return mpi_app, "--hostfile", "-n", "--npernode", "--bind-to"
77
+ else:
78
+ # Intel MPI and MVAPICH both use the same arg names
79
+ return mpi_app, "-f", "-n", "-ppn", ""
80
+
81
+
82
+ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
83
+ """
84
+ Setup the FP8 environment variables.
85
+ """
86
+ prefix = "ACCELERATE_"
87
+ for arg in vars(args):
88
+ if arg.startswith("fp8_"):
89
+ value = getattr(args, arg)
90
+ if value is not None:
91
+ if arg == "fp8_override_linear_precision":
92
+ current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
93
+ current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
94
+ current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
95
+ else:
96
+ current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
97
+ return current_env
98
+
99
+
100
+ def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:
101
+ """
102
+ Prepares and returns the command list and an environment with the correct simple launcher environment variables.
103
+ """
104
+ cmd = []
105
+ if args.no_python and args.module:
106
+ raise ValueError("--module and --no_python cannot be used together")
107
+
108
+ num_processes = getattr(args, "num_processes", None)
109
+ num_machines = args.num_machines
110
+ if args.mpirun_hostfile is not None:
111
+ mpi_app_name, hostfile_arg, num_proc_arg, proc_per_node_arg, bind_to_arg = _get_mpirun_args()
112
+ bind_to = getattr(args, "bind-to", "socket")
113
+ nproc_per_node = str(num_processes // num_machines) if num_processes and num_machines else "1"
114
+ cmd += [
115
+ mpi_app_name,
116
+ hostfile_arg,
117
+ args.mpirun_hostfile,
118
+ proc_per_node_arg,
119
+ nproc_per_node,
120
+ ]
121
+ if num_processes:
122
+ cmd += [num_proc_arg, str(num_processes)]
123
+ if bind_to_arg:
124
+ cmd += [bind_to_arg, bind_to]
125
+ if not args.no_python:
126
+ cmd.append(sys.executable)
127
+ if args.module:
128
+ cmd.append("-m")
129
+ cmd.append(args.training_script)
130
+ cmd.extend(args.training_script_args)
131
+
132
+ current_env = os.environ.copy()
133
+ current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu)
134
+ if args.debug:
135
+ current_env["ACCELERATE_DEBUG_MODE"] = "true"
136
+ if args.gpu_ids != "all" and args.gpu_ids is not None:
137
+ if is_xpu_available():
138
+ current_env["ZE_AFFINITY_MASK"] = args.gpu_ids
139
+ elif is_mlu_available():
140
+ current_env["MLU_VISIBLE_DEVICES"] = args.gpu_ids
141
+ elif is_sdaa_available():
142
+ current_env["SDAA_VISIBLE_DEVICES"] = args.gpu_ids
143
+ elif is_musa_available():
144
+ current_env["MUSA_VISIBLE_DEVICES"] = args.gpu_ids
145
+ elif is_npu_available():
146
+ current_env["ASCEND_RT_VISIBLE_DEVICES"] = args.gpu_ids
147
+ elif is_hpu_available():
148
+ current_env["HABANA_VISIBLE_MODULES"] = args.gpu_ids
149
+ else:
150
+ current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
151
+ if num_machines > 1:
152
+ assert args.main_process_ip is not None, (
153
+ "When using multiple machines, you need to specify the main process IP."
154
+ )
155
+ assert args.main_process_port is not None, (
156
+ "When using multiple machines, you need to specify the main process port."
157
+ )
158
+
159
+ ccl_worker_count = getattr(args, "mpirun_ccl", 0) if is_ccl_available() else 0
160
+ if (num_processes is not None and num_processes > 1) or num_machines > 1:
161
+ current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1"
162
+ current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500"
163
+ current_env["CCL_WORKER_COUNT"] = str(ccl_worker_count)
164
+ if current_env["ACCELERATE_USE_CPU"]:
165
+ current_env["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
166
+ current_env["KMP_BLOCKTIME"] = str(1)
167
+
168
+ try:
169
+ mixed_precision = PrecisionType(args.mixed_precision.lower())
170
+ except ValueError:
171
+ raise ValueError(
172
+ f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
173
+ )
174
+
175
+ current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
176
+ if args.mixed_precision.lower() == "fp8":
177
+ if not is_fp8_available():
178
+ raise RuntimeError(
179
+ "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
180
+ )
181
+ current_env = setup_fp8_env(args, current_env)
182
+
183
+ try:
184
+ dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
185
+ except ValueError:
186
+ raise ValueError(
187
+ f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
188
+ )
189
+ current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value
190
+ current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode
191
+ current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph)
192
+ current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic)
193
+ current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation)
194
+
195
+ current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
196
+ if is_ipex_available():
197
+ current_env["ACCELERATE_USE_IPEX"] = str(args.ipex).lower()
198
+ if args.enable_cpu_affinity:
199
+ current_env["ACCELERATE_CPU_AFFINITY"] = "1"
200
+ return cmd, current_env
201
+
202
+
203
+ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
204
+ """
205
+ Prepares and returns an environment with the correct multi-GPU environment variables.
206
+ """
207
+ # get free port and update configurations
208
+ if args.main_process_port == 0:
209
+ args.main_process_port = get_free_port()
210
+
211
+ elif args.main_process_port is None:
212
+ args.main_process_port = 29500
213
+
214
+ num_processes = args.num_processes
215
+ num_machines = args.num_machines
216
+ main_process_ip = args.main_process_ip
217
+ main_process_port = args.main_process_port
218
+ if num_machines > 1:
219
+ args.nproc_per_node = str(num_processes // num_machines)
220
+ args.nnodes = str(num_machines)
221
+ args.node_rank = int(args.machine_rank)
222
+ if getattr(args, "same_network", False):
223
+ args.master_addr = str(main_process_ip)
224
+ args.master_port = str(main_process_port)
225
+ else:
226
+ args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}"
227
+ else:
228
+ args.nproc_per_node = str(num_processes)
229
+ if main_process_port is not None:
230
+ args.master_port = str(main_process_port)
231
+
232
+ # only need to check port availability in main process, in case we have to start multiple launchers on the same machine
233
+ # for some reasons like splitting log files.
234
+ need_port_check = num_machines <= 1 or int(args.machine_rank) == 0
235
+ if need_port_check and is_port_in_use(main_process_port):
236
+ if num_machines <= 1:
237
+ args.standalone = True
238
+ warnings.warn(
239
+ f"Port `{main_process_port}` is already in use. "
240
+ "Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. "
241
+ "If this current attempt fails, or for more control in future runs, please specify a different port "
242
+ "(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection "
243
+ "in your launch command or Accelerate config file."
244
+ )
245
+ else:
246
+ raise ConnectionError(
247
+ f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. "
248
+ "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)"
249
+ " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`."
250
+ )
251
+
252
+ if args.module and args.no_python:
253
+ raise ValueError("--module and --no_python cannot be used together")
254
+ elif args.module:
255
+ args.module = True
256
+ elif args.no_python:
257
+ args.no_python = True
258
+
259
+ current_env = os.environ.copy()
260
+ if args.debug:
261
+ current_env["ACCELERATE_DEBUG_MODE"] = "true"
262
+ gpu_ids = getattr(args, "gpu_ids", "all")
263
+ if gpu_ids != "all" and args.gpu_ids is not None:
264
+ if is_xpu_available():
265
+ current_env["ZE_AFFINITY_MASK"] = gpu_ids
266
+ elif is_mlu_available():
267
+ current_env["MLU_VISIBLE_DEVICES"] = gpu_ids
268
+ elif is_sdaa_available():
269
+ current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids
270
+ elif is_musa_available():
271
+ current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids
272
+ elif is_npu_available():
273
+ current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids
274
+ elif is_hpu_available():
275
+ current_env["HABANA_VISIBLE_MODULES"] = gpu_ids
276
+ else:
277
+ current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
278
+ mixed_precision = args.mixed_precision.lower()
279
+ try:
280
+ mixed_precision = PrecisionType(mixed_precision)
281
+ except ValueError:
282
+ raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")
283
+
284
+ current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
285
+ if args.mixed_precision.lower() == "fp8":
286
+ if not is_fp8_available():
287
+ raise RuntimeError(
288
+ "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
289
+ )
290
+ current_env = setup_fp8_env(args, current_env)
291
+
292
+ try:
293
+ dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
294
+ except ValueError:
295
+ raise ValueError(
296
+ f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
297
+ )
298
+ current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value
299
+ current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode
300
+ current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph)
301
+ current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic)
302
+ current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation)
303
+
304
+ if args.use_fsdp:
305
+ current_env["ACCELERATE_USE_FSDP"] = "true"
306
+ if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states:
307
+ raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`")
308
+
309
+ current_env["FSDP_VERSION"] = str(args.fsdp_version) if hasattr(args, "fsdp_version") else "1"
310
+
311
+ # For backwards compatibility, we support this in launched scripts,
312
+ # however, we do not ask users for this in `accelerate config` CLI
313
+ current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy)
314
+
315
+ current_env["FSDP_RESHARD_AFTER_FORWARD"] = str(args.fsdp_reshard_after_forward).lower()
316
+ current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower()
317
+ current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params)
318
+ if args.fsdp_auto_wrap_policy is not None:
319
+ current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy)
320
+ if args.fsdp_transformer_layer_cls_to_wrap is not None:
321
+ current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap)
322
+ if args.fsdp_backward_prefetch is not None:
323
+ current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch)
324
+ if args.fsdp_state_dict_type is not None:
325
+ current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)
326
+ current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower()
327
+ current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower()
328
+ current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
329
+ current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
330
+ current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
331
+ if getattr(args, "fsdp_ignored_modules", None) is not None:
332
+ current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
333
+
334
+ if args.use_megatron_lm:
335
+ prefix = "MEGATRON_LM_"
336
+ current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"
337
+ current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree)
338
+ current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree)
339
+ current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping)
340
+ if args.megatron_lm_num_micro_batches is not None:
341
+ current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches)
342
+ if args.megatron_lm_sequence_parallelism is not None:
343
+ current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism)
344
+ if args.megatron_lm_recompute_activations is not None:
345
+ current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations)
346
+ if args.megatron_lm_use_distributed_optimizer is not None:
347
+ current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer)
348
+
349
+ current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
350
+ if args.enable_cpu_affinity:
351
+ current_env["ACCELERATE_CPU_AFFINITY"] = "1"
352
+
353
+ if args.use_parallelism_config:
354
+ current_env = prepare_extend_env_parallelism_config(args, current_env)
355
+
356
+ return current_env
357
+
358
+
359
+ def prepare_extend_env_parallelism_config(
360
+ args: argparse.Namespace, current_env: dict
361
+ ) -> tuple[list[str], dict[str, str]]:
362
+ """
363
+ Extends `current_env` with context parallelism env vars if any have been set
364
+ """
365
+
366
+ prefix = "PARALLELISM_CONFIG_"
367
+
368
+ current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
369
+ current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size)
370
+ current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size)
371
+ current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size)
372
+ current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size)
373
+ current_env[prefix + "CP_BACKEND"] = str(args.parallelism_config_cp_backend)
374
+ current_env[prefix + "SP_SIZE"] = str(args.parallelism_config_sp_size)
375
+ current_env[prefix + "SP_BACKEND"] = str(args.parallelism_config_sp_backend)
376
+ if args.parallelism_config_cp_size > 1:
377
+ current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy)
378
+ if args.parallelism_config_sp_size > 1:
379
+ current_env[prefix + "SP_SEQ_LENGTH"] = str(args.parallelism_config_sp_seq_length)
380
+ current_env[prefix + "SP_SEQ_LENGTH_IS_VARIABLE"] = str(args.parallelism_config_sp_seq_length_is_variable)
381
+ current_env[prefix + "SP_ATTN_IMPLEMENTATION"] = str(args.parallelism_config_sp_attn_implementation)
382
+
383
+ return current_env
384
+
385
+
386
+ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:
387
+ """
388
+ Prepares and returns the command list and an environment with the correct DeepSpeed environment variables.
389
+ """
390
+ # get free port and update configurations
391
+ if args.main_process_port == 0:
392
+ args.main_process_port = get_free_port()
393
+
394
+ elif args.main_process_port is None:
395
+ args.main_process_port = 29500
396
+
397
+ num_processes = args.num_processes
398
+ num_machines = args.num_machines
399
+ main_process_ip = args.main_process_ip
400
+ main_process_port = args.main_process_port
401
+ cmd = None
402
+
403
+ # make sure launcher is not None
404
+ if args.deepspeed_multinode_launcher is None:
405
+ # set to default pdsh
406
+ args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0]
407
+
408
+ if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
409
+ cmd = ["deepspeed"]
410
+ cmd.extend(["--hostfile", str(args.deepspeed_hostfile)])
411
+ if args.deepspeed_multinode_launcher == "nossh":
412
+ if compare_versions("deepspeed", "<", "0.14.5"):
413
+ raise ValueError("nossh launcher requires DeepSpeed >= 0.14.5")
414
+ cmd.extend(["--node_rank", str(args.machine_rank), "--no_ssh"])
415
+ else:
416
+ cmd.extend(["--no_local_rank", "--launcher", str(args.deepspeed_multinode_launcher)])
417
+ if args.deepspeed_exclusion_filter is not None:
418
+ cmd.extend(
419
+ [
420
+ "--exclude",
421
+ str(args.deepspeed_exclusion_filter),
422
+ ]
423
+ )
424
+ elif args.deepspeed_inclusion_filter is not None:
425
+ cmd.extend(
426
+ [
427
+ "--include",
428
+ str(args.deepspeed_inclusion_filter),
429
+ ]
430
+ )
431
+ else:
432
+ cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)])
433
+ if main_process_ip:
434
+ cmd.extend(["--master_addr", str(main_process_ip)])
435
+ cmd.extend(["--master_port", str(main_process_port)])
436
+ if args.module and args.no_python:
437
+ raise ValueError("--module and --no_python cannot be used together")
438
+ elif args.module:
439
+ cmd.append("--module")
440
+ elif args.no_python:
441
+ cmd.append("--no_python")
442
+ cmd.append(args.training_script)
443
+ cmd.extend(args.training_script_args)
444
+ elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]:
445
+ args.nproc_per_node = str(num_processes // num_machines)
446
+ args.nnodes = str(num_machines)
447
+ args.node_rank = int(args.machine_rank)
448
+ if getattr(args, "same_network", False):
449
+ args.master_addr = str(main_process_ip)
450
+ args.master_port = str(main_process_port)
451
+ else:
452
+ args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}"
453
+ else:
454
+ args.nproc_per_node = str(num_processes)
455
+ if main_process_port is not None:
456
+ args.master_port = str(main_process_port)
457
+
458
+ # only need to check port availability in main process, in case we have to start multiple launchers on the same machine
459
+ # for some reasons like splitting log files.
460
+ need_port_check = num_machines <= 1 or int(args.machine_rank) == 0
461
+ if need_port_check and is_port_in_use(main_process_port):
462
+ if num_machines <= 1:
463
+ args.standalone = True
464
+ warnings.warn(
465
+ f"Port `{main_process_port}` is already in use. "
466
+ "Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. "
467
+ "If this current attempt fails, or for more control in future runs, please specify a different port "
468
+ "(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection "
469
+ "in your launch command or Accelerate config file."
470
+ )
471
+ else:
472
+ raise ConnectionError(
473
+ f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. "
474
+ "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)"
475
+ " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`."
476
+ )
477
+
478
+ if args.module and args.no_python:
479
+ raise ValueError("--module and --no_python cannot be used together")
480
+ elif args.module:
481
+ args.module = True
482
+ elif args.no_python:
483
+ args.no_python = True
484
+
485
+ current_env = os.environ.copy()
486
+ if args.debug:
487
+ current_env["ACCELERATE_DEBUG_MODE"] = "true"
488
+ gpu_ids = getattr(args, "gpu_ids", "all")
489
+ if gpu_ids != "all" and args.gpu_ids is not None:
490
+ if is_xpu_available():
491
+ current_env["ZE_AFFINITY_MASK"] = gpu_ids
492
+ elif is_mlu_available():
493
+ current_env["MLU_VISIBLE_DEVICES"] = gpu_ids
494
+ elif is_sdaa_available():
495
+ current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids
496
+ elif is_musa_available():
497
+ current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids
498
+ elif is_npu_available():
499
+ current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids
500
+ elif is_hpu_available():
501
+ current_env["HABANA_VISIBLE_MODULES"] = gpu_ids
502
+ else:
503
+ current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
504
+ try:
505
+ mixed_precision = PrecisionType(args.mixed_precision.lower())
506
+ except ValueError:
507
+ raise ValueError(
508
+ f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
509
+ )
510
+
511
+ current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath("."))
512
+ current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
513
+ if args.mixed_precision.lower() == "fp8":
514
+ if not is_fp8_available():
515
+ raise RuntimeError(
516
+ "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
517
+ )
518
+ current_env = setup_fp8_env(args, current_env)
519
+ current_env["ACCELERATE_CONFIG_DS_FIELDS"] = str(args.deepspeed_fields_from_accelerate_config).lower()
520
+ current_env["ACCELERATE_USE_DEEPSPEED"] = "true"
521
+ if args.zero_stage is not None:
522
+ current_env["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage)
523
+ if args.gradient_accumulation_steps is not None:
524
+ current_env["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps)
525
+ if args.gradient_clipping is not None:
526
+ current_env["ACCELERATE_GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower()
527
+ if args.offload_optimizer_device is not None:
528
+ current_env["ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower()
529
+ if args.offload_param_device is not None:
530
+ current_env["ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower()
531
+ if args.zero3_init_flag is not None:
532
+ current_env["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower()
533
+ if args.zero3_save_16bit_model is not None:
534
+ current_env["ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower()
535
+ if args.deepspeed_config_file is not None:
536
+ current_env["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file)
537
+ if args.enable_cpu_affinity:
538
+ current_env["ACCELERATE_CPU_AFFINITY"] = "1"
539
+ if args.deepspeed_moe_layer_cls_names is not None:
540
+ current_env["ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES"] = str(args.deepspeed_moe_layer_cls_names)
541
+
542
+ if args.use_parallelism_config:
543
+ current_env = prepare_extend_env_parallelism_config(args, current_env)
544
+
545
+ return cmd, current_env
546
+
547
+
548
+ def prepare_tpu(
549
+ args: argparse.Namespace, current_env: dict[str, str], pod: bool = False
550
+ ) -> tuple[argparse.Namespace, dict[str, str]]:
551
+ """
552
+ Prepares and returns an environment with the correct TPU environment variables.
553
+ """
554
+ if args.mixed_precision == "bf16" and is_torch_xla_available(check_is_tpu=True):
555
+ if args.downcast_bf16:
556
+ current_env["XLA_DOWNCAST_BF16"] = "1"
557
+ else:
558
+ current_env["XLA_USE_BF16"] = "1"
559
+ if args.debug:
560
+ current_env["ACCELERATE_DEBUG_MODE"] = "true"
561
+ if pod:
562
+ # Take explicit args and set them up for XLA
563
+ args.vm = args.tpu_vm
564
+ args.tpu = args.tpu_name
565
+ return args, current_env
566
+
567
+
568
+ def _convert_nargs_to_dict(nargs: list[str]) -> dict[str, str]:
569
+ if len(nargs) < 0:
570
+ return {}
571
+ # helper function to infer type for argsparser
572
+
573
+ def _infer_type(s):
574
+ try:
575
+ s = float(s)
576
+
577
+ if s // 1 == s:
578
+ return int(s)
579
+ return s
580
+ except ValueError:
581
+ return s
582
+
583
+ parser = argparse.ArgumentParser()
584
+ _, unknown = parser.parse_known_args(nargs)
585
+ for index, argument in enumerate(unknown):
586
+ if argument.startswith(("-", "--")):
587
+ action = None
588
+ if index + 1 < len(unknown): # checks if next index would be in list
589
+ if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key
590
+ # raise an error if element is store_true or store_false
591
+ raise ValueError(
592
+ "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
593
+ )
594
+ else: # raise an error if last element is store_true or store_false
595
+ raise ValueError(
596
+ "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
597
+ )
598
+ # adds argument to parser based on action_store true
599
+ if action is None:
600
+ parser.add_argument(argument, type=_infer_type)
601
+ else:
602
+ parser.add_argument(argument, action=action)
603
+
604
+ return {
605
+ key: (literal_eval(value) if value in ("True", "False") else value)
606
+ for key, value in parser.parse_args(nargs).__dict__.items()
607
+ }
608
+
609
+
610
+ def prepare_sagemager_args_inputs(
611
+ sagemaker_config: SageMakerConfig, args: argparse.Namespace
612
+ ) -> tuple[argparse.Namespace, dict[str, Any]]:
613
+ # configure environment
614
+ print("Configuring Amazon SageMaker environment")
615
+ os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region
616
+
617
+ # configure credentials
618
+ if sagemaker_config.profile is not None:
619
+ os.environ["AWS_PROFILE"] = sagemaker_config.profile
620
+ elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None:
621
+ os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id
622
+ os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key
623
+ else:
624
+ raise OSError("You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile")
625
+
626
+ # extract needed arguments
627
+ source_dir = os.path.dirname(args.training_script)
628
+ if not source_dir: # checks if string is empty
629
+ source_dir = "."
630
+ entry_point = os.path.basename(args.training_script)
631
+ if not entry_point.endswith(".py"):
632
+ raise ValueError(f'Your training script should be a python script and not "{entry_point}"')
633
+
634
+ print("Converting Arguments to Hyperparameters")
635
+ hyperparameters = _convert_nargs_to_dict(args.training_script_args)
636
+
637
+ try:
638
+ mixed_precision = PrecisionType(args.mixed_precision.lower())
639
+ except ValueError:
640
+ raise ValueError(
641
+ f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
642
+ )
643
+
644
+ try:
645
+ dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
646
+ except ValueError:
647
+ raise ValueError(
648
+ f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
649
+ )
650
+
651
+ # Environment variables to be set for use during training job
652
+ environment = {
653
+ "ACCELERATE_USE_SAGEMAKER": "true",
654
+ "ACCELERATE_MIXED_PRECISION": str(mixed_precision),
655
+ "ACCELERATE_DYNAMO_BACKEND": dynamo_backend.value,
656
+ "ACCELERATE_DYNAMO_MODE": args.dynamo_mode,
657
+ "ACCELERATE_DYNAMO_USE_FULLGRAPH": str(args.dynamo_use_fullgraph),
658
+ "ACCELERATE_DYNAMO_USE_DYNAMIC": str(args.dynamo_use_dynamic),
659
+ "ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION": str(args.dynamo_use_regional_compilation),
660
+ "ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value,
661
+ }
662
+ if args.mixed_precision.lower() == "fp8":
663
+ if not is_fp8_available():
664
+ raise RuntimeError(
665
+ "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
666
+ )
667
+ environment = setup_fp8_env(args, environment)
668
+ # configure distribution set up
669
+ distribution = None
670
+ if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL:
671
+ distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
672
+
673
+ # configure sagemaker inputs
674
+ sagemaker_inputs = None
675
+ if sagemaker_config.sagemaker_inputs_file is not None:
676
+ print(f"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file")
677
+ sagemaker_inputs = {}
678
+ with open(sagemaker_config.sagemaker_inputs_file) as file:
679
+ for i, line in enumerate(file):
680
+ if i == 0:
681
+ continue
682
+ l = line.split("\t")
683
+ sagemaker_inputs[l[0]] = l[1].strip()
684
+ print(f"Loaded SageMaker Inputs: {sagemaker_inputs}")
685
+
686
+ # configure sagemaker metrics
687
+ sagemaker_metrics = None
688
+ if sagemaker_config.sagemaker_metrics_file is not None:
689
+ print(f"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file")
690
+ sagemaker_metrics = []
691
+ with open(sagemaker_config.sagemaker_metrics_file) as file:
692
+ for i, line in enumerate(file):
693
+ if i == 0:
694
+ continue
695
+ l = line.split("\t")
696
+ metric_dict = {
697
+ "Name": l[0],
698
+ "Regex": l[1].strip(),
699
+ }
700
+ sagemaker_metrics.append(metric_dict)
701
+ print(f"Loaded SageMaker Metrics: {sagemaker_metrics}")
702
+
703
+ # configure session
704
+ print("Creating Estimator")
705
+ args = {
706
+ "image_uri": sagemaker_config.image_uri,
707
+ "entry_point": entry_point,
708
+ "source_dir": source_dir,
709
+ "role": sagemaker_config.iam_role_name,
710
+ "transformers_version": sagemaker_config.transformers_version,
711
+ "pytorch_version": sagemaker_config.pytorch_version,
712
+ "py_version": sagemaker_config.py_version,
713
+ "base_job_name": sagemaker_config.base_job_name,
714
+ "instance_count": sagemaker_config.num_machines,
715
+ "instance_type": sagemaker_config.ec2_instance_type,
716
+ "debugger_hook_config": False,
717
+ "distribution": distribution,
718
+ "hyperparameters": hyperparameters,
719
+ "environment": environment,
720
+ "metric_definitions": sagemaker_metrics,
721
+ }
722
+
723
+ if sagemaker_config.additional_args is not None:
724
+ args = merge_dicts(sagemaker_config.additional_args, args)
725
+ return args, sagemaker_inputs
726
+
727
+
728
+ def env_var_path_add(env_var_name, path_to_add):
729
+ """
730
+ Extends a path-based environment variable's value with a new path and returns the updated value. It's up to the
731
+ caller to set it in os.environ.
732
+ """
733
+ paths = [p for p in os.environ.get(env_var_name, "").split(":") if len(p) > 0]
734
+ paths.append(str(path_to_add))
735
+ return ":".join(paths)
736
+
737
+
738
+ class PrepareForLaunch:
739
+ """
740
+ Prepare a function that will launched in a distributed setup.
741
+
742
+ Args:
743
+ launcher (`Callable`):
744
+ The function to launch.
745
+ distributed_type ([`~state.DistributedType`]):
746
+ The distributed type to prepare for.
747
+ debug (`bool`, *optional*, defaults to `False`):
748
+ Whether or not this is a debug launch.
749
+ """
750
+
751
+ def __init__(self, launcher, distributed_type="NO", debug=False):
752
+ self.launcher = launcher
753
+ self.distributed_type = DistributedType(distributed_type)
754
+ self.debug = debug
755
+
756
+ def __call__(self, index, *args):
757
+ if self.debug:
758
+ world_size = int(os.environ.get("WORLD_SIZE"))
759
+ rdv_file = os.environ.get("ACCELERATE_DEBUG_RDV_FILE")
760
+ torch.distributed.init_process_group(
761
+ "gloo",
762
+ rank=index,
763
+ store=torch.distributed.FileStore(rdv_file, world_size),
764
+ world_size=world_size,
765
+ )
766
+ elif self.distributed_type in (
767
+ DistributedType.MULTI_GPU,
768
+ DistributedType.MULTI_MLU,
769
+ DistributedType.MULTI_MUSA,
770
+ DistributedType.MULTI_NPU,
771
+ DistributedType.MULTI_XPU,
772
+ DistributedType.MULTI_CPU,
773
+ ):
774
+ # Prepare the environment for torch.distributed
775
+ os.environ["LOCAL_RANK"] = str(index)
776
+ nproc = int(os.environ.get("NPROC", 1))
777
+ node_rank = int(os.environ.get("NODE_RANK", 0))
778
+ os.environ["RANK"] = str(nproc * node_rank + index)
779
+
780
+ os.environ["FORK_LAUNCHED"] = str(1)
781
+ self.launcher(*args)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/megatron_lm.py ADDED
@@ -0,0 +1,1424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import math
17
+ import os
18
+ from abc import ABC
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
25
+
26
+ from ..optimizer import AcceleratedOptimizer
27
+ from ..scheduler import AcceleratedScheduler
28
+ from .imports import is_megatron_lm_available
29
+ from .operations import recursively_apply, send_to_device
30
+
31
+
32
+ if is_megatron_lm_available():
33
+ from megatron.core import mpu, tensor_parallel
34
+ from megatron.core.distributed import DistributedDataParallel as LocalDDP
35
+ from megatron.core.distributed import finalize_model_grads
36
+ from megatron.core.enums import ModelType
37
+ from megatron.core.num_microbatches_calculator import get_num_microbatches
38
+ from megatron.core.optimizer import get_megatron_optimizer
39
+ from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank
40
+ from megatron.core.pipeline_parallel import get_forward_backward_func
41
+ from megatron.core.utils import get_model_config
42
+ from megatron.inference.text_generation.communication import broadcast_int_list, broadcast_tensor
43
+ from megatron.inference.text_generation.generation import (
44
+ beam_search_and_return_on_first_stage,
45
+ generate_tokens_probs_and_return_on_first_stage,
46
+ )
47
+ from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets
48
+ from megatron.legacy.model import BertModel, Float16Module, GPTModel, T5Model
49
+ from megatron.legacy.model.classification import Classification
50
+ from megatron.training import (
51
+ get_args,
52
+ get_tensorboard_writer,
53
+ get_tokenizer,
54
+ print_rank_last,
55
+ )
56
+ from megatron.training.arguments import (
57
+ _add_data_args,
58
+ _add_validation_args,
59
+ core_transformer_config_from_args,
60
+ parse_args,
61
+ validate_args,
62
+ )
63
+ from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint
64
+ from megatron.training.global_vars import set_global_variables
65
+ from megatron.training.initialize import (
66
+ _compile_dependencies,
67
+ _init_autoresume,
68
+ _initialize_distributed,
69
+ _set_random_seed,
70
+ set_jit_fusion_options,
71
+ write_args_to_tensorboard,
72
+ )
73
+ from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
74
+ from megatron.training.training import (
75
+ build_train_valid_test_data_iterators,
76
+ get_optimizer_param_scheduler,
77
+ num_floating_point_operations,
78
+ setup_model_and_optimizer,
79
+ train_step,
80
+ training_log,
81
+ )
82
+ from megatron.training.utils import (
83
+ average_losses_across_data_parallel_group,
84
+ calc_params_l2_norm,
85
+ get_ltor_masks_and_position_ids,
86
+ unwrap_model,
87
+ )
88
+
89
+
90
+ # model utilities
91
+ def model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True):
92
+ """Build the model."""
93
+ args = get_args()
94
+ mode = "pre-training" if args.pretraining_flag else "fine-tuning"
95
+ if args.rank == 0:
96
+ print(f"Building {args.model_type_name} model in the {mode} mode.")
97
+ print(
98
+ "The Megatron LM model weights are initialized at random in `accelerator.prepare`. "
99
+ "Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup."
100
+ )
101
+ config = core_transformer_config_from_args(args)
102
+ if args.model_type_name == "bert":
103
+ if args.pretraining_flag:
104
+ num_tokentypes = 2 if args.bert_binary_head else 0
105
+ model = BertModel(
106
+ config=config,
107
+ num_tokentypes=num_tokentypes,
108
+ add_binary_head=args.bert_binary_head,
109
+ parallel_output=True,
110
+ pre_process=pre_process,
111
+ post_process=post_process,
112
+ )
113
+ else:
114
+ model = Classification(
115
+ config=config,
116
+ num_classes=args.num_labels,
117
+ num_tokentypes=2,
118
+ pre_process=pre_process,
119
+ post_process=post_process,
120
+ )
121
+ elif args.model_type_name == "gpt":
122
+ model = GPTModel(
123
+ config=config,
124
+ num_tokentypes=0,
125
+ parallel_output=True,
126
+ pre_process=pre_process,
127
+ post_process=post_process,
128
+ )
129
+ elif args.model_type_name == "t5":
130
+ model = T5Model(
131
+ config=config,
132
+ num_tokentypes=0,
133
+ parallel_output=True,
134
+ pre_process=pre_process,
135
+ post_process=post_process,
136
+ add_encoder=add_encoder,
137
+ add_decoder=add_decoder,
138
+ )
139
+ else:
140
+ raise ValueError(f"Unsupported model type: {args.model_type_name}")
141
+ return model
142
+
143
+
144
+ def prepare_model_optimizer_scheduler(accelerator):
145
+ accelerator.print("Preparing model optimizer scheduler")
146
+ args = get_args()
147
+ if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None:
148
+ if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None:
149
+ raise ValueError(
150
+ "You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`."
151
+ )
152
+ custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function
153
+ model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func)
154
+ optimizer = prepare_optimizer(accelerator, model)
155
+ scheduler = prepare_scheduler(accelerator, optimizer, scheduler=None)
156
+ else:
157
+ model_type = ModelType.encoder_or_decoder
158
+ if args.model_type_name == "t5":
159
+ model_type = ModelType.encoder_and_decoder
160
+ model_provider_func_ = model_provider_func
161
+ if accelerator.state.megatron_lm_plugin.custom_model_provider_function is not None:
162
+ model_provider_func_ = accelerator.state.megatron_lm_plugin.custom_model_provider_function
163
+ (model, optimizer, scheduler) = setup_model_and_optimizer(
164
+ model_provider_func_,
165
+ model_type,
166
+ no_wd_decay_cond=args.no_wd_decay_cond,
167
+ scale_lr_cond=args.scale_lr_cond,
168
+ lr_mult=args.lr_mult,
169
+ )
170
+ args.model_len = len(model)
171
+ return model, optimizer, scheduler
172
+
173
+
174
+ # dataloader utilities
175
+ class MegatronLMDummyDataLoader:
176
+ """
177
+ Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training
178
+
179
+ Args:
180
+ **dataset_kwargs: Megatron data arguments.
181
+ """
182
+
183
+ def __init__(self, **dataset_kwargs):
184
+ parser = argparse.ArgumentParser()
185
+ parser = _add_data_args(parser)
186
+ parser = _add_validation_args(parser)
187
+ data_args = parser.parse_known_args()
188
+ self.dataset_args = vars(data_args[0])
189
+ self.dataset_args.update(dataset_kwargs)
190
+ self.dataset_args["megatron_dataset_flag"] = True
191
+
192
+ def set_megatron_data_args(self):
193
+ args = get_args()
194
+ for key, value in self.dataset_args.items():
195
+ old_value = getattr(args, key, "")
196
+ if old_value != value:
197
+ print(
198
+ f"WARNING: MegatronLMDummyDataLoader overriding arguments for {key}:{old_value} with {key}:{value}"
199
+ )
200
+ setattr(args, key, value)
201
+
202
+ def get_train_valid_test_datasets_provider(self, accelerator):
203
+ def train_valid_test_datasets_provider(train_val_test_num_samples):
204
+ """Build train, valid, and test datasets."""
205
+ args = get_args()
206
+ dataset_args = {
207
+ "data_prefix": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path],
208
+ "splits_string": args.split,
209
+ "train_valid_test_num_samples": train_val_test_num_samples,
210
+ "seed": args.seed,
211
+ }
212
+ if args.model_type_name == "bert":
213
+ dataset_args.update(
214
+ {
215
+ "max_seq_length": args.seq_length,
216
+ "binary_head": args.bert_binary_head,
217
+ }
218
+ )
219
+ elif args.model_type_name == "gpt":
220
+ dataset_args.update(
221
+ {
222
+ "max_seq_length": args.seq_length,
223
+ }
224
+ )
225
+ elif args.model_type_name == "t5":
226
+ dataset_args.update(
227
+ {
228
+ "max_seq_length": args.encoder_seq_length,
229
+ "max_seq_length_dec": args.decoder_seq_length,
230
+ "dataset_type": "t5",
231
+ }
232
+ )
233
+ else:
234
+ raise ValueError(f"Unsupported model type: {args.model_type_name}")
235
+ train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args)
236
+ return train_ds, valid_ds, test_ds
237
+
238
+ if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None:
239
+ return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function
240
+ try:
241
+ args = get_args()
242
+ # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
243
+ if args.model_type_name == "bert":
244
+ from pretrain_bert import train_valid_test_datasets_provider
245
+
246
+ train_valid_test_datasets_provider.is_distributed = True
247
+ return train_valid_test_datasets_provider
248
+ elif args.model_type_name == "gpt":
249
+ from pretrain_gpt import train_valid_test_datasets_provider
250
+
251
+ train_valid_test_datasets_provider.is_distributed = True
252
+ return train_valid_test_datasets_provider
253
+ elif args.model_type_name == "t5":
254
+ from pretrain_t5 import train_valid_test_datasets_provider
255
+
256
+ train_valid_test_datasets_provider.is_distributed = True
257
+ return train_valid_test_datasets_provider
258
+ except ImportError:
259
+ pass
260
+ return train_valid_test_datasets_provider
261
+
262
+ def build_train_valid_test_data_iterators(self, accelerator):
263
+ args = get_args()
264
+
265
+ train_valid_test_dataset_provider = self.get_train_valid_test_datasets_provider(accelerator)
266
+ if args.virtual_pipeline_model_parallel_size is not None:
267
+ train_data_iterator = []
268
+ valid_data_iterator = []
269
+ test_data_iterator = []
270
+ for i in range(getattr(args, "model_len", 0)):
271
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
272
+ iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
273
+ train_data_iterator.append(iterators[0])
274
+ valid_data_iterator.append(iterators[1])
275
+ test_data_iterator.append(iterators[2])
276
+ else:
277
+ train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators(
278
+ train_valid_test_dataset_provider
279
+ )
280
+
281
+ return train_data_iterator, valid_data_iterator, test_data_iterator
282
+
283
+
284
+ def _handle_megatron_data_iterator(accelerator, data_iterator):
285
+ class DummyMegatronDataloader:
286
+ def __iter__(self):
287
+ return self
288
+
289
+ def __next__(self):
290
+ return {}
291
+
292
+ is_data_iterator_empty = data_iterator is None
293
+ is_src_data_iterator_empty = torch.tensor(is_data_iterator_empty, dtype=torch.bool, device=accelerator.device)
294
+ torch.distributed.broadcast(
295
+ is_src_data_iterator_empty, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
296
+ )
297
+ if not is_src_data_iterator_empty and is_data_iterator_empty:
298
+ return DummyMegatronDataloader()
299
+ return data_iterator
300
+
301
+
302
+ def prepare_data_loader(accelerator, dataloader):
303
+ accelerator.print("Preparing dataloader")
304
+ args = get_args()
305
+ if not args.megatron_dataset_flag:
306
+ from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader
307
+
308
+ micro_batch_size = args.micro_batch_size * args.num_micro_batches
309
+ kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS}
310
+ if kwargs["batch_size"] is None:
311
+ if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler):
312
+ kwargs["sampler"].batch_size = micro_batch_size
313
+ else:
314
+ del kwargs["sampler"]
315
+ del kwargs["shuffle"]
316
+ del kwargs["batch_size"]
317
+ kwargs["batch_sampler"].batch_size = micro_batch_size
318
+ else:
319
+ del kwargs["batch_sampler"]
320
+ kwargs["batch_size"] = micro_batch_size
321
+
322
+ dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs)
323
+ # split_batches:
324
+ # Megatron only needs to fetch different data between different dp groups,
325
+ # and does not need to split the data within the dp group.
326
+ return prepare_data_loader(
327
+ dataloader,
328
+ accelerator.device,
329
+ num_processes=mpu.get_data_parallel_world_size(),
330
+ process_index=mpu.get_data_parallel_rank(),
331
+ split_batches=False,
332
+ put_on_device=True,
333
+ rng_types=accelerator.rng_types.copy(),
334
+ dispatch_batches=accelerator.dispatch_batches,
335
+ )
336
+ else:
337
+ if args.consumed_samples is not None:
338
+ (
339
+ args.consumed_train_samples,
340
+ args.consumed_valid_samples,
341
+ args.consumed_test_samples,
342
+ ) = args.consumed_samples
343
+ else:
344
+ args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0
345
+ args.micro_batch_size = args.micro_batch_size * args.num_micro_batches
346
+ # In order to be compatible with data in transform format,
347
+ # it needs to increase the size of mbs first,
348
+ # and then split the large batch data into some mbs.
349
+ (
350
+ train_data_iterator,
351
+ valid_data_iterator,
352
+ test_data_iterator,
353
+ ) = dataloader.build_train_valid_test_data_iterators(accelerator)
354
+ args.micro_batch_size = args.micro_batch_size // args.num_micro_batches
355
+
356
+ train_data_iterator = _handle_megatron_data_iterator(
357
+ accelerator=accelerator, data_iterator=train_data_iterator
358
+ )
359
+ valid_data_iterator = _handle_megatron_data_iterator(
360
+ accelerator=accelerator, data_iterator=valid_data_iterator
361
+ )
362
+ test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator)
363
+
364
+ return train_data_iterator, valid_data_iterator, test_data_iterator
365
+
366
+
367
+ # optimizer utilities
368
+ class MegatronLMOptimizerWrapper(AcceleratedOptimizer):
369
+ def __init__(self, optimizer):
370
+ super().__init__(optimizer, device_placement=False, scaler=None)
371
+
372
+ def zero_grad(self, set_to_none=None):
373
+ pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
374
+
375
+ def step(self):
376
+ pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
377
+
378
+ @property
379
+ def step_was_skipped(self):
380
+ """Whether or not the optimizer step was done, or skipped because of gradient overflow."""
381
+ return self.optimizer.skipped_iter
382
+
383
+
384
+ def prepare_optimizer(accelerator, model):
385
+ accelerator.print("Preparing optimizer")
386
+ args = get_args()
387
+ return get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult)
388
+
389
+
390
+ # scheduler utilities
391
+ class MegatronLMDummyScheduler:
392
+ """
393
+ Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
394
+ loop when scheduler config is specified in the deepspeed config file.
395
+
396
+ Args:
397
+ optimizer (`torch.optim.optimizer.Optimizer`):
398
+ The optimizer to wrap.
399
+ total_num_steps (int):
400
+ Total number of steps.
401
+ warmup_num_steps (int):
402
+ Number of steps for warmup.
403
+ **kwargs (additional keyword arguments, *optional*):
404
+ Other arguments.
405
+ """
406
+
407
+ def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs):
408
+ self.optimizer = optimizer
409
+ self.total_num_steps = total_num_steps
410
+ self.warmup_num_steps = warmup_num_steps
411
+ self.kwargs = kwargs
412
+
413
+
414
+ class MegatronLMSchedulerWrapper(AcceleratedScheduler):
415
+ def __init__(self, scheduler, optimizers):
416
+ super().__init__(scheduler, optimizers)
417
+
418
+ def step(self, *args, **kwargs):
419
+ return # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed
420
+
421
+
422
+ def prepare_scheduler(accelerator, optimizer, scheduler):
423
+ accelerator.print("Preparing scheduler")
424
+ scheduler = get_optimizer_param_scheduler(optimizer)
425
+ return scheduler
426
+
427
+
428
+ class AbstractTrainStep(ABC):
429
+ """Abstract class for batching, forward pass and loss handler."""
430
+
431
+ def __init__(self, name):
432
+ super().__init__()
433
+ self.name = name
434
+
435
+ def get_batch_func(self, accelerator, megatron_dataset_flag):
436
+ pass
437
+
438
+ def get_forward_step_func(self):
439
+ pass
440
+
441
+ def get_loss_func(self, accelerator):
442
+ pass
443
+
444
+
445
+ class BertTrainStep(AbstractTrainStep):
446
+ """
447
+ Bert train step class.
448
+
449
+ Args:
450
+ args (`argparse.Namespace`): Megatron-LM arguments.
451
+ """
452
+
453
+ def __init__(self, accelerator, args):
454
+ super().__init__("BertTrainStep")
455
+ self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
456
+ self.loss_func = self.get_loss_func(accelerator, args.pretraining_flag, args.num_labels)
457
+ self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head)
458
+ if not args.model_return_dict:
459
+ self.model_output_class = None
460
+ else:
461
+ from transformers.modeling_outputs import SequenceClassifierOutput
462
+
463
+ self.model_output_class = SequenceClassifierOutput
464
+
465
+ def get_batch_func(self, accelerator, megatron_dataset_flag):
466
+ def get_batch_megatron(data_iterator):
467
+ """Build the batch."""
468
+
469
+ # Items and their type.
470
+ keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
471
+ datatype = torch.int64
472
+
473
+ # Broadcast data.
474
+ if data_iterator is not None:
475
+ data = next(data_iterator)
476
+ else:
477
+ data = None
478
+ data_b = tensor_parallel.broadcast_data(keys, data, datatype)
479
+
480
+ # Unpack.
481
+ tokens = data_b["text"].long()
482
+ types = data_b["types"].long()
483
+ sentence_order = data_b["is_random"].long()
484
+ loss_mask = data_b["loss_mask"].float()
485
+ lm_labels = data_b["labels"].long()
486
+ padding_mask = data_b["padding_mask"].long()
487
+
488
+ return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
489
+
490
+ def get_batch_transformer(data_iterator):
491
+ """Build the batch."""
492
+ data = next(data_iterator)
493
+ data = send_to_device(data, torch.cuda.current_device())
494
+
495
+ # Unpack.
496
+ tokens = data["input_ids"].long()
497
+ padding_mask = data["attention_mask"].long()
498
+ if "token_type_ids" in data:
499
+ types = data["token_type_ids"].long()
500
+ else:
501
+ types = None
502
+ if "labels" in data:
503
+ lm_labels = data["labels"].long()
504
+ loss_mask = (data["labels"] != -100).to(torch.float)
505
+ else:
506
+ lm_labels = None
507
+ loss_mask = None
508
+ if "next_sentence_label" in data:
509
+ sentence_order = data["next_sentence_label"].long()
510
+ else:
511
+ sentence_order = None
512
+
513
+ return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
514
+
515
+ if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
516
+ return accelerator.state.megatron_lm_plugin.custom_get_batch_function
517
+ if megatron_dataset_flag:
518
+ try:
519
+ # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
520
+ from pretrain_bert import get_batch
521
+
522
+ return get_batch
523
+ except ImportError:
524
+ pass
525
+ return get_batch_megatron
526
+ else:
527
+ return get_batch_transformer
528
+
529
+ def get_loss_func(self, accelerator, pretraining_flag, num_labels):
530
+ def loss_func_pretrain(loss_mask, sentence_order, output_tensor):
531
+ lm_loss_, sop_logits = output_tensor
532
+
533
+ lm_loss_ = lm_loss_.float()
534
+ loss_mask = loss_mask.float()
535
+ lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
536
+
537
+ if sop_logits is not None:
538
+ sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)
539
+ sop_loss = sop_loss.float()
540
+ loss = lm_loss + sop_loss
541
+ averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
542
+ return loss, {"lm loss": averaged_losses[0], "sop loss": averaged_losses[1]}
543
+
544
+ else:
545
+ loss = lm_loss
546
+ averaged_losses = average_losses_across_data_parallel_group([lm_loss])
547
+ return loss, {"lm loss": averaged_losses[0]}
548
+
549
+ def loss_func_finetune(labels, logits):
550
+ if num_labels == 1:
551
+ # We are doing regression
552
+ loss_fct = MSELoss()
553
+ loss = loss_fct(logits.view(-1), labels.view(-1))
554
+ elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
555
+ loss_fct = CrossEntropyLoss()
556
+ loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
557
+ else:
558
+ loss_fct = BCEWithLogitsLoss()
559
+ loss = loss_fct(logits, labels)
560
+ averaged_losses = average_losses_across_data_parallel_group([loss])
561
+ return loss, {"loss": averaged_losses[0]}
562
+
563
+ if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
564
+ return accelerator.state.megatron_lm_plugin.custom_loss_function
565
+ if pretraining_flag:
566
+ return loss_func_pretrain
567
+ else:
568
+ return loss_func_finetune
569
+
570
+ def get_forward_step_func(self, pretraining_flag, bert_binary_head):
571
+ def forward_step(data_iterator, model):
572
+ """Forward step."""
573
+ tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator)
574
+ if not bert_binary_head:
575
+ types = None
576
+ # Forward pass through the model.
577
+ if pretraining_flag:
578
+ output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels)
579
+ return output_tensor, partial(self.loss_func, loss_mask, sentence_order)
580
+ else:
581
+ logits = model(tokens, padding_mask, tokentype_ids=types)
582
+ return logits, partial(self.loss_func, labels)
583
+
584
+ return forward_step
585
+
586
+
587
+ class GPTTrainStep(AbstractTrainStep):
588
+ """
589
+ GPT train step class.
590
+
591
+ Args:
592
+ args (`argparse.Namespace`): Megatron-LM arguments.
593
+ """
594
+
595
+ def __init__(self, accelerator, args):
596
+ super().__init__("GPTTrainStep")
597
+ self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
598
+ self.loss_func = self.get_loss_func(accelerator)
599
+ self.forward_step = self.get_forward_step_func()
600
+ self.eod_token = args.padded_vocab_size - 1
601
+ if args.vocab_file is not None:
602
+ tokenizer = get_tokenizer()
603
+ self.eod_token = tokenizer.eod
604
+ self.reset_position_ids = args.reset_position_ids
605
+ self.reset_attention_mask = args.reset_attention_mask
606
+ self.eod_mask_loss = args.eod_mask_loss
607
+ if not args.model_return_dict:
608
+ self.model_output_class = None
609
+ else:
610
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
611
+
612
+ self.model_output_class = CausalLMOutputWithCrossAttentions
613
+
614
+ def get_batch_func(self, accelerator, megatron_dataset_flag):
615
+ def get_batch_megatron(data_iterator):
616
+ """Generate a batch"""
617
+ # Items and their type.
618
+ keys = ["text"]
619
+ datatype = torch.int64
620
+
621
+ # Broadcast data.
622
+ if data_iterator is not None:
623
+ data = next(data_iterator)
624
+ else:
625
+ data = None
626
+ data_b = tensor_parallel.broadcast_data(keys, data, datatype)
627
+
628
+ # Unpack.
629
+ tokens_ = data_b["text"].long()
630
+ labels = tokens_[:, 1:].contiguous()
631
+ tokens = tokens_[:, :-1].contiguous()
632
+
633
+ # Get the masks and position ids.
634
+ attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
635
+ tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss
636
+ )
637
+
638
+ return tokens, labels, loss_mask, attention_mask, position_ids
639
+
640
+ def get_batch_transformer(data_iterator):
641
+ data = next(data_iterator)
642
+ data = {"input_ids": data["input_ids"]}
643
+ data = send_to_device(data, torch.cuda.current_device())
644
+
645
+ tokens_ = data["input_ids"].long()
646
+ padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token
647
+ tokens_ = torch.concat([tokens_, padding], dim=1)
648
+ labels = tokens_[:, 1:].contiguous()
649
+ tokens = tokens_[:, :-1].contiguous()
650
+ # Get the masks and position ids.
651
+ attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
652
+ tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, True
653
+ )
654
+ return tokens, labels, loss_mask, attention_mask, position_ids
655
+
656
+ if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
657
+ return accelerator.state.megatron_lm_plugin.custom_get_batch_function
658
+ if megatron_dataset_flag:
659
+ try:
660
+ # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
661
+ from pretrain_gpt import get_batch
662
+
663
+ return get_batch
664
+ except ImportError:
665
+ pass
666
+ return get_batch_megatron
667
+ else:
668
+ return get_batch_transformer
669
+
670
+ def get_loss_func(self, accelerator):
671
+ args = get_args()
672
+
673
+ def loss_func(loss_mask, output_tensor):
674
+ if args.return_logits:
675
+ losses, logits = output_tensor
676
+ else:
677
+ losses = output_tensor
678
+ losses = losses.float()
679
+ loss_mask = loss_mask.view(-1).float()
680
+ if args.context_parallel_size > 1:
681
+ loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
682
+ torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
683
+ loss = loss[0] / loss[1]
684
+ else:
685
+ loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
686
+
687
+ # Check individual rank losses are not NaN prior to DP all-reduce.
688
+ if args.check_for_nan_in_loss_and_grad:
689
+ global_rank = torch.distributed.get_rank()
690
+ assert not loss.isnan(), (
691
+ f"Rank {global_rank}: found NaN in local forward loss calculation. "
692
+ f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}"
693
+ )
694
+
695
+ # Reduce loss for logging.
696
+ averaged_loss = average_losses_across_data_parallel_group([loss])
697
+
698
+ output_dict = {"lm loss": averaged_loss[0]}
699
+ if args.return_logits:
700
+ output_dict.update({"logits": logits})
701
+ return loss, output_dict
702
+
703
+ if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
704
+ return accelerator.state.megatron_lm_plugin.custom_loss_function
705
+ return loss_func
706
+
707
+ def get_forward_step_func(self):
708
+ def forward_step(data_iterator, model):
709
+ """Forward step."""
710
+ # Get the batch.
711
+ tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator)
712
+ output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
713
+
714
+ return output_tensor, partial(self.loss_func, loss_mask)
715
+
716
+ return forward_step
717
+
718
+
719
+ class T5TrainStep(AbstractTrainStep):
720
+ """
721
+ T5 train step class.
722
+
723
+ Args:
724
+ args (`argparse.Namespace`): Megatron-LM arguments.
725
+ """
726
+
727
+ def __init__(self, accelerator, args):
728
+ super().__init__("T5TrainStep")
729
+ self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
730
+ self.loss_func = self.get_loss_func(accelerator)
731
+ self.forward_step = self.get_forward_step_func()
732
+ if not args.model_return_dict:
733
+ self.model_output_class = None
734
+ else:
735
+ from transformers.modeling_outputs import Seq2SeqLMOutput
736
+
737
+ self.model_output_class = Seq2SeqLMOutput
738
+
739
+ @staticmethod
740
+ def attn_mask_postprocess(attention_mask):
741
+ # We create a 3D attention mask from a 2D tensor mask.
742
+ # [b, 1, s]
743
+ attention_mask_b1s = attention_mask.unsqueeze(1)
744
+ # [b, s, 1]
745
+ attention_mask_bs1 = attention_mask.unsqueeze(2)
746
+ # [b, s, s]
747
+ attention_mask_bss = attention_mask_b1s * attention_mask_bs1
748
+ # Convert attention mask to binary:
749
+ extended_attention_mask = attention_mask_bss < 0.5
750
+ return extended_attention_mask
751
+
752
+ @staticmethod
753
+ def get_decoder_mask(seq_length, device):
754
+ attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device))
755
+ attention_mask = attention_mask < 0.5
756
+ return attention_mask
757
+
758
+ @staticmethod
759
+ def get_enc_dec_mask(attention_mask, dec_seq_length, device):
760
+ batch_size, _ = attention_mask.shape
761
+ # We create a 3D attention mask from a 2D tensor mask.
762
+ # [b, 1, s]
763
+ attention_mask_b1s = attention_mask.unsqueeze(1)
764
+ # [b, s, 1]
765
+ attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device)
766
+ attention_mask_bss = attention_mask_bs1 * attention_mask_b1s
767
+ extended_attention_mask = attention_mask_bss < 0.5
768
+ return extended_attention_mask
769
+
770
+ def get_batch_func(self, accelerator, megatron_dataset_flag):
771
+ def get_batch_megatron(data_iterator):
772
+ """Build the batch."""
773
+
774
+ keys = ["text_enc", "text_dec", "labels", "loss_mask", "enc_mask", "dec_mask", "enc_dec_mask"]
775
+ datatype = torch.int64
776
+
777
+ # Broadcast data.
778
+ if data_iterator is not None:
779
+ data = next(data_iterator)
780
+ else:
781
+ data = None
782
+ data_b = tensor_parallel.broadcast_data(keys, data, datatype)
783
+
784
+ # Unpack.
785
+ tokens_enc = data_b["text_enc"].long()
786
+ tokens_dec = data_b["text_dec"].long()
787
+ labels = data_b["labels"].long()
788
+ loss_mask = data_b["loss_mask"].float()
789
+
790
+ enc_mask = data_b["enc_mask"] < 0.5
791
+ dec_mask = data_b["dec_mask"] < 0.5
792
+ enc_dec_mask = data_b["enc_dec_mask"] < 0.5
793
+
794
+ return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
795
+
796
+ def get_batch_transformer(data_iterator):
797
+ """Build the batch."""
798
+ data = next(data_iterator)
799
+ data = send_to_device(data, torch.cuda.current_device())
800
+
801
+ tokens_enc = data["input_ids"].long()
802
+ labels = data["labels"].long()
803
+ loss_mask = (labels != -100).to(torch.float)
804
+ if "decoder_input_ids" in data:
805
+ tokens_dec = data["decoder_input_ids"].long()
806
+ else:
807
+ tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long)
808
+ tokens_dec[..., 1:] = labels[..., :-1].clone()
809
+ tokens_dec[..., 0] = 0
810
+ tokens_dec.masked_fill_(tokens_dec == -100, 0)
811
+ enc_mask = T5TrainStep.attn_mask_postprocess(data["attention_mask"].long())
812
+ dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device)
813
+ enc_dec_mask = T5TrainStep.get_enc_dec_mask(
814
+ data["attention_mask"].long(), tokens_dec.shape[1], tokens_dec.device
815
+ )
816
+
817
+ return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
818
+
819
+ if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
820
+ return accelerator.state.megatron_lm_plugin.custom_get_batch_function
821
+ if megatron_dataset_flag:
822
+ try:
823
+ # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
824
+ from pretrain_t5 import get_batch
825
+
826
+ return get_batch
827
+ except ImportError:
828
+ pass
829
+ return get_batch_megatron
830
+ else:
831
+ return get_batch_transformer
832
+
833
+ def get_loss_func(self, accelerator):
834
+ def loss_func(loss_mask, output_tensor):
835
+ lm_loss_ = output_tensor.float()
836
+ lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
837
+
838
+ loss = lm_loss
839
+ averaged_losses = average_losses_across_data_parallel_group([lm_loss])
840
+
841
+ return loss, {"lm loss": averaged_losses[0]}
842
+
843
+ if accelerator.state.megatron_lm_plugin.custom_loss_function is not None:
844
+ return accelerator.state.megatron_lm_plugin.custom_loss_function
845
+ return loss_func
846
+
847
+ def get_forward_step_func(self):
848
+ def forward_step(data_iterator, model):
849
+ """Forward step."""
850
+ # Get the batch.
851
+ tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch(
852
+ data_iterator
853
+ )
854
+ # Forward model lm_labels
855
+ output_tensor = model(
856
+ tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels
857
+ )
858
+
859
+ return output_tensor, partial(self.loss_func, loss_mask)
860
+
861
+ return forward_step
862
+
863
+
864
+ def finish_mpu_init():
865
+ # torch.distributed initialization
866
+ args = get_args()
867
+ # Pytorch distributed.
868
+ _initialize_distributed()
869
+
870
+ # Random seeds for reproducibility.
871
+ if args.rank == 0:
872
+ print(f"> setting random seeds to {args.seed} ...")
873
+ _set_random_seed(args.seed, args.data_parallel_random_init)
874
+
875
+
876
+ # initialize megatron setup
877
+ def initialize(accelerator, extra_args_provider=None, args_defaults={}):
878
+ accelerator.print("Initializing Megatron-LM")
879
+ assert torch.cuda.is_available(), "Megatron requires CUDA."
880
+
881
+ # Parse arguments
882
+ args = parse_args(extra_args_provider, ignore_unknown_args=True)
883
+
884
+ # Set defaults
885
+ for key, value in args_defaults.items():
886
+ if getattr(args, key, None) is not None:
887
+ if args.rank == 0:
888
+ print(
889
+ f"WARNING: overriding default arguments for {key}:{getattr(args, key)} with {key}:{value}",
890
+ flush=True,
891
+ )
892
+ setattr(args, key, value)
893
+
894
+ if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
895
+ assert args.load is not None, "--use-checkpoints-args requires --load argument"
896
+ load_args_from_checkpoint(args)
897
+
898
+ validate_args(args)
899
+
900
+ # set global args, build tokenizer, and set adlr-autoresume,
901
+ # tensorboard-writer, and timers.
902
+ set_global_variables(args)
903
+
904
+ # Megatron's MPU is the master. Complete initialization right away.
905
+ finish_mpu_init()
906
+
907
+ # Autoresume.
908
+ _init_autoresume()
909
+
910
+ # Compile dependencies.
911
+ _compile_dependencies()
912
+
913
+ # Set pytorch JIT layer fusion options and warmup JIT functions.
914
+ set_jit_fusion_options()
915
+ args = get_args()
916
+ if getattr(args, "padded_vocab_size", None) is None:
917
+ args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args)
918
+ if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2:
919
+ args.bert_binary_head = True
920
+ else:
921
+ args.bert_binary_head = False
922
+ args.iteration = 0
923
+
924
+
925
+ class MegatronEngine(torch.nn.Module):
926
+ """
927
+ Megatron-LM model wrapper
928
+
929
+ Args:
930
+ accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use.
931
+ model: Megatron-LM model
932
+ optimizer: Megatron-LM optimizer
933
+ lr_scheduler: Megatron-LM lr scheduler
934
+ """
935
+
936
+ def __init__(self, accelerator, model, optimizer, scheduler):
937
+ super().__init__()
938
+ self.module = model
939
+ self.base_model = model[0]
940
+ self.optimizer = optimizer
941
+ self.scheduler = scheduler
942
+ args = get_args()
943
+ if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None:
944
+ self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class(
945
+ args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs
946
+ )
947
+ elif args.model_type_name == "bert":
948
+ self.train_step_handler = BertTrainStep(accelerator, args)
949
+ elif args.model_type_name == "gpt":
950
+ self.train_step_handler = GPTTrainStep(accelerator, args)
951
+ elif args.model_type_name == "t5":
952
+ self.train_step_handler = T5TrainStep(accelerator, args)
953
+ else:
954
+ raise ValueError(f"Unsupported model type: {args.model_type_name}")
955
+ self.optimizer.skipped_iter = False
956
+
957
+ # Tracking loss.
958
+ self.total_loss_dict = {}
959
+ self.eval_total_loss_dict = {}
960
+ self.iteration = 0
961
+ self.report_memory_flag = True
962
+ self.num_floating_point_operations_so_far = 0
963
+ self.module_config = None
964
+ if args.tensorboard_dir is not None:
965
+ write_args_to_tensorboard()
966
+
967
+ def get_module_config(self):
968
+ args = get_args()
969
+ config = get_model_config(self.module[0])
970
+ # Setup some training config params
971
+ config.grad_scale_func = self.optimizer.scale_loss
972
+ if isinstance(self.module[0], LocalDDP) and args.overlap_grad_reduce:
973
+ assert config.no_sync_func is None, (
974
+ "When overlap_grad_reduce is True, config.no_sync_func must be None; "
975
+ "a custom no_sync_func is not supported when overlapping grad-reduce"
976
+ )
977
+ config.no_sync_func = [model_chunk.no_sync for model_chunk in self.module]
978
+ if len(self.module) == 1:
979
+ config.no_sync_func = config.no_sync_func[0]
980
+ if args.delay_grad_reduce:
981
+ config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.module]
982
+ if len(self.module) == 1:
983
+ config.grad_sync_func = config.grad_sync_func[0]
984
+ if args.overlap_param_gather and args.delay_param_gather:
985
+ config.param_sync_func = [
986
+ lambda x: self.optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.module))
987
+ ]
988
+ if len(self.module) == 1:
989
+ config.param_sync_func = config.param_sync_func[0]
990
+ config.finalize_model_grads_func = finalize_model_grads
991
+ return config
992
+
993
+ def train(self):
994
+ for model_module in self.module:
995
+ model_module.train()
996
+
997
+ if self.module_config is None:
998
+ self.module_config = self.get_module_config()
999
+
1000
+ self.log_eval_results()
1001
+
1002
+ def eval(self):
1003
+ for model_module in self.module:
1004
+ model_module.eval()
1005
+
1006
+ if self.module_config is None:
1007
+ self.module_config = self.get_module_config()
1008
+
1009
+ def get_batch_data_iterator(self, batch_data):
1010
+ args = get_args()
1011
+ data_chunks = []
1012
+ if len(batch_data) > 0:
1013
+ if args.num_micro_batches > 1:
1014
+ for i in range(0, args.num_micro_batches):
1015
+ data_chunks.append(
1016
+ {
1017
+ k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size]
1018
+ for k, v in batch_data.items()
1019
+ }
1020
+ )
1021
+ else:
1022
+ data_chunks = [batch_data]
1023
+
1024
+ if len(self.module) > 1:
1025
+ batch_data_iterator = (
1026
+ [iter(data_chunks) for _ in range(len(self.module))]
1027
+ if len(batch_data) > 0
1028
+ else [None] * len(self.module)
1029
+ )
1030
+ else:
1031
+ batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None
1032
+ return batch_data_iterator
1033
+
1034
+ def train_step(self, **batch_data):
1035
+ """
1036
+ Training step for Megatron-LM
1037
+
1038
+ Args:
1039
+ batch_data (:obj:`dict`): The batch data to train on.
1040
+ """
1041
+
1042
+ batch_data_iterator = self.get_batch_data_iterator(batch_data)
1043
+
1044
+ loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
1045
+ forward_step_func=self.train_step_handler.forward_step,
1046
+ data_iterator=batch_data_iterator,
1047
+ model=self.module,
1048
+ optimizer=self.optimizer,
1049
+ opt_param_scheduler=self.scheduler,
1050
+ config=self.module_config,
1051
+ )
1052
+
1053
+ self.optimizer.skipped_iter = skipped_iter == 1
1054
+
1055
+ return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
1056
+
1057
+ def eval_step(self, **batch_data):
1058
+ """
1059
+ Evaluation step for Megatron-LM
1060
+
1061
+ Args:
1062
+ batch_data (:obj:`dict`): The batch data to evaluate on.
1063
+ """
1064
+
1065
+ args = get_args()
1066
+ batch_data_iterator = self.get_batch_data_iterator(batch_data)
1067
+ forward_backward_func = get_forward_backward_func()
1068
+ loss_dicts = forward_backward_func(
1069
+ forward_step_func=self.train_step_handler.forward_step,
1070
+ data_iterator=batch_data_iterator,
1071
+ model=self.module,
1072
+ num_microbatches=get_num_microbatches(),
1073
+ seq_length=args.seq_length,
1074
+ micro_batch_size=args.micro_batch_size,
1075
+ forward_only=True,
1076
+ )
1077
+ # Empty unused memory
1078
+ if args.empty_unused_memory_level >= 1:
1079
+ torch.cuda.empty_cache()
1080
+
1081
+ args.consumed_valid_samples += (
1082
+ mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
1083
+ )
1084
+
1085
+ if mpu.is_pipeline_last_stage(ignore_virtual=True):
1086
+ # Average loss across microbatches.
1087
+ loss_reduced = {}
1088
+ for key in loss_dicts[0]:
1089
+ losses_reduced_for_key = [x[key] for x in loss_dicts]
1090
+ if len(losses_reduced_for_key[0].shape) == 0:
1091
+ loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
1092
+ else:
1093
+ loss_reduced[key] = torch.concat(losses_reduced_for_key)
1094
+ return loss_reduced
1095
+ return {}
1096
+
1097
+ def forward(self, **batch_data):
1098
+ # During training, we use train_step()
1099
+ # model(**batch_data) performs following operations by delegating it to `self.train_step`:
1100
+ # 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism
1101
+ # 2. Set grad to zero.
1102
+ # 3. forward pass and backward pass using Pipeline Parallelism
1103
+ # 4. Empty unused memory.
1104
+ # 5. Reduce gradients.
1105
+ # 6. Update parameters.
1106
+ # 7. Gather params when using Distributed Optimizer (Data Parallelism).
1107
+ # 8. Update learning rate if scheduler is specified.
1108
+ # 9. Empty unused memory.
1109
+ # 10. Average loss across microbatches and across DP ranks.
1110
+ #
1111
+ # During evaluation, we use eval_step()
1112
+ args = get_args()
1113
+ if self.module[0].training:
1114
+ loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data)
1115
+ self.iteration += 1
1116
+ batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
1117
+ args.consumed_train_samples += batch_size
1118
+ self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size)
1119
+ if args.tensorboard_dir is not None:
1120
+ # Logging.
1121
+ loss_scale = self.optimizer.get_loss_scale().item()
1122
+ params_norm = None
1123
+ if args.log_params_norm:
1124
+ params_norm = calc_params_l2_norm(self.model)
1125
+ self.report_memory_flag = training_log(
1126
+ loss_dict,
1127
+ self.total_loss_dict,
1128
+ self.optimizer.param_groups[0]["lr"],
1129
+ self.iteration,
1130
+ loss_scale,
1131
+ self.report_memory_flag,
1132
+ skipped_iter,
1133
+ grad_norm,
1134
+ params_norm,
1135
+ num_zeros_in_grad,
1136
+ )
1137
+ else:
1138
+ loss_dict = self.eval_step(**batch_data)
1139
+ if args.tensorboard_dir is not None:
1140
+ for key in loss_dict:
1141
+ self.eval_total_loss_dict[key] = (
1142
+ self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
1143
+ )
1144
+ self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get(
1145
+ key + "_num_iters", torch.cuda.FloatTensor([0.0])
1146
+ ) + torch.cuda.FloatTensor([1.0])
1147
+
1148
+ loss = torch.tensor(0.0, device=torch.cuda.current_device())
1149
+ for key in loss_dict:
1150
+ if len(loss_dict[key].shape) == 0:
1151
+ loss += loss_dict[key]
1152
+
1153
+ logits = None
1154
+ if "logits" in loss_dict:
1155
+ logits = loss_dict["logits"]
1156
+ if self.train_step_handler.model_output_class is not None:
1157
+ return self.train_step_handler.model_output_class(loss=loss, logits=logits)
1158
+ return loss
1159
+
1160
+ def log_eval_results(self):
1161
+ args = get_args()
1162
+ if args.tensorboard_dir is None or self.iteration == 0:
1163
+ return
1164
+ args = get_args()
1165
+ writer = get_tensorboard_writer()
1166
+ string = f"validation loss at iteration {self.iteration} | "
1167
+ for key in self.eval_total_loss_dict:
1168
+ if key.endswith("_num_iters"):
1169
+ continue
1170
+ value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + "_num_iters"]
1171
+ string += f"{key} value: {value} | "
1172
+ ppl = math.exp(min(20, value.item()))
1173
+ if args.pretraining_flag:
1174
+ string += f"{key} PPL: {ppl} | "
1175
+ if writer:
1176
+ writer.add_scalar(f"{key} validation", value.item(), self.iteration)
1177
+ if args.pretraining_flag:
1178
+ writer.add_scalar(f"{key} validation ppl", ppl, self.iteration)
1179
+
1180
+ length = len(string) + 1
1181
+ print_rank_last("-" * length)
1182
+ print_rank_last(string)
1183
+ print_rank_last("-" * length)
1184
+ self.eval_total_loss_dict = {}
1185
+
1186
+ def save_checkpoint(self, output_dir):
1187
+ self.log_eval_results()
1188
+ args = get_args()
1189
+ args.save = output_dir
1190
+ torch.distributed.barrier()
1191
+ save_checkpoint(
1192
+ self.iteration,
1193
+ self.module,
1194
+ self.optimizer,
1195
+ self.scheduler,
1196
+ num_floating_point_operations_so_far=self.num_floating_point_operations_so_far,
1197
+ )
1198
+ torch.distributed.barrier()
1199
+
1200
+ def load_checkpoint(self, input_dir):
1201
+ args = get_args()
1202
+ args.load = input_dir
1203
+ args.consumed_train_samples = 0
1204
+ args.consumed_valid_samples = 0
1205
+ torch.distributed.barrier()
1206
+ iteration, num_floating_point_operations_so_far = load_checkpoint(self.module, self.optimizer, self.scheduler)
1207
+ torch.distributed.barrier()
1208
+ self.iteration = iteration
1209
+ self.num_floating_point_operations_so_far = num_floating_point_operations_so_far
1210
+ if args.fp16 and self.iteration == 0:
1211
+ self.optimizer.reload_model_params()
1212
+
1213
+ def megatron_generate(
1214
+ self,
1215
+ inputs,
1216
+ attention_mask=None,
1217
+ max_length=None,
1218
+ max_new_tokens=None,
1219
+ num_beams=None,
1220
+ temperature=None,
1221
+ top_k=None,
1222
+ top_p=None,
1223
+ length_penalty=None,
1224
+ **kwargs,
1225
+ ):
1226
+ """
1227
+ Generate method for GPT2 model. This method is used for inference. Supports both greedy and beam search along
1228
+ with sampling. Refer the Megatron-LM repo for more details
1229
+
1230
+ Args:
1231
+ inputs (torch.Tensor): input ids
1232
+ attention_mask (torch.Tensor, optional): attention mask. Defaults to None.
1233
+ max_length (int, optional): max length of the generated sequence. Defaults to None.
1234
+ Either this or max_new_tokens should be provided.
1235
+ max_new_tokens (int, optional): max number of tokens to be generated. Defaults to None.
1236
+ Either this or max_length should be provided.
1237
+ num_beams (int, optional): number of beams to use for beam search. Defaults to None.
1238
+ temperature (float, optional): temperature for sampling. Defaults to 1.0.
1239
+ top_k (int, optional): top k tokens to consider for sampling. Defaults to 0.0.
1240
+ top_p (float, optional): tokens in top p probability are considered for sampling. Defaults to 0.0.
1241
+ length_penalty (float, optional): length penalty for beam search. Defaults to None.
1242
+ kwargs: additional key-value arguments
1243
+ """
1244
+
1245
+ # checking if required arguments are passed
1246
+ args = get_args()
1247
+ if args.model_type_name != "gpt":
1248
+ raise NotImplementedError("Generate method is not implemented for this model")
1249
+
1250
+ if args.data_parallel_size > 1:
1251
+ raise ValueError("Generate method requires data parallelism to be 1")
1252
+
1253
+ if args.sequence_parallel:
1254
+ raise ValueError("Generate method requires sequence parallelism to be False")
1255
+
1256
+ if args.recompute_granularity is not None:
1257
+ raise ValueError("Checkpoint activations cannot be set for inference")
1258
+
1259
+ if args.vocab_file is None:
1260
+ raise ValueError("Vocab file is required for inference")
1261
+
1262
+ # Prepare inputs
1263
+ if max_length is None and max_new_tokens is None:
1264
+ raise ValueError("`max_length` or `max_new_tokens` are required for inference")
1265
+
1266
+ if temperature is None:
1267
+ temperature = 1.0
1268
+ elif not (0.0 < temperature <= 100.0):
1269
+ raise ValueError("temperature must be a positive number less than or equal to 100.0")
1270
+
1271
+ if top_k is None:
1272
+ top_k = 0
1273
+ elif not (0 <= top_k <= 1000):
1274
+ raise ValueError("top_k must be a positive number less than or equal to 1000")
1275
+
1276
+ if top_p is None:
1277
+ top_p = 0.0
1278
+ elif top_p > 0.0 and top_k > 0.0:
1279
+ raise ValueError("top_p and top_k sampling cannot be set together")
1280
+ else:
1281
+ if not (0.0 <= top_p <= 1.0):
1282
+ raise ValueError("top_p must be less than or equal to 1.0")
1283
+
1284
+ top_p_decay = kwargs.get("top_p_decay", 0.0)
1285
+ if not (0.0 <= top_p_decay <= 1.0):
1286
+ raise ValueError("top_p_decay must be less than or equal to 1.0")
1287
+
1288
+ top_p_bound = kwargs.get("top_p_bound", 0.0)
1289
+ if not (0.0 <= top_p_bound <= 1.0):
1290
+ raise ValueError("top_p_bound must be less than or equal to 1.0")
1291
+
1292
+ add_BOS = kwargs.get("add_BOS", False)
1293
+ if not (isinstance(add_BOS, bool)):
1294
+ raise ValueError("add_BOS must be a boolean")
1295
+
1296
+ beam_width = num_beams
1297
+ if beam_width is not None:
1298
+ if not isinstance(beam_width, int):
1299
+ raise ValueError("beam_width must be an integer")
1300
+ if beam_width < 1:
1301
+ raise ValueError("beam_width must be greater than 0")
1302
+ if inputs.shape[0] > 1:
1303
+ return "When doing beam_search, batch size must be 1"
1304
+
1305
+ tokenizer = get_tokenizer()
1306
+
1307
+ stop_token = kwargs.get("stop_token", tokenizer.eod)
1308
+ if stop_token is not None:
1309
+ if not isinstance(stop_token, int):
1310
+ raise ValueError("stop_token must be an integer")
1311
+
1312
+ if length_penalty is None:
1313
+ length_penalty = 1.0
1314
+
1315
+ sizes_list = None
1316
+ prompts_tokens_tensor = None
1317
+ prompts_length_tensor = None
1318
+ if torch.distributed.get_rank() == 0:
1319
+ # Get the prompts length.
1320
+ if attention_mask is None:
1321
+ prompts_length_tensor = torch.cuda.LongTensor([inputs.shape[1]] * inputs.shape[0])
1322
+ else:
1323
+ prompts_length_tensor = attention_mask.sum(axis=-1).cuda()
1324
+
1325
+ if max_new_tokens is None:
1326
+ max_new_tokens = max_length - inputs.shape[1]
1327
+ if max_new_tokens <= 0:
1328
+ raise ValueError("max_new_tokens must be greater than 0")
1329
+
1330
+ if add_BOS:
1331
+ max_length = max_new_tokens + inputs.shape[1] + 1
1332
+ # making sure that `max_length` is a multiple of 4 to leverage fused kernels
1333
+ max_length = 4 * math.ceil(max_length / 4)
1334
+ max_new_tokens = max_length - (inputs.shape[1] + 1)
1335
+ padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
1336
+ prompts_tokens_tensor = torch.concat(
1337
+ [torch.unsqueeze(padding[:, 0], axis=-1), inputs.cuda(), padding], axis=-1
1338
+ )
1339
+ else:
1340
+ # making sure that `max_length` is a multiple of 4 to leverage fused kernels
1341
+ max_length = max_new_tokens + inputs.shape[1]
1342
+ max_length = 4 * math.ceil(max_length / 4)
1343
+ max_new_tokens = max_length - inputs.shape[1]
1344
+ padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
1345
+ prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1)
1346
+
1347
+ # We need the sizes of these tensors for the broadcast
1348
+ sizes_list = [
1349
+ prompts_tokens_tensor.size(0), # Batch size
1350
+ prompts_tokens_tensor.size(1),
1351
+ ] # Sequence length
1352
+
1353
+ # First, broadcast the sizes.
1354
+ sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0)
1355
+
1356
+ # Now that we have the sizes, we can broadcast the tokens
1357
+ # and length tensors.
1358
+ sizes = sizes_tensor.tolist()
1359
+ context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0)
1360
+ context_length_tensor = broadcast_tensor(sizes[0], torch.int64, tensor=prompts_length_tensor, rank=0)
1361
+
1362
+ # Run the inference
1363
+ random_seed = kwargs.get("random_seed", 0)
1364
+ torch.random.manual_seed(random_seed)
1365
+ unwrapped_model = unwrap_model(self.base_model, (torchDDP, LocalDDP, Float16Module))
1366
+ if beam_width is not None:
1367
+ tokens, _ = beam_search_and_return_on_first_stage(
1368
+ unwrapped_model,
1369
+ context_tokens_tensor,
1370
+ context_length_tensor,
1371
+ beam_width,
1372
+ stop_token=stop_token,
1373
+ num_return_gen=1,
1374
+ length_penalty=length_penalty,
1375
+ )
1376
+ else:
1377
+ tokens, _, _ = generate_tokens_probs_and_return_on_first_stage(
1378
+ unwrapped_model,
1379
+ context_tokens_tensor,
1380
+ context_length_tensor,
1381
+ return_output_log_probs=False,
1382
+ top_k=top_k,
1383
+ top_p=top_p,
1384
+ top_p_decay=top_p_decay,
1385
+ top_p_bound=top_p_bound,
1386
+ temperature=temperature,
1387
+ use_eod_token_for_early_termination=True,
1388
+ )
1389
+ return tokens
1390
+
1391
+
1392
+ # other utilities
1393
+ def avg_losses_across_data_parallel_group(losses):
1394
+ """
1395
+ Average losses across data parallel group.
1396
+
1397
+ Args:
1398
+ losses (List[Tensor]): List of losses to average across data parallel group.
1399
+ """
1400
+
1401
+ return average_losses_across_data_parallel_group(losses)
1402
+
1403
+
1404
+ def gather_across_data_parallel_groups(tensor):
1405
+ """
1406
+ Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks.
1407
+
1408
+ Args:
1409
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
1410
+ The data to gather across data parallel ranks.
1411
+
1412
+ """
1413
+
1414
+ def _gpu_gather_one(tensor):
1415
+ if tensor.ndim == 0:
1416
+ tensor = tensor.clone()[None]
1417
+ output_tensors = [
1418
+ torch.empty_like(tensor)
1419
+ for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group()))
1420
+ ]
1421
+ torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group())
1422
+ return torch.cat(output_tensors, dim=0)
1423
+
1424
+ return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/memory.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ A collection of utilities for ensuring that training can always occur. Heavily influenced by the
17
+ [toma](https://github.com/BlackHC/toma) library.
18
+ """
19
+
20
+ import functools
21
+ import gc
22
+ import importlib
23
+ import inspect
24
+ import warnings
25
+ from typing import Optional
26
+
27
+ import torch
28
+ from packaging import version
29
+
30
+ from .imports import (
31
+ is_cuda_available,
32
+ is_hpu_available,
33
+ is_ipex_available,
34
+ is_mlu_available,
35
+ is_mps_available,
36
+ is_musa_available,
37
+ is_npu_available,
38
+ is_sdaa_available,
39
+ is_xpu_available,
40
+ )
41
+ from .versions import compare_versions
42
+
43
+
44
+ def clear_device_cache(garbage_collection=False):
45
+ """
46
+ Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that
47
+ this is a *considerable* slowdown and should be used sparingly.
48
+ """
49
+ if garbage_collection:
50
+ gc.collect()
51
+
52
+ if is_xpu_available():
53
+ torch.xpu.empty_cache()
54
+ elif is_mlu_available():
55
+ torch.mlu.empty_cache()
56
+ elif is_sdaa_available():
57
+ torch.sdaa.empty_cache()
58
+ elif is_musa_available():
59
+ torch.musa.empty_cache()
60
+ elif is_npu_available():
61
+ torch.npu.empty_cache()
62
+ elif is_mps_available(min_version="2.0"):
63
+ torch.mps.empty_cache()
64
+ elif is_cuda_available():
65
+ torch.cuda.empty_cache()
66
+ elif is_hpu_available():
67
+ # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
68
+ pass
69
+
70
+
71
+ def release_memory(*objects):
72
+ """
73
+ Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`.
74
+ Returned objects should be reassigned to the same variables.
75
+
76
+ Args:
77
+ objects (`Iterable`):
78
+ An iterable of objects
79
+ Returns:
80
+ A list of `None` objects to replace `objects`
81
+
82
+ Example:
83
+
84
+ ```python
85
+ >>> import torch
86
+ >>> from accelerate.utils import release_memory
87
+
88
+ >>> a = torch.ones(1000, 1000).cuda()
89
+ >>> b = torch.ones(1000, 1000).cuda()
90
+ >>> a, b = release_memory(a, b)
91
+ ```
92
+ """
93
+ if not isinstance(objects, list):
94
+ objects = list(objects)
95
+ for i in range(len(objects)):
96
+ objects[i] = None
97
+ clear_device_cache(garbage_collection=True)
98
+ return objects
99
+
100
+
101
+ def should_reduce_batch_size(exception: Exception) -> bool:
102
+ """
103
+ Checks if `exception` relates to CUDA out-of-memory, XPU out-of-memory, CUDNN not supported, or CPU out-of-memory
104
+
105
+ Args:
106
+ exception (`Exception`):
107
+ An exception
108
+ """
109
+ _statements = [
110
+ " out of memory.", # OOM for CUDA, HIP, XPU
111
+ "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU
112
+ "DefaultCPUAllocator: can't allocate memory", # CPU OOM
113
+ "FATAL ERROR :: MODULE:PT_DEVMEM Allocation failed", # HPU OOM
114
+ ]
115
+ if isinstance(exception, RuntimeError) and len(exception.args) == 1:
116
+ return any(err in exception.args[0] for err in _statements)
117
+ return False
118
+
119
+
120
+ def find_executable_batch_size(
121
+ function: Optional[callable] = None,
122
+ starting_batch_size: int = 128,
123
+ reduce_batch_size_fn: Optional[callable] = None,
124
+ ):
125
+ """
126
+ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
127
+ CUDNN, the batch size is multiplied by 0.9 and passed to `function`
128
+
129
+ `function` must take in a `batch_size` parameter as its first argument.
130
+
131
+ Args:
132
+ function (`callable`, *optional*):
133
+ A function to wrap
134
+ starting_batch_size (`int`, *optional*):
135
+ The batch size to try and fit into memory
136
+
137
+ Example:
138
+
139
+ ```python
140
+ >>> from accelerate.utils import find_executable_batch_size
141
+
142
+
143
+ >>> @find_executable_batch_size(starting_batch_size=128)
144
+ ... def train(batch_size, model, optimizer):
145
+ ... ...
146
+
147
+
148
+ >>> train(model, optimizer)
149
+ ```
150
+ """
151
+ if function is None:
152
+ return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
153
+
154
+ batch_size = starting_batch_size
155
+ if reduce_batch_size_fn is None:
156
+
157
+ def reduce_batch_size_fn():
158
+ nonlocal batch_size
159
+ batch_size = int(batch_size * 0.9)
160
+ return batch_size
161
+
162
+ def decorator(*args, **kwargs):
163
+ nonlocal batch_size
164
+ clear_device_cache(garbage_collection=True)
165
+ params = list(inspect.signature(function).parameters.keys())
166
+ # Guard against user error
167
+ if len(params) < (len(args) + 1):
168
+ arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])])
169
+ raise TypeError(
170
+ f"Batch size was passed into `{function.__name__}` as the first argument when called."
171
+ f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
172
+ )
173
+ while True:
174
+ if batch_size == 0:
175
+ raise RuntimeError("No executable batch size found, reached zero.")
176
+ try:
177
+ return function(batch_size, *args, **kwargs)
178
+ except Exception as e:
179
+ if should_reduce_batch_size(e):
180
+ clear_device_cache(garbage_collection=True)
181
+ batch_size = reduce_batch_size_fn()
182
+ else:
183
+ raise
184
+
185
+ return decorator
186
+
187
+
188
+ def get_xpu_available_memory(device_index: int):
189
+ if version.parse(torch.__version__).release >= version.parse("2.6").release:
190
+ # torch.xpu.mem_get_info API is available starting from PyTorch 2.6
191
+ # It further requires PyTorch built with the SYCL runtime which supports API
192
+ # to query available device memory. If not available, exception will be
193
+ # raised. Version of SYCL runtime used to build PyTorch is being reported
194
+ # with print(torch.version.xpu) and corresponds to the version of Intel DPC++
195
+ # SYCL compiler. First version to support required feature is 20250001.
196
+ try:
197
+ return torch.xpu.mem_get_info(device_index)[0]
198
+ except Exception:
199
+ pass
200
+ elif is_ipex_available():
201
+ ipex_version = version.parse(importlib.metadata.version("intel_extension_for_pytorch"))
202
+ if compare_versions(ipex_version, ">=", "2.5"):
203
+ from intel_extension_for_pytorch.xpu import mem_get_info
204
+
205
+ return mem_get_info(device_index)[0]
206
+
207
+ warnings.warn(
208
+ "The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version."
209
+ )
210
+ return torch.xpu.max_memory_allocated(device_index)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/modeling.py ADDED
@@ -0,0 +1,2186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import gc
17
+ import inspect
18
+ import json
19
+ import logging
20
+ import os
21
+ import re
22
+ import shutil
23
+ import tempfile
24
+ import warnings
25
+ from collections import OrderedDict, defaultdict
26
+ from typing import Optional, Union
27
+
28
+ import torch
29
+ from torch import distributed as dist
30
+ from torch import nn
31
+
32
+ from ..state import AcceleratorState
33
+ from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
34
+ from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
35
+ from .imports import (
36
+ is_hpu_available,
37
+ is_mlu_available,
38
+ is_mps_available,
39
+ is_musa_available,
40
+ is_npu_available,
41
+ is_peft_available,
42
+ is_sdaa_available,
43
+ is_torch_xla_available,
44
+ is_xpu_available,
45
+ )
46
+ from .memory import clear_device_cache, get_xpu_available_memory
47
+ from .offload import load_offloaded_weight, offload_weight, save_offload_index
48
+ from .tqdm import is_tqdm_available, tqdm
49
+ from .versions import is_torch_version
50
+
51
+
52
+ if is_npu_available(check_device=False):
53
+ import torch_npu # noqa: F401
54
+
55
+ if is_mlu_available(check_device=False):
56
+ import torch_mlu # noqa: F401
57
+
58
+ if is_sdaa_available(check_device=False):
59
+ import torch_sdaa # noqa: F401
60
+
61
+ if is_musa_available(check_device=False):
62
+ import torch_musa # noqa: F401
63
+
64
+ from safetensors import safe_open
65
+ from safetensors.torch import load_file as safe_load_file
66
+
67
+
68
+ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ def is_peft_model(model):
74
+ from .other import extract_model_from_parallel
75
+
76
+ if is_peft_available():
77
+ from peft import PeftModel
78
+
79
+ return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel)
80
+
81
+
82
+ def check_device_same(first_device, second_device):
83
+ """
84
+ Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False`
85
+ for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same
86
+
87
+ Args:
88
+ first_device (`torch.device`):
89
+ First device to check
90
+ second_device (`torch.device`):
91
+ Second device to check
92
+ """
93
+ if first_device.type != second_device.type:
94
+ return False
95
+
96
+ if first_device.type != "cpu" and first_device.index is None:
97
+ # In case the first_device is a cuda device and have
98
+ # the index attribute set to `None`, default it to `0`
99
+ first_device = torch.device(first_device.type, index=0)
100
+
101
+ if second_device.type != "cpu" and second_device.index is None:
102
+ # In case the second_device is a cuda device and have
103
+ # the index attribute set to `None`, default it to `0`
104
+ second_device = torch.device(second_device.type, index=0)
105
+
106
+ return first_device == second_device
107
+
108
+
109
+ def convert_file_size_to_int(size: Union[int, str]):
110
+ """
111
+ Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
112
+
113
+ Args:
114
+ size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
115
+
116
+ Example:
117
+
118
+ ```py
119
+ >>> convert_file_size_to_int("1MiB")
120
+ 1048576
121
+ ```
122
+ """
123
+ mem_size = -1
124
+ err_msg = (
125
+ f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')."
126
+ )
127
+ try:
128
+ if isinstance(size, int):
129
+ mem_size = size
130
+ elif size.upper().endswith("GIB"):
131
+ mem_size = int(float(size[:-3]) * (2**30))
132
+ elif size.upper().endswith("MIB"):
133
+ mem_size = int(float(size[:-3]) * (2**20))
134
+ elif size.upper().endswith("KIB"):
135
+ mem_size = int(float(size[:-3]) * (2**10))
136
+ elif size.upper().endswith("GB"):
137
+ int_size = int(float(size[:-2]) * (10**9))
138
+ mem_size = int_size // 8 if size.endswith("b") else int_size
139
+ elif size.upper().endswith("MB"):
140
+ int_size = int(float(size[:-2]) * (10**6))
141
+ mem_size = int_size // 8 if size.endswith("b") else int_size
142
+ elif size.upper().endswith("KB"):
143
+ int_size = int(float(size[:-2]) * (10**3))
144
+ mem_size = int_size // 8 if size.endswith("b") else int_size
145
+ except ValueError:
146
+ raise ValueError(err_msg)
147
+
148
+ if mem_size < 0:
149
+ raise ValueError(err_msg)
150
+ return mem_size
151
+
152
+
153
+ def dtype_byte_size(dtype: torch.dtype):
154
+ """
155
+ Returns the size (in bytes) occupied by one parameter of type `dtype`.
156
+
157
+ Example:
158
+
159
+ ```py
160
+ >>> dtype_byte_size(torch.float32)
161
+ 4
162
+ ```
163
+ """
164
+ if dtype == torch.bool:
165
+ return 1 / 8
166
+ elif dtype == CustomDtype.INT2:
167
+ return 1 / 4
168
+ elif dtype == CustomDtype.INT4:
169
+ return 1 / 2
170
+ elif dtype == CustomDtype.FP8:
171
+ return 1
172
+ elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
173
+ return 1
174
+ bit_search = re.search(r"[^\d](\d+)$", str(dtype))
175
+ if bit_search is None:
176
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
177
+ bit_size = int(bit_search.groups()[0])
178
+ return bit_size // 8
179
+
180
+
181
+ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
182
+ """
183
+ Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
184
+ example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
185
+ guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
186
+ non-overlapping lifetimes may have the same id.
187
+ """
188
+ _SIZE = {
189
+ torch.int64: 8,
190
+ torch.float32: 4,
191
+ torch.int32: 4,
192
+ torch.bfloat16: 2,
193
+ torch.float16: 2,
194
+ torch.int16: 2,
195
+ torch.uint8: 1,
196
+ torch.int8: 1,
197
+ torch.bool: 1,
198
+ torch.float64: 8,
199
+ }
200
+ try:
201
+ storage_ptr = tensor.untyped_storage().data_ptr()
202
+ storage_size = tensor.untyped_storage().nbytes()
203
+ except Exception:
204
+ try:
205
+ # Fallback for torch==1.10
206
+ storage_ptr = tensor.storage().data_ptr()
207
+ storage_size = tensor.storage().size() * _SIZE[tensor.dtype]
208
+ except NotImplementedError:
209
+ # Fallback for meta storage
210
+ storage_ptr = 0
211
+ # On torch >=2.0 this is the tensor size
212
+ storage_size = tensor.nelement() * _SIZE[tensor.dtype]
213
+
214
+ return tensor.device, storage_ptr, storage_size
215
+
216
+
217
+ def set_module_tensor_to_device(
218
+ module: nn.Module,
219
+ tensor_name: str,
220
+ device: Union[int, str, torch.device],
221
+ value: Optional[torch.Tensor] = None,
222
+ dtype: Optional[Union[str, torch.dtype]] = None,
223
+ fp16_statistics: Optional[torch.HalfTensor] = None,
224
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
225
+ non_blocking: bool = False,
226
+ clear_cache: bool = True,
227
+ ):
228
+ """
229
+ A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
230
+ `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
231
+
232
+ Args:
233
+ module (`torch.nn.Module`):
234
+ The module in which the tensor we want to move lives.
235
+ tensor_name (`str`):
236
+ The full name of the parameter/buffer.
237
+ device (`int`, `str` or `torch.device`):
238
+ The device on which to set the tensor.
239
+ value (`torch.Tensor`, *optional*):
240
+ The value of the tensor (useful when going from the meta device to any other device).
241
+ dtype (`torch.dtype`, *optional*):
242
+ If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
243
+ the dtype of the existing parameter in the model.
244
+ fp16_statistics (`torch.HalfTensor`, *optional*):
245
+ The list of fp16 statistics to set on the module, used for 8 bit model serialization.
246
+ tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`):
247
+ A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
248
+ execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
249
+ device for all others, instead of duplicating memory.
250
+ non_blocking (`bool`, *optional*, defaults to `False`):
251
+ If `True`, the device transfer will be asynchronous with respect to the host, if possible.
252
+ clear_cache (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to clear the device cache after setting the tensor on the device.
254
+ """
255
+ # Recurse if needed
256
+ if "." in tensor_name:
257
+ splits = tensor_name.split(".")
258
+ for split in splits[:-1]:
259
+ new_module = getattr(module, split)
260
+ if new_module is None:
261
+ raise ValueError(f"{module} has no attribute {split}.")
262
+ module = new_module
263
+ tensor_name = splits[-1]
264
+
265
+ if tensor_name not in module._parameters and tensor_name not in module._buffers:
266
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
267
+ is_buffer = tensor_name in module._buffers
268
+ old_value = getattr(module, tensor_name)
269
+
270
+ # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight
271
+ # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer.
272
+ if (
273
+ value is not None
274
+ and tied_params_map is not None
275
+ and value.data_ptr() in tied_params_map
276
+ and device in tied_params_map[value.data_ptr()]
277
+ ):
278
+ module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]
279
+ return
280
+ elif (
281
+ tied_params_map is not None
282
+ and old_value.data_ptr() in tied_params_map
283
+ and device in tied_params_map[old_value.data_ptr()]
284
+ ):
285
+ module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]
286
+ return
287
+
288
+ if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
289
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
290
+
291
+ param = module._parameters[tensor_name] if tensor_name in module._parameters else None
292
+ param_cls = type(param)
293
+
294
+ if value is not None:
295
+ # We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights.
296
+ # In other cases, we want to make sure we're not loading checkpoints that do not match the config.
297
+ if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
298
+ raise ValueError(
299
+ f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.'
300
+ )
301
+
302
+ if dtype is None:
303
+ # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
304
+ value = value.to(old_value.dtype, non_blocking=non_blocking)
305
+ elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
306
+ value = value.to(dtype, non_blocking=non_blocking)
307
+
308
+ device_quantization = None
309
+ with torch.no_grad():
310
+ # leave it on cpu first before moving them to cuda
311
+ # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
312
+ if (
313
+ param is not None
314
+ and param.device.type not in ("cuda", "xpu")
315
+ and torch.device(device).type in ("cuda", "xpu")
316
+ and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
317
+ ):
318
+ device_quantization = device
319
+ device = "cpu"
320
+ # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
321
+ if isinstance(device, int):
322
+ if is_npu_available():
323
+ device = f"npu:{device}"
324
+ elif is_mlu_available():
325
+ device = f"mlu:{device}"
326
+ elif is_sdaa_available():
327
+ device = f"sdaa:{device}"
328
+ elif is_musa_available():
329
+ device = f"musa:{device}"
330
+ elif is_hpu_available():
331
+ device = "hpu"
332
+ if "xpu" in str(device) and not is_xpu_available():
333
+ raise ValueError(f'{device} is not available, you should use device="cpu" instead')
334
+ if value is None:
335
+ new_value = old_value.to(device, non_blocking=non_blocking)
336
+ if dtype is not None and device in ["meta", torch.device("meta")]:
337
+ if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
338
+ new_value = new_value.to(dtype, non_blocking=non_blocking)
339
+
340
+ if not is_buffer:
341
+ module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
342
+ elif isinstance(value, torch.Tensor):
343
+ new_value = value.to(device, non_blocking=non_blocking)
344
+ else:
345
+ new_value = torch.tensor(value, device=device)
346
+ if device_quantization is not None:
347
+ device = device_quantization
348
+ if is_buffer:
349
+ module._buffers[tensor_name] = new_value
350
+ elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):
351
+ param_cls = type(module._parameters[tensor_name])
352
+ kwargs = module._parameters[tensor_name].__dict__
353
+ if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
354
+ if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
355
+ # downcast to fp16 if any - needed for 8bit serialization
356
+ new_value = new_value.to(torch.float16, non_blocking=non_blocking)
357
+ # quantize module that are going to stay on the cpu so that we offload quantized weights
358
+ if device == "cpu" and param_cls.__name__ == "Int8Params":
359
+ new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
360
+ new_value.CB = new_value.CB.to("cpu")
361
+ new_value.SCB = new_value.SCB.to("cpu")
362
+ else:
363
+ new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
364
+ device, non_blocking=non_blocking
365
+ )
366
+ elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
367
+ new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(
368
+ device, non_blocking=non_blocking
369
+ )
370
+ elif param_cls.__name__ in ["AffineQuantizedTensor"]:
371
+ new_value = new_value.to(device, non_blocking=non_blocking)
372
+ else:
373
+ new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(
374
+ device, non_blocking=non_blocking
375
+ )
376
+
377
+ module._parameters[tensor_name] = new_value
378
+ if fp16_statistics is not None:
379
+ module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)
380
+ del fp16_statistics
381
+ # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
382
+ if (
383
+ module.__class__.__name__ == "Linear8bitLt"
384
+ and getattr(module.weight, "SCB", None) is None
385
+ and str(module.weight.device) != "meta"
386
+ ):
387
+ # quantize only if necessary
388
+ device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
389
+ if not getattr(module.weight, "SCB", None) and device_index is not None:
390
+ if module.bias is not None and module.bias.device.type != "meta":
391
+ # if a bias exists, we need to wait until the bias is set on the correct device
392
+ module = module.cuda(device_index)
393
+ elif module.bias is None:
394
+ # if no bias exists, we can quantize right away
395
+ module = module.cuda(device_index)
396
+ elif (
397
+ module.__class__.__name__ == "Linear4bit"
398
+ and getattr(module.weight, "quant_state", None) is None
399
+ and str(module.weight.device) != "meta"
400
+ ):
401
+ # quantize only if necessary
402
+ device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
403
+ if not getattr(module.weight, "quant_state", None) and device_index is not None:
404
+ module.weight = module.weight.cuda(device_index)
405
+
406
+ # clean pre and post forward hook
407
+ if clear_cache and device not in ("cpu", "meta"):
408
+ clear_device_cache()
409
+
410
+ # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
411
+ # order to avoid duplicating memory, see above.
412
+ if (
413
+ tied_params_map is not None
414
+ and old_value.data_ptr() in tied_params_map
415
+ and device not in tied_params_map[old_value.data_ptr()]
416
+ ):
417
+ tied_params_map[old_value.data_ptr()][device] = new_value
418
+ elif (
419
+ value is not None
420
+ and tied_params_map is not None
421
+ and value.data_ptr() in tied_params_map
422
+ and device not in tied_params_map[value.data_ptr()]
423
+ ):
424
+ tied_params_map[value.data_ptr()][device] = new_value
425
+
426
+
427
+ def named_module_tensors(
428
+ module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
429
+ ):
430
+ """
431
+ A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
432
+ it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.
433
+
434
+ Args:
435
+ module (`torch.nn.Module`):
436
+ The module we want the tensors on.
437
+ include_buffer (`bool`, *optional*, defaults to `True`):
438
+ Whether or not to include the buffers in the result.
439
+ recurse (`bool`, *optional`, defaults to `False`):
440
+ Whether or not to go look in every submodule or just return the direct parameters and buffers.
441
+ remove_non_persistent (`bool`, *optional*, defaults to `False`):
442
+ Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =
443
+ True
444
+ """
445
+ yield from module.named_parameters(recurse=recurse)
446
+
447
+ if include_buffers:
448
+ non_persistent_buffers = set()
449
+ if remove_non_persistent:
450
+ non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse)
451
+ for named_buffer in module.named_buffers(recurse=recurse):
452
+ name, _ = named_buffer
453
+ if name not in non_persistent_buffers:
454
+ yield named_buffer
455
+
456
+
457
+ def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: bool = False):
458
+ """
459
+ Gather all non persistent buffers of a given modules into a set
460
+
461
+ Args:
462
+ module (`nn.Module`):
463
+ The module we want the non persistent buffers on.
464
+ recurse (`bool`, *optional*, defaults to `False`):
465
+ Whether or not to go look in every submodule or just return the direct non persistent buffers.
466
+ fqns (`bool`, *optional*, defaults to `False`):
467
+ Whether or not to return the fully-qualified names of the non persistent buffers.
468
+ """
469
+
470
+ non_persistent_buffers_set = module._non_persistent_buffers_set
471
+ if recurse:
472
+ for n, m in module.named_modules():
473
+ if fqns:
474
+ non_persistent_buffers_set |= {n + "." + b for b in m._non_persistent_buffers_set}
475
+ else:
476
+ non_persistent_buffers_set |= m._non_persistent_buffers_set
477
+
478
+ return non_persistent_buffers_set
479
+
480
+
481
+ def check_tied_parameters_in_config(model: nn.Module):
482
+ """
483
+ Check if there is any indication in the given model that some weights should be tied.
484
+
485
+ Args:
486
+ model (`torch.nn.Module`): The model to inspect
487
+
488
+ Returns:
489
+ bool: True if the model needs to have tied weights
490
+ """
491
+
492
+ # based on model.tie_weights() method
493
+ has_tied_word_embedding = False
494
+ has_tied_encoder_decoder = False
495
+ has_tied_module = False
496
+
497
+ if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]:
498
+ has_tied_word_embedding = False
499
+ model_decoder_config = None
500
+ if hasattr(model, "config"):
501
+ model_decoder_config = (
502
+ model.config.get_text_config(decoder=True)
503
+ if hasattr(model.config, "get_text_config")
504
+ else model.config
505
+ )
506
+ has_tied_word_embedding = (
507
+ model_decoder_config is not None
508
+ and getattr(model_decoder_config, "tie_word_embeddings", False)
509
+ and model.get_output_embeddings()
510
+ )
511
+
512
+ has_tied_encoder_decoder = (
513
+ hasattr(model, "config")
514
+ and getattr(model.config, "is_encoder_decoder", False)
515
+ and getattr(model.config, "tie_encoder_decoder", False)
516
+ )
517
+ has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules())
518
+ return any([has_tied_word_embedding, has_tied_encoder_decoder, has_tied_module])
519
+
520
+
521
+ def _get_param_device(param, device_map):
522
+ if param in device_map:
523
+ return device_map[param]
524
+ parent_param = ".".join(param.split(".")[:-1])
525
+ if parent_param == param:
526
+ raise ValueError(f"The `device_map` does not contain the module {param}.")
527
+ else:
528
+ return _get_param_device(parent_param, device_map)
529
+
530
+
531
+ def check_tied_parameters_on_same_device(tied_params, device_map):
532
+ """
533
+ Check if tied parameters are on the same device
534
+
535
+ Args:
536
+ tied_params (`List[List[str]]`):
537
+ A list of lists of parameter names being all tied together.
538
+
539
+ device_map (`Dict[str, Union[int, str, torch.device]]`):
540
+ A map that specifies where each submodule should go.
541
+
542
+ """
543
+ for tie_param in tied_params:
544
+ tie_param_devices = {}
545
+ for param in tie_param:
546
+ tie_param_devices[param] = _get_param_device(param, device_map)
547
+ if len(set(tie_param_devices.values())) > 1:
548
+ logger.warning(
549
+ f"Tied parameters are on different devices: {tie_param_devices}. "
550
+ "Please modify your custom device map or set `device_map='auto'`. "
551
+ )
552
+
553
+
554
+ def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]:
555
+ """
556
+ Find the tied parameters in a given model.
557
+
558
+ <Tip warning={true}>
559
+
560
+ The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
561
+ them.
562
+
563
+ </Tip>
564
+
565
+ Args:
566
+ model (`torch.nn.Module`): The model to inspect.
567
+
568
+ Returns:
569
+ List[List[str]]: A list of lists of parameter names being all tied together.
570
+
571
+ Example:
572
+
573
+ ```py
574
+ >>> from collections import OrderedDict
575
+ >>> import torch.nn as nn
576
+
577
+ >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
578
+ >>> model.linear2.weight = model.linear1.weight
579
+ >>> find_tied_parameters(model)
580
+ [['linear1.weight', 'linear2.weight']]
581
+ ```
582
+ """
583
+
584
+ # get ALL model parameters and their names
585
+ all_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=False)}
586
+
587
+ # get ONLY unique named parameters,
588
+ # if parameter is tied and have multiple names, it will be included only once
589
+ no_duplicate_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=True)}
590
+
591
+ # the difference of the two sets will give us the tied parameters
592
+ tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
593
+
594
+ # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
595
+ # which names refer to the same parameter. To identify this, we need to group them together.
596
+ tied_param_groups = {}
597
+ for tied_param_name in tied_param_names:
598
+ tied_param = all_named_parameters[tied_param_name]
599
+ for param_name, param in no_duplicate_named_parameters.items():
600
+ # compare if parameters are the same, if so, group their names together
601
+ if param is tied_param:
602
+ if param_name not in tied_param_groups:
603
+ tied_param_groups[param_name] = []
604
+ tied_param_groups[param_name].append(tied_param_name)
605
+
606
+ return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
607
+
608
+
609
+ def retie_parameters(model, tied_params):
610
+ """
611
+ Reties tied parameters in a given model if the link was broken (for instance when adding hooks).
612
+
613
+ Args:
614
+ model (`torch.nn.Module`):
615
+ The model in which to retie parameters.
616
+ tied_params (`List[List[str]]`):
617
+ A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`.
618
+ """
619
+ for tied_group in tied_params:
620
+ param_to_tie = None
621
+ # two loops : the first one to set param_to_tie , the second one to change the values of tied_group
622
+ for param_name in tied_group:
623
+ module = model
624
+ splits = param_name.split(".")
625
+ for split in splits[:-1]:
626
+ module = getattr(module, split)
627
+ param = getattr(module, splits[-1])
628
+ if param_to_tie is None and param.device != torch.device("meta"):
629
+ param_to_tie = param
630
+ break
631
+ if param_to_tie is not None:
632
+ for param_name in tied_group:
633
+ module = model
634
+ splits = param_name.split(".")
635
+ for split in splits[:-1]:
636
+ module = getattr(module, split)
637
+ setattr(module, splits[-1], param_to_tie)
638
+
639
+
640
+ def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
641
+ """
642
+ Just does torch.dtype(dtype) if necessary.
643
+ """
644
+ if isinstance(dtype, str):
645
+ # We accept "torch.float16" or just "float16"
646
+ dtype = dtype.replace("torch.", "")
647
+ dtype = getattr(torch, dtype)
648
+ return dtype
649
+
650
+
651
+ def compute_module_sizes(
652
+ model: nn.Module,
653
+ dtype: Optional[Union[str, torch.device]] = None,
654
+ special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
655
+ buffers_only: bool = False,
656
+ ):
657
+ """
658
+ Compute the size of each submodule of a given model.
659
+ """
660
+ if dtype is not None:
661
+ dtype = _get_proper_dtype(dtype)
662
+ dtype_size = dtype_byte_size(dtype)
663
+ if special_dtypes is not None:
664
+ special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
665
+ special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
666
+ module_sizes = defaultdict(int)
667
+
668
+ module_list = []
669
+
670
+ if not buffers_only:
671
+ module_list = named_module_tensors(model, recurse=True)
672
+ else:
673
+ module_list = model.named_buffers(recurse=True)
674
+
675
+ for name, tensor in module_list:
676
+ if special_dtypes is not None and name in special_dtypes:
677
+ size = tensor.numel() * special_dtypes_size[name]
678
+ elif dtype is None:
679
+ size = tensor.numel() * dtype_byte_size(tensor.dtype)
680
+ elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
681
+ # According to the code in set_module_tensor_to_device, these types won't be converted
682
+ # so use their original size here
683
+ size = tensor.numel() * dtype_byte_size(tensor.dtype)
684
+ else:
685
+ size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
686
+ name_parts = name.split(".")
687
+ for idx in range(len(name_parts) + 1):
688
+ module_sizes[".".join(name_parts[:idx])] += size
689
+
690
+ return module_sizes
691
+
692
+
693
+ def compute_module_total_buffer_size(
694
+ model: nn.Module,
695
+ dtype: Optional[Union[str, torch.device]] = None,
696
+ special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
697
+ ):
698
+ """
699
+ Compute the total size of buffers in each submodule of a given model.
700
+ """
701
+ module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True)
702
+ return module_sizes.get("", 0)
703
+
704
+
705
+ def get_max_layer_size(
706
+ modules: list[tuple[str, torch.nn.Module]], module_sizes: dict[str, int], no_split_module_classes: list[str]
707
+ ):
708
+ """
709
+ Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
710
+ definition of a layer being:
711
+ - a module with no direct children (just parameters and buffers)
712
+ - a module whose class name is in the list `no_split_module_classes`
713
+
714
+ Args:
715
+ modules (`List[Tuple[str, torch.nn.Module]]`):
716
+ The list of named modules where we want to determine the maximum layer size.
717
+ module_sizes (`Dict[str, int]`):
718
+ A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
719
+ no_split_module_classes (`List[str]`):
720
+ A list of class names for layers we don't want to be split.
721
+
722
+ Returns:
723
+ `Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size.
724
+ """
725
+ max_size = 0
726
+ layer_names = []
727
+ modules_to_treat = modules.copy()
728
+ while len(modules_to_treat) > 0:
729
+ module_name, module = modules_to_treat.pop(0)
730
+ modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else []
731
+ if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
732
+ # No splitting this one so we compare to the max_size
733
+ size = module_sizes[module_name]
734
+ if size > max_size:
735
+ max_size = size
736
+ layer_names = [module_name]
737
+ elif size == max_size:
738
+ layer_names.append(module_name)
739
+ else:
740
+ modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat
741
+ return max_size, layer_names
742
+
743
+
744
+ def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None):
745
+ """
746
+ Get the maximum memory available if nothing is passed, converts string to int otherwise.
747
+ """
748
+ import psutil
749
+
750
+ if max_memory is None:
751
+ max_memory = {}
752
+ # Make sure CUDA is initialized on each GPU to have the right memory info.
753
+ if is_npu_available():
754
+ for i in range(torch.npu.device_count()):
755
+ try:
756
+ _ = torch.tensor(0, device=torch.device("npu", i))
757
+ max_memory[i] = torch.npu.mem_get_info(i)[0]
758
+ except Exception:
759
+ logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
760
+ continue
761
+ elif is_mlu_available():
762
+ for i in range(torch.mlu.device_count()):
763
+ try:
764
+ _ = torch.tensor(0, device=torch.device("mlu", i))
765
+ max_memory[i] = torch.mlu.mem_get_info(i)[0]
766
+ except Exception:
767
+ logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
768
+ continue
769
+ elif is_sdaa_available():
770
+ for i in range(torch.sdaa.device_count()):
771
+ try:
772
+ _ = torch.tensor(0, device=torch.device("sdaa", i))
773
+ max_memory[i] = torch.sdaa.mem_get_info(i)[0]
774
+ except Exception:
775
+ logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
776
+ continue
777
+ elif is_musa_available():
778
+ for i in range(torch.musa.device_count()):
779
+ try:
780
+ _ = torch.tensor(0, device=torch.device("musa", i))
781
+ max_memory[i] = torch.musa.mem_get_info(i)[0]
782
+ except Exception:
783
+ logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
784
+ continue
785
+ elif is_xpu_available():
786
+ for i in range(torch.xpu.device_count()):
787
+ try:
788
+ _ = torch.tensor(0, device=torch.device("xpu", i))
789
+ max_memory[i] = get_xpu_available_memory(i)
790
+ except Exception:
791
+ logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
792
+ continue
793
+ elif is_hpu_available():
794
+ for i in range(torch.hpu.device_count()):
795
+ try:
796
+ _ = torch.tensor(0, device=torch.device("hpu", i))
797
+ max_memory[i] = torch.hpu.mem_get_info(i)[0]
798
+ except Exception:
799
+ logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
800
+ continue
801
+ else:
802
+ for i in range(torch.cuda.device_count()):
803
+ try:
804
+ _ = torch.tensor([0], device=i)
805
+ max_memory[i] = torch.cuda.mem_get_info(i)[0]
806
+ except Exception:
807
+ logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
808
+ continue
809
+ # allocate everything in the mps device as the RAM is shared
810
+ if is_mps_available():
811
+ max_memory["mps"] = psutil.virtual_memory().available
812
+ else:
813
+ max_memory["cpu"] = psutil.virtual_memory().available
814
+ return max_memory
815
+
816
+ for key in max_memory:
817
+ if isinstance(max_memory[key], str):
818
+ max_memory[key] = convert_file_size_to_int(max_memory[key])
819
+
820
+ # Need to sort the device by type to make sure that we allocate the gpu first.
821
+ # As gpu/npu/xpu are represented by int, we need to sort them first.
822
+ gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)]
823
+ gpu_devices.sort()
824
+ # check if gpu/npu/xpu devices are available and if not, throw a warning
825
+ if is_npu_available():
826
+ num_devices = torch.npu.device_count()
827
+ elif is_mlu_available():
828
+ num_devices = torch.mlu.device_count()
829
+ elif is_sdaa_available():
830
+ num_devices = torch.sdaa.device_count()
831
+ elif is_musa_available():
832
+ num_devices = torch.musa.device_count()
833
+ elif is_xpu_available():
834
+ num_devices = torch.xpu.device_count()
835
+ elif is_hpu_available():
836
+ num_devices = torch.hpu.device_count()
837
+ else:
838
+ num_devices = torch.cuda.device_count()
839
+ for device in gpu_devices:
840
+ if device >= num_devices or device < 0:
841
+ logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}")
842
+ # Add the other devices in the preset order if they are available
843
+ all_devices = gpu_devices + [k for k in ["mps", "cpu", "disk"] if k in max_memory.keys()]
844
+ # Raise an error if a device is not recognized
845
+ for k in max_memory.keys():
846
+ if k not in all_devices:
847
+ raise ValueError(
848
+ f"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'"
849
+ )
850
+ max_memory = {k: max_memory[k] for k in all_devices}
851
+
852
+ return max_memory
853
+
854
+
855
+ def clean_device_map(device_map: dict[str, Union[int, str, torch.device]], module_name: str = ""):
856
+ """
857
+ Cleans a device_map by grouping all submodules that go on the same device together.
858
+ """
859
+ # Get the value of the current module and if there is only one split across several keys, regroup it.
860
+ prefix = "" if module_name == "" else f"{module_name}."
861
+ values = [v for k, v in device_map.items() if k.startswith(prefix)]
862
+ if len(set(values)) == 1 and len(values) > 1:
863
+ for k in [k for k in device_map if k.startswith(prefix)]:
864
+ del device_map[k]
865
+ device_map[module_name] = values[0]
866
+
867
+ # Recurse over the children
868
+ children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)]
869
+ idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1
870
+ children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules)
871
+ for child in children_modules:
872
+ clean_device_map(device_map, module_name=child)
873
+
874
+ return device_map
875
+
876
+
877
+ def load_offloaded_weights(model, index, offload_folder):
878
+ """
879
+ Loads the weights from the offload folder into the model.
880
+
881
+ Args:
882
+ model (`torch.nn.Module`):
883
+ The model to load the weights into.
884
+ index (`dict`):
885
+ A dictionary containing the parameter name and its metadata for each parameter that was offloaded from the
886
+ model.
887
+ offload_folder (`str`):
888
+ The folder where the offloaded weights are stored.
889
+ """
890
+ if index is None or len(index) == 0:
891
+ # Nothing to do
892
+ return
893
+ for param_name, metadata in index.items():
894
+ if "SCB" in param_name:
895
+ continue
896
+ fp16_statistics = None
897
+ if "weight" in param_name and param_name.replace("weight", "SCB") in index.keys():
898
+ weight_name = param_name.replace("weight", "SCB")
899
+ fp16_statistics = load_offloaded_weight(
900
+ os.path.join(offload_folder, f"{weight_name}.dat"), index[weight_name]
901
+ )
902
+ tensor_file = os.path.join(offload_folder, f"{param_name}.dat")
903
+ weight = load_offloaded_weight(tensor_file, metadata)
904
+ set_module_tensor_to_device(model, param_name, "cpu", value=weight, fp16_statistics=fp16_statistics)
905
+
906
+
907
+ def get_module_leaves(module_sizes):
908
+ module_children = {}
909
+ for module in module_sizes:
910
+ if module == "" or "." not in module:
911
+ continue
912
+ parent = module.rsplit(".", 1)[0]
913
+ module_children[parent] = module_children.get(parent, 0) + 1
914
+ leaves = [module for module in module_sizes if module_children.get(module, 0) == 0 and module != ""]
915
+ return leaves
916
+
917
+
918
+ def get_balanced_memory(
919
+ model: nn.Module,
920
+ max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
921
+ no_split_module_classes: Optional[list[str]] = None,
922
+ dtype: Optional[Union[str, torch.dtype]] = None,
923
+ special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
924
+ low_zero: bool = False,
925
+ ):
926
+ """
927
+ Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.
928
+
929
+ <Tip>
930
+
931
+ All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
932
+ meta device (as it would if initialized within the `init_empty_weights` context manager).
933
+
934
+ </Tip>
935
+
936
+ Args:
937
+ model (`torch.nn.Module`):
938
+ The model to analyze.
939
+ max_memory (`Dict`, *optional*):
940
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
941
+ Example: `max_memory={0: "1GB"}`.
942
+ no_split_module_classes (`List[str]`, *optional*):
943
+ A list of layer class names that should never be split across device (for instance any layer that has a
944
+ residual connection).
945
+ dtype (`str` or `torch.dtype`, *optional*):
946
+ If provided, the weights will be converted to that type when loaded.
947
+ special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
948
+ If provided, special dtypes to consider for some specific weights (will override dtype used as default for
949
+ all weights).
950
+ low_zero (`bool`, *optional*):
951
+ Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the
952
+ Transformers generate function).
953
+ """
954
+ # Get default / clean up max_memory
955
+ user_not_set_max_memory = max_memory is None
956
+ max_memory = get_max_memory(max_memory)
957
+
958
+ if is_npu_available():
959
+ expected_device_type = "npu"
960
+ elif is_mlu_available():
961
+ expected_device_type = "mlu"
962
+ elif is_sdaa_available():
963
+ expected_device_type = "sdaa"
964
+ elif is_musa_available():
965
+ expected_device_type = "musa"
966
+ elif is_xpu_available():
967
+ expected_device_type = "xpu"
968
+ elif is_hpu_available():
969
+ expected_device_type = "hpu"
970
+ elif is_mps_available():
971
+ expected_device_type = "mps"
972
+ else:
973
+ expected_device_type = "cuda"
974
+ num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0])
975
+
976
+ if num_devices == 0:
977
+ return max_memory
978
+
979
+ if num_devices == 1:
980
+ # We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer
981
+ low_zero = False
982
+ # If user just asked us to handle memory usage, we should avoid OOM
983
+ if user_not_set_max_memory:
984
+ for key in max_memory.keys():
985
+ if isinstance(key, int):
986
+ max_memory[key] *= 0.9 # 90% is a good compromise
987
+ logger.info(
988
+ f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. "
989
+ "You can set `max_memory` in to a higher value to use more memory (at your own risk)."
990
+ )
991
+ break # only one device
992
+
993
+ module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
994
+ per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)
995
+
996
+ # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get
997
+ # slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to
998
+ # add which is the biggest of:
999
+ # - the size of no split block (if applicable)
1000
+ # - the mean of the layer sizes
1001
+ if no_split_module_classes is None:
1002
+ no_split_module_classes = []
1003
+ elif not isinstance(no_split_module_classes, (list, tuple)):
1004
+ no_split_module_classes = [no_split_module_classes]
1005
+
1006
+ # Identify the size of the no_split_block modules
1007
+ if len(no_split_module_classes) > 0:
1008
+ no_split_children = {}
1009
+ for name, size in module_sizes.items():
1010
+ if name == "":
1011
+ continue
1012
+ submodule = model
1013
+ for submodule_name in name.split("."):
1014
+ submodule = getattr(submodule, submodule_name)
1015
+ class_name = submodule.__class__.__name__
1016
+ if class_name in no_split_module_classes and class_name not in no_split_children:
1017
+ no_split_children[class_name] = size
1018
+
1019
+ if set(no_split_children.keys()) == set(no_split_module_classes):
1020
+ break
1021
+ buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0
1022
+ else:
1023
+ buffer = 0
1024
+
1025
+ # Compute mean of final modules. In the first dict of module sizes, leaves are the parameters
1026
+ leaves = get_module_leaves(module_sizes)
1027
+ leaves_set = set(leaves) # Convert to set for O(1) membership testing
1028
+ module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}
1029
+ # Once removed, leaves are the final modules.
1030
+ leaves = get_module_leaves(module_sizes)
1031
+ mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))
1032
+ buffer = int(1.25 * max(buffer, mean_leaves))
1033
+ per_gpu += buffer
1034
+
1035
+ # Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them)
1036
+ gpus_idx_list = list(
1037
+ sorted(
1038
+ device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0
1039
+ )
1040
+ )
1041
+ # The last device is left with max_memory just in case the buffer is not enough.
1042
+ for idx in gpus_idx_list[:-1]:
1043
+ max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])
1044
+
1045
+ if low_zero:
1046
+ min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)]))
1047
+ max_memory[0] = min(min_zero, max_memory[0])
1048
+
1049
+ return max_memory
1050
+
1051
+
1052
+ def calculate_maximum_sizes(model: torch.nn.Module):
1053
+ "Computes the total size of the model and its largest layer"
1054
+ sizes = compute_module_sizes(model)
1055
+ # `transformers` models store this information for us
1056
+ no_split_modules = getattr(model, "_no_split_modules", None)
1057
+ if no_split_modules is None:
1058
+ no_split_modules = []
1059
+
1060
+ modules_to_treat = (
1061
+ list(model.named_parameters(recurse=False))
1062
+ + list(model.named_children())
1063
+ + list(model.named_buffers(recurse=False))
1064
+ )
1065
+ largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules)
1066
+ total_size = sizes[""]
1067
+ return total_size, largest_layer
1068
+
1069
+
1070
+ def _init_infer_auto_device_map(
1071
+ model: nn.Module,
1072
+ max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
1073
+ no_split_module_classes: Optional[list[str]] = None,
1074
+ dtype: Optional[Union[str, torch.dtype]] = None,
1075
+ special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
1076
+ ) -> tuple[
1077
+ list[Union[int, str]],
1078
+ dict[Union[int, str], Union[int, str]],
1079
+ list[Union[int, str]],
1080
+ list[int],
1081
+ dict[str, int],
1082
+ list[list[str]],
1083
+ list[str],
1084
+ list[tuple[str, nn.Module]],
1085
+ ]:
1086
+ """
1087
+ Initialize variables required for computing the device map for model allocation.
1088
+ """
1089
+ max_memory = get_max_memory(max_memory)
1090
+ if no_split_module_classes is None:
1091
+ no_split_module_classes = []
1092
+ elif not isinstance(no_split_module_classes, (list, tuple)):
1093
+ no_split_module_classes = [no_split_module_classes]
1094
+
1095
+ devices = list(max_memory.keys())
1096
+ if "disk" not in devices:
1097
+ devices.append("disk")
1098
+ gpus = [device for device in devices if device not in ["cpu", "disk"]]
1099
+
1100
+ # Devices that need to keep space for a potential offloaded layer.
1101
+ if "mps" in gpus:
1102
+ main_devices = ["mps"]
1103
+ elif len(gpus) > 0:
1104
+ main_devices = [gpus[0], "cpu"]
1105
+ else:
1106
+ main_devices = ["cpu"]
1107
+
1108
+ module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
1109
+ tied_parameters = find_tied_parameters(model)
1110
+ if check_tied_parameters_in_config(model) and len(tied_parameters) == 0:
1111
+ logger.warning(
1112
+ "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
1113
+ )
1114
+
1115
+ # Direct submodules and parameters
1116
+ modules_to_treat = (
1117
+ list(model.named_parameters(recurse=False))
1118
+ + list(model.named_children())
1119
+ + list(model.named_buffers(recurse=False))
1120
+ )
1121
+
1122
+ return (
1123
+ devices,
1124
+ max_memory,
1125
+ main_devices,
1126
+ gpus,
1127
+ module_sizes,
1128
+ tied_parameters,
1129
+ no_split_module_classes,
1130
+ modules_to_treat,
1131
+ )
1132
+
1133
+
1134
+ def get_module_size_with_ties(
1135
+ tied_params,
1136
+ module_size,
1137
+ module_sizes,
1138
+ modules_to_treat,
1139
+ ) -> tuple[int, list[str], list[nn.Module]]:
1140
+ """
1141
+ Calculate the total size of a module, including its tied parameters.
1142
+
1143
+ Args:
1144
+ tied_params (`List[str]`): The list of tied parameters.
1145
+ module_size (`int`): The size of the module without tied parameters.
1146
+ module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size.
1147
+ modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat.
1148
+
1149
+ Returns:
1150
+ `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the
1151
+ tied modules.
1152
+ """
1153
+ if len(tied_params) < 1:
1154
+ return module_size, [], []
1155
+ tied_module_names = []
1156
+ tied_modules = []
1157
+
1158
+ for tied_param in tied_params:
1159
+ tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0]
1160
+ tied_module_names.append(modules_to_treat[tied_module_index][0])
1161
+ tied_modules.append(modules_to_treat[tied_module_index][1])
1162
+
1163
+ module_size_with_ties = module_size
1164
+ for tied_param, tied_module_name in zip(tied_params, tied_module_names):
1165
+ module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]
1166
+
1167
+ return module_size_with_ties, tied_module_names, tied_modules
1168
+
1169
+
1170
+ def fallback_allocate(
1171
+ modules: list[tuple[str, nn.Module]],
1172
+ module_sizes: dict[str, int],
1173
+ size_limit: Union[int, str],
1174
+ no_split_module_classes: Optional[list[str]] = None,
1175
+ tied_parameters: Optional[list[list[str]]] = None,
1176
+ ) -> tuple[Optional[str], Optional[nn.Module], list[tuple[str, nn.Module]]]:
1177
+ """
1178
+ Find a module that fits in the size limit using BFS and return it with its name and the remaining modules.
1179
+
1180
+ Args:
1181
+ modules (`List[Tuple[str, nn.Module]]`):
1182
+ The list of named modules to search in.
1183
+ module_sizes (`Dict[str, int]`):
1184
+ A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
1185
+ size_limit (`Union[int, str]`):
1186
+ The maximum size a module can have.
1187
+ no_split_module_classes (`Optional[List[str]]`, *optional*):
1188
+ A list of class names for layers we don't want to be split.
1189
+ tied_parameters (`Optional[List[List[str]]`, *optional*):
1190
+ A list of lists of parameter names being all tied together.
1191
+
1192
+ Returns:
1193
+ `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: A tuple containing:
1194
+ - The name of the module that fits within the size limit.
1195
+ - The module itself.
1196
+ - The list of remaining modules after the found module is removed.
1197
+ """
1198
+ try:
1199
+ size_limit = convert_file_size_to_int(size_limit)
1200
+ except ValueError:
1201
+ return None, None, modules
1202
+
1203
+ if no_split_module_classes is None:
1204
+ no_split_module_classes = []
1205
+
1206
+ if tied_parameters is None:
1207
+ tied_parameters = []
1208
+
1209
+ modules_to_search = modules.copy()
1210
+ module_found = False
1211
+
1212
+ while modules_to_search:
1213
+ name, module = modules_to_search.pop(0)
1214
+
1215
+ tied_param_groups = [
1216
+ tied_group
1217
+ for tied_group in tied_parameters
1218
+ if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
1219
+ ]
1220
+
1221
+ tied_params = sum(
1222
+ [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
1223
+ )
1224
+
1225
+ module_size_with_ties, _, _ = get_module_size_with_ties(
1226
+ tied_params, module_sizes[name], module_sizes, modules_to_search
1227
+ )
1228
+
1229
+ # If the module fits in the size limit, we found it.
1230
+ if module_size_with_ties <= size_limit:
1231
+ module_found = True
1232
+ break
1233
+
1234
+ # The module is too big, we need to split it if possible.
1235
+ modules_children = (
1236
+ []
1237
+ if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
1238
+ else list(module.named_children())
1239
+ )
1240
+
1241
+ # Split fails, move to the next module
1242
+ if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
1243
+ continue
1244
+
1245
+ # split is possible, add the children to the list of modules to search
1246
+ modules_children = list(module.named_parameters(recurse=False)) + modules_children
1247
+ modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search
1248
+
1249
+ if not module_found:
1250
+ return None, None, modules
1251
+
1252
+ # Prepare the module list for removal of the found module
1253
+ current_names = [n for n, _ in modules]
1254
+ dot_idx = [i for i, c in enumerate(name) if c == "."]
1255
+
1256
+ for dot_index in dot_idx:
1257
+ parent_name = name[:dot_index]
1258
+ if parent_name in current_names:
1259
+ parent_module_idx = current_names.index(parent_name)
1260
+ _, parent_module = modules[parent_module_idx]
1261
+ module_children = list(parent_module.named_parameters(recurse=False)) + list(
1262
+ parent_module.named_children()
1263
+ )
1264
+ modules = (
1265
+ modules[:parent_module_idx]
1266
+ + [(f"{parent_name}.{n}", v) for n, v in module_children]
1267
+ + modules[parent_module_idx + 1 :]
1268
+ )
1269
+ current_names = [n for n, _ in modules]
1270
+
1271
+ # Now the target module should be directly in the list
1272
+ target_idx = current_names.index(name)
1273
+ name, module = modules.pop(target_idx)
1274
+
1275
+ return name, module, modules
1276
+
1277
+
1278
+ def infer_auto_device_map(
1279
+ model: nn.Module,
1280
+ max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
1281
+ no_split_module_classes: Optional[list[str]] = None,
1282
+ dtype: Optional[Union[str, torch.dtype]] = None,
1283
+ special_dtypes: Optional[dict[str, Union[str, torch.dtype]]] = None,
1284
+ verbose: bool = False,
1285
+ clean_result: bool = True,
1286
+ offload_buffers: bool = False,
1287
+ fallback_allocation: bool = False,
1288
+ ):
1289
+ """
1290
+ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
1291
+ such that:
1292
+ - we don't exceed the memory available of any of the GPU.
1293
+ - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that
1294
+ has the largest size.
1295
+ - if offload to the CPU is needed,we don't exceed the RAM available on the CPU.
1296
+ - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk
1297
+ that has the largest size.
1298
+
1299
+ <Tip>
1300
+
1301
+ All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
1302
+ meta device (as it would if initialized within the `init_empty_weights` context manager).
1303
+
1304
+ </Tip>
1305
+
1306
+ Args:
1307
+ model (`torch.nn.Module`):
1308
+ The model to analyze.
1309
+ max_memory (`Dict`, *optional*):
1310
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
1311
+ Example: `max_memory={0: "1GB"}`.
1312
+ no_split_module_classes (`List[str]`, *optional*):
1313
+ A list of layer class names that should never be split across device (for instance any layer that has a
1314
+ residual connection).
1315
+ dtype (`str` or `torch.dtype`, *optional*):
1316
+ If provided, the weights will be converted to that type when loaded.
1317
+ special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
1318
+ If provided, special dtypes to consider for some specific weights (will override dtype used as default for
1319
+ all weights).
1320
+ verbose (`bool`, *optional*, defaults to `False`):
1321
+ Whether or not to provide debugging statements as the function builds the device_map.
1322
+ clean_result (`bool`, *optional*, defaults to `True`):
1323
+ Clean the resulting device_map by grouping all submodules that go on the same device together.
1324
+ offload_buffers (`bool`, *optional*, defaults to `False`):
1325
+ In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
1326
+ well as the parameters.
1327
+ fallback_allocation (`bool`, *optional*, defaults to `False`):
1328
+ When regular allocation fails, try to allocate a module that fits in the size limit using BFS.
1329
+ """
1330
+
1331
+ # Initialize the variables
1332
+ (
1333
+ devices,
1334
+ max_memory,
1335
+ main_devices,
1336
+ gpus,
1337
+ module_sizes,
1338
+ tied_parameters,
1339
+ no_split_module_classes,
1340
+ modules_to_treat,
1341
+ ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes)
1342
+
1343
+ device_map = OrderedDict()
1344
+ current_device = 0
1345
+ device_memory_used = {device: 0 for device in devices}
1346
+ device_buffer_sizes = {}
1347
+ device_minimum_assignment_memory = {}
1348
+
1349
+ # Initialize maximum largest layer, to know which space to keep in memory
1350
+ max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
1351
+
1352
+ # Ready ? This is going to be a bit messy.
1353
+ while len(modules_to_treat) > 0:
1354
+ name, module = modules_to_treat.pop(0)
1355
+ if verbose:
1356
+ print(f"\nTreating module {name}.")
1357
+ # Max size in the remaining layers may have changed since we took one, so we maybe update it.
1358
+ max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
1359
+ if len(max_layer_names) == 0:
1360
+ max_layer_size, max_layer_names = get_max_layer_size(
1361
+ [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
1362
+ module_sizes,
1363
+ no_split_module_classes,
1364
+ )
1365
+ # Assess size needed
1366
+ module_size = module_sizes[name]
1367
+
1368
+ # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
1369
+ # and the other is not.
1370
+ # Note: If we are currently processing the name `compute.weight`, an other parameter named
1371
+ # e.g. `compute.weight_submodule.parameter`
1372
+ # needs to be considered outside the current module, hence the check with additional dots.
1373
+ tied_param_groups = [
1374
+ tied_group
1375
+ for tied_group in tied_parameters
1376
+ if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
1377
+ ]
1378
+
1379
+ if verbose and len(tied_param_groups) > 0:
1380
+ print(f" Found the relevant tied param groups {tied_param_groups}")
1381
+
1382
+ # Then we keep track of all the parameters that are tied to the current module, but not in the current module
1383
+ tied_params = sum(
1384
+ [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
1385
+ )
1386
+
1387
+ if verbose and len(tied_params) > 0:
1388
+ print(f" So those parameters need to be taken into account {tied_params}")
1389
+
1390
+ device = devices[current_device]
1391
+ current_max_size = max_memory[device] if device != "disk" else None
1392
+ current_memory_reserved = 0
1393
+ # Reduce max size available by the largest layer.
1394
+ if devices[current_device] in main_devices:
1395
+ current_max_size = current_max_size - max_layer_size
1396
+ current_memory_reserved = max_layer_size
1397
+
1398
+ module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(
1399
+ tied_params, module_size, module_sizes, modules_to_treat
1400
+ )
1401
+
1402
+ # The module and its tied modules fit on the current device.
1403
+ if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
1404
+ if verbose:
1405
+ output = f"Putting {name}"
1406
+
1407
+ if tied_module_names:
1408
+ output += f" and {tied_module_names}"
1409
+ else:
1410
+ output += f" (size={module_size})"
1411
+
1412
+ if current_max_size is not None:
1413
+ output += f" (available={current_max_size - device_memory_used[device]})"
1414
+
1415
+ output += f" on {device}."
1416
+ print(output)
1417
+
1418
+ device_memory_used[device] += module_size_with_ties
1419
+
1420
+ # Assign the primary module to the device.
1421
+ device_map[name] = device
1422
+
1423
+ # Assign tied modules if any.
1424
+ for tied_module_name in tied_module_names:
1425
+ if tied_module_name in [m[0] for m in modules_to_treat]:
1426
+ # Find the index of the tied module in the list
1427
+ tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name)
1428
+ # Remove the tied module from the list to prevent reprocessing
1429
+ modules_to_treat.pop(tied_module_index)
1430
+
1431
+ # Assign the tied module to the device
1432
+ device_map[tied_module_name] = device
1433
+
1434
+ # Buffer Handling
1435
+ if not offload_buffers and isinstance(module, nn.Module):
1436
+ # Compute the total buffer size for the module
1437
+ current_buffer_size = compute_module_total_buffer_size(
1438
+ module, dtype=dtype, special_dtypes=special_dtypes
1439
+ )
1440
+ # Update the buffer size on the device
1441
+ device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size
1442
+
1443
+ continue
1444
+
1445
+ # The current module itself fits, so we try to split the tied modules.
1446
+ if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size:
1447
+ # can we split one of the tied modules to make it smaller or do we need to go on the next device?
1448
+ if verbose:
1449
+ print(
1450
+ f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
1451
+ f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
1452
+ )
1453
+ split_happened = False
1454
+ for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
1455
+ tied_module_children = list(tied_module.named_children())
1456
+ if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
1457
+ # can't break this one.
1458
+ continue
1459
+
1460
+ if verbose:
1461
+ print(f"Splitting {tied_module_name}.")
1462
+ tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
1463
+ tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
1464
+ tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]
1465
+
1466
+ modules_to_treat = (
1467
+ [(name, module)]
1468
+ + modules_to_treat[:tied_module_index]
1469
+ + tied_module_children
1470
+ + modules_to_treat[tied_module_index + 1 :]
1471
+ )
1472
+ # Update the max layer size.
1473
+ max_layer_size, max_layer_names = get_max_layer_size(
1474
+ [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
1475
+ module_sizes,
1476
+ no_split_module_classes,
1477
+ )
1478
+ split_happened = True
1479
+ break
1480
+
1481
+ if split_happened:
1482
+ continue
1483
+
1484
+ # If the tied module is not split, we go to the next device
1485
+ if verbose:
1486
+ print("None of the tied module can be split, going to the next device.")
1487
+
1488
+ # The current module itself doesn't fit, so we have to split it or go to the next device.
1489
+ if device_memory_used[device] + module_size >= current_max_size:
1490
+ # Split or not split?
1491
+ modules_children = (
1492
+ []
1493
+ if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
1494
+ else list(module.named_children())
1495
+ )
1496
+ if verbose:
1497
+ print(
1498
+ f"Not enough space on {devices[current_device]} to put {name} (space available "
1499
+ f"{current_max_size - device_memory_used[device]}, module size {module_size})."
1500
+ )
1501
+ if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
1502
+ # -> no split, we go to the next device
1503
+ if verbose:
1504
+ print("This module cannot be split, going to the next device.")
1505
+
1506
+ else:
1507
+ # -> split, we replace the module studied by its children + parameters
1508
+ if verbose:
1509
+ print(f"Splitting {name}.")
1510
+ modules_children = list(module.named_parameters(recurse=False)) + modules_children
1511
+ modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
1512
+ # Update the max layer size.
1513
+ max_layer_size, max_layer_names = get_max_layer_size(
1514
+ [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
1515
+ module_sizes,
1516
+ no_split_module_classes,
1517
+ )
1518
+ continue
1519
+
1520
+ # If no module is assigned to the current device, we attempt to allocate a fallback module
1521
+ # if fallback_allocation is enabled.
1522
+ if device_memory_used[device] == 0 and fallback_allocation and device != "disk":
1523
+ # We try to allocate a module that fits in the size limit using BFS.
1524
+ # Recompute the current max size as we need to consider the current module as well.
1525
+ current_max_size = max_memory[device] - max(max_layer_size, module_size_with_ties)
1526
+
1527
+ fallback_module_name, fallback_module, remaining_modules = fallback_allocate(
1528
+ modules_to_treat,
1529
+ module_sizes,
1530
+ current_max_size - device_memory_used[device],
1531
+ no_split_module_classes,
1532
+ tied_parameters,
1533
+ )
1534
+ # use the next iteration to put the fallback module on the next device to avoid code duplication
1535
+ if fallback_module is not None:
1536
+ modules_to_treat = [(fallback_module_name, fallback_module)] + [(name, module)] + remaining_modules
1537
+ continue
1538
+
1539
+ if device_memory_used[device] == 0:
1540
+ device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved
1541
+
1542
+ # Neither the current module nor any tied modules can be split, so we move to the next device.
1543
+ device_memory_used[device] = device_memory_used[device] + current_memory_reserved
1544
+ current_device += 1
1545
+ modules_to_treat = [(name, module)] + modules_to_treat
1546
+
1547
+ device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}
1548
+
1549
+ if clean_result:
1550
+ device_map = clean_device_map(device_map)
1551
+
1552
+ non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)
1553
+ if non_gpu_buffer_size > 0 and not offload_buffers:
1554
+ is_buffer_fit_any_gpu = False
1555
+ for gpu_device, gpu_max_memory in max_memory.items():
1556
+ if gpu_device == "cpu" or gpu_device == "disk":
1557
+ continue
1558
+
1559
+ if not is_buffer_fit_any_gpu:
1560
+ gpu_memory_used = device_memory_used.get(gpu_device, 0)
1561
+
1562
+ if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
1563
+ is_buffer_fit_any_gpu = True
1564
+
1565
+ if len(gpus) > 0 and not is_buffer_fit_any_gpu:
1566
+ warnings.warn(
1567
+ f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does "
1568
+ f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using "
1569
+ f"offload_buffers=True."
1570
+ )
1571
+
1572
+ if device_minimum_assignment_memory:
1573
+ devices_info = "\n".join(
1574
+ f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items()
1575
+ )
1576
+ logger.info(
1577
+ f"Based on the current allocation process, no modules could be assigned to the following devices due to "
1578
+ f"insufficient memory:\n"
1579
+ f"{devices_info}\n"
1580
+ f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing "
1581
+ f"the available memory for these devices to at least the specified minimum, or adjusting the model config."
1582
+ )
1583
+ return device_map
1584
+
1585
+
1586
+ def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, torch.device]]):
1587
+ """
1588
+ Checks a device map covers everything in a given model.
1589
+
1590
+ Args:
1591
+ model (`torch.nn.Module`): The model to check the device map against.
1592
+ device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.
1593
+ """
1594
+ all_module_names = dict(model.named_modules())
1595
+ invalid_keys = [k for k in device_map if k != "" and k not in all_module_names]
1596
+
1597
+ if invalid_keys:
1598
+ warnings.warn(
1599
+ f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning
1600
+ )
1601
+
1602
+ all_model_tensors = [name for name, _ in model.state_dict().items()]
1603
+ for module_name in device_map.keys():
1604
+ if module_name == "":
1605
+ all_model_tensors.clear()
1606
+ break
1607
+ else:
1608
+ all_model_tensors = [
1609
+ name
1610
+ for name in all_model_tensors
1611
+ if not name == module_name and not name.startswith(module_name + ".")
1612
+ ]
1613
+ if len(all_model_tensors) > 0:
1614
+ non_covered_params = ", ".join(all_model_tensors)
1615
+ raise ValueError(
1616
+ f"The device_map provided does not give any device for the following parameters: {non_covered_params}"
1617
+ )
1618
+
1619
+
1620
+ def load_state_dict(checkpoint_file, device_map=None):
1621
+ """
1622
+ Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the
1623
+ weights can be fast-loaded directly on the GPU.
1624
+
1625
+ Args:
1626
+ checkpoint_file (`str`): The path to the checkpoint to load.
1627
+ device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
1628
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
1629
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
1630
+ """
1631
+ if checkpoint_file.endswith(".safetensors"):
1632
+ with safe_open(checkpoint_file, framework="pt") as f:
1633
+ metadata = f.metadata()
1634
+ weight_names = f.keys()
1635
+
1636
+ if metadata is None:
1637
+ logger.warning(
1638
+ f"The safetensors archive passed at {checkpoint_file} does not contain metadata. "
1639
+ "Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata."
1640
+ )
1641
+ metadata = {"format": "pt"}
1642
+
1643
+ if metadata.get("format") not in ["pt", "tf", "flax"]:
1644
+ raise OSError(
1645
+ f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
1646
+ "you save your model with the `save_pretrained` method."
1647
+ )
1648
+ elif metadata["format"] != "pt":
1649
+ raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.")
1650
+ if device_map is None:
1651
+ return safe_load_file(checkpoint_file)
1652
+ else:
1653
+ # if we only have one device we can load everything directly
1654
+ if len(set(device_map.values())) == 1:
1655
+ device = list(device_map.values())[0]
1656
+ target_device = device
1657
+ if isinstance(device, int):
1658
+ if is_npu_available():
1659
+ target_device = f"npu:{device}"
1660
+ elif is_hpu_available():
1661
+ target_device = "hpu"
1662
+
1663
+ return safe_load_file(checkpoint_file, device=target_device)
1664
+
1665
+ devices = list(set(device_map.values()) - {"disk"})
1666
+ # cpu device should always exist as fallback option
1667
+ if "cpu" not in devices:
1668
+ devices.append("cpu")
1669
+
1670
+ # For each device, get the weights that go there
1671
+ device_weights = {device: [] for device in devices}
1672
+ for module_name, device in device_map.items():
1673
+ if device in devices:
1674
+ device_weights[device].extend(
1675
+ [k for k in weight_names if k == module_name or k.startswith(module_name + ".")]
1676
+ )
1677
+
1678
+ # all weights that haven't defined a device should be loaded on CPU
1679
+ device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])])
1680
+ tensors = {}
1681
+ if is_tqdm_available():
1682
+ progress_bar = tqdm(
1683
+ main_process_only=False,
1684
+ total=sum([len(device_weights[device]) for device in devices]),
1685
+ unit="w",
1686
+ smoothing=0,
1687
+ leave=False,
1688
+ )
1689
+ else:
1690
+ progress_bar = None
1691
+ for device in devices:
1692
+ target_device = device
1693
+ if isinstance(device, int):
1694
+ if is_npu_available():
1695
+ target_device = f"npu:{device}"
1696
+ elif is_hpu_available():
1697
+ target_device = "hpu"
1698
+
1699
+ with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
1700
+ for key in device_weights[device]:
1701
+ if progress_bar is not None:
1702
+ progress_bar.set_postfix(dev=device, refresh=False)
1703
+ progress_bar.set_description(key)
1704
+ tensors[key] = f.get_tensor(key)
1705
+ if progress_bar is not None:
1706
+ progress_bar.update()
1707
+ if progress_bar is not None:
1708
+ progress_bar.close()
1709
+
1710
+ return tensors
1711
+ else:
1712
+ return torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)
1713
+
1714
+
1715
+ def get_state_dict_offloaded_model(model: nn.Module):
1716
+ """
1717
+ Returns the state dictionary for an offloaded model via iterative onloading
1718
+
1719
+ Args:
1720
+ model (`torch.nn.Module`):
1721
+ The offloaded model we want to save
1722
+ """
1723
+
1724
+ state_dict = {}
1725
+ placeholders = set()
1726
+ for name, module in model.named_modules():
1727
+ if name == "":
1728
+ continue
1729
+
1730
+ try:
1731
+ with align_module_device(module, "cpu"):
1732
+ module_state_dict = module.state_dict()
1733
+ except MemoryError:
1734
+ raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None
1735
+
1736
+ for key in module_state_dict:
1737
+ # ignore placeholder parameters that are still on the meta device
1738
+ if module_state_dict[key].device == torch.device("meta"):
1739
+ placeholders.add(name + f".{key}")
1740
+ continue
1741
+ params = module_state_dict[key]
1742
+ state_dict[name + f".{key}"] = params.to("cpu") # move buffers to cpu
1743
+ for key in placeholders.copy():
1744
+ if key in state_dict:
1745
+ placeholders.remove(key)
1746
+ if placeholders:
1747
+ logger.warning(f"The following tensors were not saved because they were still on meta device: {placeholders}")
1748
+
1749
+ return state_dict
1750
+
1751
+
1752
+ def get_state_dict_from_offload(
1753
+ module: nn.Module,
1754
+ module_name: str,
1755
+ state_dict: dict[str, Union[str, torch.tensor]],
1756
+ device_to_put_offload: Union[int, str, torch.device] = "cpu",
1757
+ ):
1758
+ """
1759
+ Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defaults
1760
+ to cpu).
1761
+
1762
+ Args:
1763
+ module: (`torch.nn.Module`):
1764
+ The module we want to retrieve a state dictionary from
1765
+ module_name: (`str`):
1766
+ The name of the module of interest
1767
+ state_dict (`Dict[str, Union[int, str, torch.device]]`):
1768
+ Dictionary of {module names: parameters}
1769
+ device_to_put_offload (`Union[int, str, torch.device]`):
1770
+ Device to load offloaded parameters into, defaults to the cpu.
1771
+ """
1772
+
1773
+ root = module_name[: module_name.rfind(".")] # module name without .weight or .bias
1774
+
1775
+ # do not move parameters if the module is not offloaded
1776
+ if not has_offloaded_params(module):
1777
+ device_to_put_offload = None
1778
+
1779
+ # assign the device to which the offloaded parameters will be sent
1780
+ with align_module_device(module, device_to_put_offload):
1781
+ for m_key, params in module.state_dict().items():
1782
+ if (root + f".{m_key}") in state_dict:
1783
+ state_dict[root + f".{m_key}"] = params
1784
+
1785
+ return state_dict
1786
+
1787
+
1788
+ def load_checkpoint_in_model(
1789
+ model: nn.Module,
1790
+ checkpoint: Union[str, os.PathLike],
1791
+ device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
1792
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
1793
+ dtype: Optional[Union[str, torch.dtype]] = None,
1794
+ offload_state_dict: bool = False,
1795
+ offload_buffers: bool = False,
1796
+ keep_in_fp32_modules: Optional[list[str]] = None,
1797
+ offload_8bit_bnb: bool = False,
1798
+ strict: bool = False,
1799
+ full_state_dict: bool = True,
1800
+ broadcast_from_rank0: bool = False,
1801
+ ):
1802
+ """
1803
+ Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
1804
+ loaded.
1805
+
1806
+ <Tip warning={true}>
1807
+
1808
+ Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To
1809
+ group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`].
1810
+
1811
+ </Tip>
1812
+
1813
+ Args:
1814
+ model (`torch.nn.Module`):
1815
+ The model in which we want to load a checkpoint.
1816
+ checkpoint (`str` or `os.PathLike`):
1817
+ The folder checkpoint to load. It can be:
1818
+ - a path to a file containing a whole model state dict
1819
+ - a path to a `.json` file containing the index to a sharded checkpoint
1820
+ - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
1821
+ - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.
1822
+ device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
1823
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
1824
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
1825
+ offload_folder (`str` or `os.PathLike`, *optional*):
1826
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
1827
+ dtype (`str` or `torch.dtype`, *optional*):
1828
+ If provided, the weights will be converted to that type when loaded.
1829
+ offload_state_dict (`bool`, *optional*, defaults to `False`):
1830
+ If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
1831
+ the weight of the CPU state dict + the biggest shard does not fit.
1832
+ offload_buffers (`bool`, *optional*, defaults to `False`):
1833
+ Whether or not to include the buffers in the weights offloaded to disk.
1834
+ keep_in_fp32_modules(`List[str]`, *optional*):
1835
+ A list of the modules that we keep in `torch.float32` dtype.
1836
+ offload_8bit_bnb (`bool`, *optional*):
1837
+ Whether or not to enable offload of 8-bit modules on cpu/disk.
1838
+ strict (`bool`, *optional*, defaults to `False`):
1839
+ Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
1840
+ state_dict.
1841
+ full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the
1842
+ loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.
1843
+ broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed
1844
+ `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors
1845
+ in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)
1846
+ according to the local shards in the model.
1847
+
1848
+ """
1849
+ if offload_8bit_bnb:
1850
+ from .bnb import quantize_and_offload_8bit
1851
+
1852
+ tied_params = find_tied_parameters(model)
1853
+
1854
+ if check_tied_parameters_in_config(model) and len(tied_params) == 0:
1855
+ logger.warning(
1856
+ "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
1857
+ )
1858
+ if device_map is not None:
1859
+ check_tied_parameters_on_same_device(tied_params, device_map)
1860
+
1861
+ if offload_folder is None and device_map is not None and "disk" in device_map.values():
1862
+ raise ValueError(
1863
+ "At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`."
1864
+ )
1865
+ elif offload_folder is not None and device_map is not None and "disk" in device_map.values():
1866
+ os.makedirs(offload_folder, exist_ok=True)
1867
+
1868
+ if isinstance(dtype, str):
1869
+ # We accept "torch.float16" or just "float16"
1870
+ dtype = dtype.replace("torch.", "")
1871
+ dtype = getattr(torch, dtype)
1872
+
1873
+ checkpoint_files = None
1874
+ index_filename = None
1875
+ if os.path.isfile(checkpoint):
1876
+ if str(checkpoint).endswith(".json"):
1877
+ index_filename = checkpoint
1878
+ else:
1879
+ checkpoint_files = [checkpoint]
1880
+ elif os.path.isdir(checkpoint):
1881
+ # check if the whole state dict is present
1882
+ potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
1883
+ potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME]
1884
+ if len(potential_state_bin) == 1:
1885
+ checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
1886
+ elif len(potential_state_safetensor) == 1:
1887
+ checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]
1888
+ else:
1889
+ # otherwise check for sharded checkpoints
1890
+ potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
1891
+ if len(potential_index) == 0:
1892
+ raise ValueError(
1893
+ f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
1894
+ )
1895
+ elif len(potential_index) == 1:
1896
+ index_filename = os.path.join(checkpoint, potential_index[0])
1897
+ else:
1898
+ raise ValueError(
1899
+ f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
1900
+ )
1901
+ else:
1902
+ raise ValueError(
1903
+ "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
1904
+ f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
1905
+ )
1906
+
1907
+ if index_filename is not None:
1908
+ checkpoint_folder = os.path.split(index_filename)[0]
1909
+ with open(index_filename) as f:
1910
+ index = json.loads(f.read())
1911
+
1912
+ if "weight_map" in index:
1913
+ index = index["weight_map"]
1914
+ checkpoint_files = sorted(list(set(index.values())))
1915
+ checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
1916
+
1917
+ # Logic for missing/unexepected keys goes here.
1918
+
1919
+ offload_index = {}
1920
+ if offload_state_dict:
1921
+ state_dict_folder = tempfile.mkdtemp()
1922
+ state_dict_index = {}
1923
+
1924
+ unexpected_keys = set()
1925
+ model_keys = set(model.state_dict().keys())
1926
+ buffer_names = [name for name, _ in model.named_buffers()]
1927
+ model_devices = {t.device for t in model.state_dict().values() if isinstance(t, torch.Tensor)}
1928
+ model_physical_devices = model_devices - {torch.device("meta")}
1929
+ for checkpoint_file in checkpoint_files:
1930
+ if device_map is None:
1931
+ # exception for multi-device loading was made for the meta device in torch v2.7.0
1932
+ # https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/checkpoint/state_dict.py#L557-L563
1933
+ # https://github.com/pytorch/pytorch/blob/v2.7.0-rc2/torch/distributed/checkpoint/state_dict.py#L575-L587
1934
+ if is_torch_version(">=", "2.2.0") and (
1935
+ (is_torch_version(">=", "2.7.0") and len(model_physical_devices) <= 1) or len(model_devices) <= 1
1936
+ ):
1937
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
1938
+
1939
+ broadcast_from_rank0 &= is_torch_version(">=", "2.4.0")
1940
+ loaded_checkpoint = (
1941
+ load_state_dict(checkpoint_file, device_map=device_map)
1942
+ if not broadcast_from_rank0 or dist.get_rank() == 0
1943
+ else {}
1944
+ )
1945
+ set_model_state_dict(
1946
+ model,
1947
+ loaded_checkpoint,
1948
+ options=StateDictOptions(
1949
+ full_state_dict=full_state_dict,
1950
+ strict=strict,
1951
+ **({"broadcast_from_rank0": broadcast_from_rank0} if is_torch_version(">=", "2.4.0") else {}),
1952
+ ),
1953
+ )
1954
+ else:
1955
+ loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
1956
+ model.load_state_dict(loaded_checkpoint, strict=strict)
1957
+
1958
+ unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys)
1959
+ else:
1960
+ loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
1961
+
1962
+ for param_name, param in loaded_checkpoint.items():
1963
+ # skip SCB parameter (for 8-bit serialization)
1964
+ if "SCB" in param_name:
1965
+ continue
1966
+
1967
+ if param_name not in model_keys:
1968
+ unexpected_keys.add(param_name)
1969
+ if not strict:
1970
+ continue # Skip loading this parameter.
1971
+
1972
+ module_name = param_name
1973
+
1974
+ while len(module_name) > 0 and module_name not in device_map:
1975
+ module_name = ".".join(module_name.split(".")[:-1])
1976
+ if module_name == "" and "" not in device_map:
1977
+ # TODO: group all errors and raise at the end.
1978
+ raise ValueError(f"{param_name} doesn't have any device set.")
1979
+ param_device = device_map[module_name]
1980
+ new_dtype = dtype
1981
+ if dtype is not None and torch.is_floating_point(param):
1982
+ if keep_in_fp32_modules is not None and dtype == torch.float16:
1983
+ proceed = False
1984
+ for key in keep_in_fp32_modules:
1985
+ if ((key in param_name) and (key + "." in param_name)) or key == param_name:
1986
+ proceed = True
1987
+ break
1988
+ if proceed:
1989
+ new_dtype = torch.float32
1990
+
1991
+ if "weight" in param_name and param_name.replace("weight", "SCB") in loaded_checkpoint.keys():
1992
+ if param.dtype == torch.int8:
1993
+ fp16_statistics = loaded_checkpoint[param_name.replace("weight", "SCB")]
1994
+ else:
1995
+ fp16_statistics = None
1996
+
1997
+ if param_device == "disk":
1998
+ if offload_buffers or param_name not in buffer_names:
1999
+ if new_dtype is None:
2000
+ new_dtype = param.dtype
2001
+ if offload_8bit_bnb:
2002
+ quantize_and_offload_8bit(
2003
+ model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics
2004
+ )
2005
+ continue
2006
+ else:
2007
+ set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
2008
+ offload_weight(param, param_name, offload_folder, index=offload_index)
2009
+ elif param_device == "cpu" and offload_state_dict:
2010
+ if new_dtype is None:
2011
+ new_dtype = param.dtype
2012
+ if offload_8bit_bnb:
2013
+ quantize_and_offload_8bit(
2014
+ model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics
2015
+ )
2016
+ else:
2017
+ set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
2018
+ offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
2019
+ else:
2020
+ set_module_tensor_to_device(
2021
+ model,
2022
+ param_name,
2023
+ param_device,
2024
+ value=param,
2025
+ dtype=new_dtype,
2026
+ fp16_statistics=fp16_statistics,
2027
+ )
2028
+
2029
+ # Force Python to clean up.
2030
+ del loaded_checkpoint
2031
+ gc.collect()
2032
+
2033
+ if not strict and len(unexpected_keys) > 0:
2034
+ logger.warning(
2035
+ f"Some weights of the model checkpoint at {checkpoint} were not used when"
2036
+ f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint."
2037
+ )
2038
+
2039
+ save_offload_index(offload_index, offload_folder)
2040
+
2041
+ # Load back offloaded state dict on CPU
2042
+ if offload_state_dict:
2043
+ load_offloaded_weights(model, state_dict_index, state_dict_folder)
2044
+ shutil.rmtree(state_dict_folder)
2045
+
2046
+ retie_parameters(model, tied_params)
2047
+
2048
+
2049
+ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwargs: AutocastKwargs = None):
2050
+ """
2051
+ Return a context manager for autocasting mixed precision
2052
+
2053
+ Args:
2054
+ native_amp (`bool`, *optional*, defaults to False):
2055
+ Whether mixed precision is actually enabled.
2056
+ cache_enabled (`bool`, *optional*, defaults to True):
2057
+ Whether the weight cache inside autocast should be enabled.
2058
+ """
2059
+ state = AcceleratorState()
2060
+ if autocast_kwargs is None:
2061
+ autocast_kwargs = {}
2062
+ else:
2063
+ autocast_kwargs = autocast_kwargs.to_kwargs()
2064
+ if native_amp:
2065
+ device_type = (
2066
+ "cuda"
2067
+ if (state.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_gpu=True))
2068
+ else state.device.type
2069
+ )
2070
+ if state.mixed_precision == "fp16":
2071
+ return torch.autocast(device_type=device_type, dtype=torch.float16, **autocast_kwargs)
2072
+ elif state.mixed_precision in ["bf16", "fp8"] and state.distributed_type in [
2073
+ DistributedType.NO,
2074
+ DistributedType.MULTI_CPU,
2075
+ DistributedType.MULTI_GPU,
2076
+ DistributedType.MULTI_MLU,
2077
+ DistributedType.MULTI_SDAA,
2078
+ DistributedType.MULTI_MUSA,
2079
+ DistributedType.MULTI_NPU,
2080
+ DistributedType.MULTI_XPU,
2081
+ DistributedType.MULTI_HPU,
2082
+ DistributedType.FSDP,
2083
+ DistributedType.XLA,
2084
+ ]:
2085
+ return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)
2086
+ else:
2087
+ return torch.autocast(device_type=device_type, **autocast_kwargs)
2088
+ else:
2089
+ return contextlib.nullcontext()
2090
+
2091
+
2092
+ def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
2093
+ """
2094
+ A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return
2095
+ it.
2096
+
2097
+ Args:
2098
+ distributed_type (`DistributedType`, *optional*, defaults to None):
2099
+ The type of distributed environment.
2100
+ kwargs:
2101
+ Additional arguments for the utilized `GradScaler` constructor.
2102
+ """
2103
+ if distributed_type == DistributedType.FSDP:
2104
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
2105
+
2106
+ return ShardedGradScaler(**kwargs)
2107
+ if is_torch_xla_available(check_is_gpu=True):
2108
+ import torch_xla.amp as xamp
2109
+
2110
+ return xamp.GradScaler(**kwargs)
2111
+ elif is_mlu_available():
2112
+ return torch.mlu.amp.GradScaler(**kwargs)
2113
+ elif is_sdaa_available():
2114
+ return torch.sdaa.amp.GradScaler(**kwargs)
2115
+ elif is_musa_available():
2116
+ return torch.musa.amp.GradScaler(**kwargs)
2117
+ elif is_npu_available():
2118
+ return torch.npu.amp.GradScaler(**kwargs)
2119
+ elif is_hpu_available():
2120
+ return torch.amp.GradScaler("hpu", **kwargs)
2121
+ elif is_xpu_available():
2122
+ return torch.amp.GradScaler("xpu", **kwargs)
2123
+ elif is_mps_available():
2124
+ if not is_torch_version(">=", "2.8.0"):
2125
+ raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
2126
+ return torch.amp.GradScaler("mps", **kwargs)
2127
+ else:
2128
+ if is_torch_version(">=", "2.3"):
2129
+ return torch.amp.GradScaler("cuda", **kwargs)
2130
+ else:
2131
+ return torch.cuda.amp.GradScaler(**kwargs)
2132
+
2133
+
2134
+ def has_offloaded_params(module: torch.nn.Module) -> bool:
2135
+ """
2136
+ Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with
2137
+ offloading enabled
2138
+
2139
+ Args:
2140
+ module (`torch.nn.Module`): The module to check for an offload hook.
2141
+
2142
+ Returns:
2143
+ bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise.
2144
+ """
2145
+ from ..hooks import AlignDevicesHook # avoid circular import
2146
+
2147
+ return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload
2148
+
2149
+
2150
+ @contextlib.contextmanager
2151
+ def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None):
2152
+ """
2153
+ Context manager that moves a module's parameters to the specified execution device.
2154
+
2155
+ Args:
2156
+ module (`torch.nn.Module`):
2157
+ Module with parameters to align.
2158
+ execution_device (`torch.device`, *optional*):
2159
+ If provided, overrides the module's execution device within the context. Otherwise, use hook execution
2160
+ device or pass
2161
+ """
2162
+ if has_offloaded_params(module):
2163
+ if execution_device is not None:
2164
+ original_device = module._hf_hook.execution_device
2165
+ module._hf_hook.execution_device = execution_device
2166
+
2167
+ try:
2168
+ module._hf_hook.pre_forward(module)
2169
+ yield
2170
+ finally:
2171
+ module._hf_hook.post_forward(module, None)
2172
+ if execution_device is not None:
2173
+ module._hf_hook.execution_device = original_device
2174
+
2175
+ elif execution_device is not None:
2176
+ devices = {name: param.device for name, param in module.named_parameters(recurse=False)}
2177
+ try:
2178
+ for name in devices:
2179
+ set_module_tensor_to_device(module, name, execution_device)
2180
+ yield
2181
+ finally:
2182
+ for name, device in devices.items():
2183
+ set_module_tensor_to_device(module, name, device)
2184
+
2185
+ else:
2186
+ yield
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/offload.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from collections.abc import Mapping
18
+ from typing import Optional, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ from safetensors import safe_open
23
+
24
+
25
+ def offload_weight(weight, weight_name, offload_folder, index=None):
26
+ dtype = None
27
+ # Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16.
28
+ if str(weight.dtype) == "torch.bfloat16":
29
+ # Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s.
30
+ weight = weight.view(torch.int16)
31
+ dtype = "bfloat16"
32
+ array = weight.cpu().numpy()
33
+ tensor_file = os.path.join(offload_folder, f"{weight_name}.dat")
34
+ if index is not None:
35
+ if dtype is None:
36
+ dtype = str(array.dtype)
37
+ index[weight_name] = {"dtype": dtype, "shape": list(array.shape)}
38
+ if array.ndim == 0:
39
+ array = array[None]
40
+ file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
41
+ file_array[:] = array[:]
42
+ file_array.flush()
43
+ return index
44
+
45
+
46
+ def load_offloaded_weight(weight_file, weight_info):
47
+ shape = tuple(weight_info["shape"])
48
+ if shape == ():
49
+ # NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor
50
+ shape = (1,)
51
+
52
+ dtype = weight_info["dtype"]
53
+ if dtype == "bfloat16":
54
+ # NumPy does not support bfloat16 so this was saved as a int16
55
+ dtype = "int16"
56
+
57
+ weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode="r")
58
+
59
+ if len(weight_info["shape"]) == 0:
60
+ weight = weight[0]
61
+ weight = torch.tensor(weight)
62
+ if weight_info["dtype"] == "bfloat16":
63
+ weight = weight.view(torch.bfloat16)
64
+
65
+ return weight
66
+
67
+
68
+ def save_offload_index(index, offload_folder):
69
+ if index is None or len(index) == 0:
70
+ # Nothing to save
71
+ return
72
+
73
+ offload_index_file = os.path.join(offload_folder, "index.json")
74
+ if os.path.isfile(offload_index_file):
75
+ with open(offload_index_file, encoding="utf-8") as f:
76
+ current_index = json.load(f)
77
+ else:
78
+ current_index = {}
79
+ current_index.update(index)
80
+
81
+ with open(offload_index_file, "w", encoding="utf-8") as f:
82
+ json.dump(current_index, f, indent=2)
83
+
84
+
85
+ def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: dict[str, torch.Tensor]):
86
+ """
87
+ Offload a state dict in a given folder.
88
+
89
+ Args:
90
+ save_dir (`str` or `os.PathLike`):
91
+ The directory in which to offload the state dict.
92
+ state_dict (`Dict[str, torch.Tensor]`):
93
+ The dictionary of tensors to offload.
94
+ """
95
+ os.makedirs(save_dir, exist_ok=True)
96
+ index = {}
97
+ for name, parameter in state_dict.items():
98
+ index = offload_weight(parameter, name, save_dir, index=index)
99
+
100
+ # Update index
101
+ save_offload_index(index, save_dir)
102
+
103
+
104
+ class PrefixedDataset(Mapping):
105
+ """
106
+ Will access keys in a given dataset by adding a prefix.
107
+
108
+ Args:
109
+ dataset (`Mapping`): Any map with string keys.
110
+ prefix (`str`): A prefix to add when trying to access any element in the underlying dataset.
111
+ """
112
+
113
+ def __init__(self, dataset: Mapping, prefix: str):
114
+ self.dataset = dataset
115
+ self.prefix = prefix
116
+
117
+ def __getitem__(self, key):
118
+ return self.dataset[f"{self.prefix}{key}"]
119
+
120
+ def __iter__(self):
121
+ return iter([key for key in self.dataset if key.startswith(self.prefix)])
122
+
123
+ def __len__(self):
124
+ return len(self.dataset)
125
+
126
+
127
+ class OffloadedWeightsLoader(Mapping):
128
+ """
129
+ A collection that loads weights stored in a given state dict or memory-mapped on disk.
130
+
131
+ Args:
132
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
133
+ A dictionary parameter name to tensor.
134
+ save_folder (`str` or `os.PathLike`, *optional*):
135
+ The directory in which the weights are stored (by `offload_state_dict` for instance).
136
+ index (`Dict`, *optional*):
137
+ A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
138
+ to the index saved in `save_folder`.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ state_dict: Optional[dict[str, torch.Tensor]] = None,
144
+ save_folder: Optional[Union[str, os.PathLike]] = None,
145
+ index: Optional[Mapping] = None,
146
+ device=None,
147
+ ):
148
+ if state_dict is None and save_folder is None and index is None:
149
+ raise ValueError("Need either a `state_dict`, a `save_folder` or an `index` containing offloaded weights.")
150
+
151
+ self.state_dict = {} if state_dict is None else state_dict
152
+ self.save_folder = save_folder
153
+ if index is None and save_folder is not None:
154
+ with open(os.path.join(save_folder, "index.json")) as f:
155
+ index = json.load(f)
156
+ self.index = {} if index is None else index
157
+ self.all_keys = list(self.state_dict.keys())
158
+ self.all_keys.extend([key for key in self.index if key not in self.all_keys])
159
+ self.device = device
160
+
161
+ def __getitem__(self, key: str):
162
+ # State dict gets priority
163
+ if key in self.state_dict:
164
+ return self.state_dict[key]
165
+ weight_info = self.index[key]
166
+ if weight_info.get("safetensors_file") is not None:
167
+ device = "cpu" if self.device is None else self.device
168
+ tensor = None
169
+ try:
170
+ with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f:
171
+ tensor = f.get_tensor(weight_info.get("weight_name", key))
172
+ except TypeError:
173
+ # if failed to get_tensor on the device, such as bf16 on mps, try to load it on CPU first
174
+ with safe_open(weight_info["safetensors_file"], framework="pt", device="cpu") as f:
175
+ tensor = f.get_tensor(weight_info.get("weight_name", key))
176
+
177
+ if "dtype" in weight_info:
178
+ tensor = tensor.to(getattr(torch, weight_info["dtype"]))
179
+
180
+ if tensor.device != torch.device(device):
181
+ tensor = tensor.to(device)
182
+ return tensor
183
+
184
+ weight_file = os.path.join(self.save_folder, f"{key}.dat")
185
+ return load_offloaded_weight(weight_file, weight_info)
186
+
187
+ def __iter__(self):
188
+ return iter(self.all_keys)
189
+
190
+ def __len__(self):
191
+ return len(self.all_keys)
192
+
193
+
194
+ def extract_submodules_state_dict(state_dict: dict[str, torch.Tensor], submodule_names: list[str]):
195
+ """
196
+ Extract the sub state-dict corresponding to a list of given submodules.
197
+
198
+ Args:
199
+ state_dict (`Dict[str, torch.Tensor]`): The state dict to extract from.
200
+ submodule_names (`List[str]`): The list of submodule names we want to extract.
201
+ """
202
+ result = {}
203
+ for module_name in submodule_names:
204
+ # We want to catch module_name parameter (module_name.xxx) or potentially module_name, but not any of the
205
+ # submodules that could being like module_name (transformers.h.1 and transformers.h.10 for instance)
206
+ result.update(
207
+ {
208
+ key: param
209
+ for key, param in state_dict.items()
210
+ if key == module_name or key.startswith(module_name + ".")
211
+ }
212
+ )
213
+ return result
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ A set of basic tensor ops compatible with tpu, gpu, and multigpu
16
+ """
17
+
18
+ import pickle
19
+ import warnings
20
+ from collections.abc import Mapping
21
+ from contextlib import contextmanager, nullcontext
22
+ from functools import update_wrapper, wraps
23
+ from typing import Any
24
+
25
+ import torch
26
+
27
+ from ..state import AcceleratorState, PartialState
28
+ from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
29
+ from .dataclasses import DistributedType, TensorInformation
30
+ from .imports import (
31
+ is_npu_available,
32
+ is_torch_distributed_available,
33
+ is_torch_xla_available,
34
+ )
35
+ from .versions import is_torch_version
36
+
37
+
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ if is_torch_distributed_available():
42
+ from torch.distributed import ReduceOp
43
+
44
+
45
+ def is_torch_tensor(tensor):
46
+ return isinstance(tensor, torch.Tensor)
47
+
48
+
49
+ def is_torch_xpu_tensor(tensor):
50
+ return isinstance(
51
+ tensor,
52
+ torch.xpu.FloatTensor,
53
+ torch.xpu.ByteTensor,
54
+ torch.xpu.IntTensor,
55
+ torch.xpu.LongTensor,
56
+ torch.xpu.HalfTensor,
57
+ torch.xpu.DoubleTensor,
58
+ torch.xpu.BFloat16Tensor,
59
+ )
60
+
61
+
62
+ def is_tensor_information(tensor_info):
63
+ return isinstance(tensor_info, TensorInformation)
64
+
65
+
66
+ def is_namedtuple(data):
67
+ """
68
+ Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a
69
+ `namedtuple` perfectly.
70
+ """
71
+ return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields")
72
+
73
+
74
+ def honor_type(obj, generator):
75
+ """
76
+ Cast a generator to the same type as obj (list, tuple, or namedtuple)
77
+ """
78
+ # Some objects may not be able to instantiate from a generator directly
79
+ if is_namedtuple(obj):
80
+ return type(obj)(*list(generator))
81
+ else:
82
+ return type(obj)(generator)
83
+
84
+
85
+ def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs):
86
+ """
87
+ Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
88
+
89
+ Args:
90
+ func (`callable`):
91
+ The function to recursively apply.
92
+ data (nested list/tuple/dictionary of `main_type`):
93
+ The data on which to apply `func`
94
+ *args:
95
+ Positional arguments that will be passed to `func` when applied on the unpacked data.
96
+ main_type (`type`, *optional*, defaults to `torch.Tensor`):
97
+ The base type of the objects to which apply `func`.
98
+ error_on_other_type (`bool`, *optional*, defaults to `False`):
99
+ Whether to return an error or not if after unpacking `data`, we get on an object that is not of type
100
+ `main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged.
101
+ **kwargs (additional keyword arguments, *optional*):
102
+ Keyword arguments that will be passed to `func` when applied on the unpacked data.
103
+
104
+ Returns:
105
+ The same data structure as `data` with `func` applied to every object of type `main_type`.
106
+ """
107
+ if isinstance(data, (tuple, list)):
108
+ return honor_type(
109
+ data,
110
+ (
111
+ recursively_apply(
112
+ func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
113
+ )
114
+ for o in data
115
+ ),
116
+ )
117
+ elif isinstance(data, Mapping):
118
+ return type(data)(
119
+ {
120
+ k: recursively_apply(
121
+ func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
122
+ )
123
+ for k, v in data.items()
124
+ }
125
+ )
126
+ elif test_type(data):
127
+ return func(data, *args, **kwargs)
128
+ elif error_on_other_type:
129
+ raise TypeError(
130
+ f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of "
131
+ f"objects that are valid for `{test_type.__name__}` should be passed."
132
+ )
133
+ return data
134
+
135
+
136
+ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
137
+ """
138
+ Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
139
+
140
+ Args:
141
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
142
+ The data to send to a given device.
143
+ device (`torch.device`):
144
+ The device to send the data to.
145
+
146
+ Returns:
147
+ The same data structure as `tensor` with all tensors sent to the proper device.
148
+ """
149
+ if is_torch_tensor(tensor) or hasattr(tensor, "to"):
150
+ # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)).
151
+ if device == "npu":
152
+ device = "npu:0"
153
+ try:
154
+ return tensor.to(device, non_blocking=non_blocking)
155
+ except TypeError: # .to() doesn't accept non_blocking as kwarg
156
+ return tensor.to(device)
157
+ except AssertionError as error:
158
+ # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
159
+ # This call is inside the try-block since is_npu_available is not supported by torch.compile.
160
+ if is_npu_available():
161
+ if isinstance(device, int):
162
+ device = f"npu:{device}"
163
+ else:
164
+ raise error
165
+ try:
166
+ return tensor.to(device, non_blocking=non_blocking)
167
+ except TypeError: # .to() doesn't accept non_blocking as kwarg
168
+ return tensor.to(device)
169
+ elif isinstance(tensor, (tuple, list)):
170
+ return honor_type(
171
+ tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
172
+ )
173
+ elif isinstance(tensor, Mapping):
174
+ if isinstance(skip_keys, str):
175
+ skip_keys = [skip_keys]
176
+ elif skip_keys is None:
177
+ skip_keys = []
178
+ return type(tensor)(
179
+ {
180
+ k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys)
181
+ for k, t in tensor.items()
182
+ }
183
+ )
184
+ else:
185
+ return tensor
186
+
187
+
188
+ def get_data_structure(data):
189
+ """
190
+ Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors.
191
+
192
+ Args:
193
+ data (nested list/tuple/dictionary of `torch.Tensor`):
194
+ The data to send to analyze.
195
+
196
+ Returns:
197
+ The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors.
198
+ """
199
+
200
+ def _get_data_structure(tensor):
201
+ return TensorInformation(shape=tensor.shape, dtype=tensor.dtype)
202
+
203
+ return recursively_apply(_get_data_structure, data)
204
+
205
+
206
+ def get_shape(data):
207
+ """
208
+ Recursively gathers the shape of a nested list/tuple/dictionary of tensors as a list.
209
+
210
+ Args:
211
+ data (nested list/tuple/dictionary of `torch.Tensor`):
212
+ The data to send to analyze.
213
+
214
+ Returns:
215
+ The same data structure as `data` with lists of tensor shapes instead of tensors.
216
+ """
217
+
218
+ def _get_shape(tensor):
219
+ return list(tensor.shape)
220
+
221
+ return recursively_apply(_get_shape, data)
222
+
223
+
224
+ def initialize_tensors(data_structure):
225
+ """
226
+ Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].
227
+
228
+ Returns:
229
+ The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].
230
+ """
231
+
232
+ def _initialize_tensor(tensor_info):
233
+ return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype)
234
+
235
+ return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)
236
+
237
+
238
+ def find_batch_size(data):
239
+ """
240
+ Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors.
241
+
242
+ Args:
243
+ data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
244
+
245
+ Returns:
246
+ `int`: The batch size.
247
+ """
248
+ if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0):
249
+ raise ValueError(f"Cannot find the batch size from empty {type(data)}.")
250
+
251
+ if isinstance(data, (tuple, list)):
252
+ return find_batch_size(data[0])
253
+ elif isinstance(data, Mapping):
254
+ for k in data.keys():
255
+ return find_batch_size(data[k])
256
+ elif not isinstance(data, torch.Tensor):
257
+ raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.")
258
+ return data.shape[0]
259
+
260
+
261
+ def ignorant_find_batch_size(data):
262
+ """
263
+ Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised
264
+
265
+ Args:
266
+ data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
267
+
268
+ Returns:
269
+ `int`: The batch size.
270
+ """
271
+ try:
272
+ return find_batch_size(data)
273
+ except (ValueError, TypeError):
274
+ pass
275
+ return None
276
+
277
+
278
+ def listify(data):
279
+ """
280
+ Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers.
281
+
282
+ Args:
283
+ data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to convert to regular numbers.
284
+
285
+ Returns:
286
+ The same data structure as `data` with lists of numbers instead of `torch.Tensor`.
287
+ """
288
+
289
+ def _convert_to_list(tensor):
290
+ tensor = tensor.detach().cpu()
291
+ if tensor.dtype == torch.bfloat16:
292
+ # As of Numpy 1.21.4, NumPy does not support bfloat16 (see
293
+ # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
294
+ # Until Numpy adds bfloat16, we must convert float32.
295
+ tensor = tensor.to(torch.float32)
296
+ return tensor.tolist()
297
+
298
+ return recursively_apply(_convert_to_list, data)
299
+
300
+
301
+ def _tpu_gather(tensor):
302
+ def _tpu_gather_one(tensor):
303
+ if tensor.ndim == 0:
304
+ tensor = tensor.clone()[None]
305
+
306
+ # Can only gather contiguous tensors
307
+ if not tensor.is_contiguous():
308
+ tensor = tensor.contiguous()
309
+ return xm.all_gather(tensor)
310
+
311
+ res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True)
312
+ xm.mark_step()
313
+ return res
314
+
315
+
316
+ def _gpu_gather(tensor):
317
+ state = PartialState()
318
+ gather_op = torch.distributed.all_gather_into_tensor
319
+
320
+ # NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0
321
+ if state.device.type == "xpu" and is_torch_version("<=", "2.8"):
322
+ torch.xpu.synchronize()
323
+
324
+ def _gpu_gather_one(tensor):
325
+ if tensor.ndim == 0:
326
+ tensor = tensor.clone()[None]
327
+
328
+ # Can only gather contiguous tensors
329
+ if not tensor.is_contiguous():
330
+ tensor = tensor.contiguous()
331
+
332
+ if state.backend is not None and state.backend != "gloo":
333
+ # We use `empty` as `all_gather_into_tensor` slightly
334
+ # differs from `all_gather` for better efficiency,
335
+ # and we rely on the number of items in the tensor
336
+ # rather than its direct shape
337
+ output_tensors = torch.empty(
338
+ state.num_processes * tensor.numel(),
339
+ dtype=tensor.dtype,
340
+ device=state.device,
341
+ )
342
+ gather_op(output_tensors, tensor)
343
+ return output_tensors.view(-1, *tensor.size()[1:])
344
+ else:
345
+ # a backend of `None` is always CPU
346
+ # also gloo does not support `all_gather_into_tensor`,
347
+ # which will result in a larger memory overhead for the op
348
+ output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
349
+ torch.distributed.all_gather(output_tensors, tensor)
350
+ return torch.cat(output_tensors, dim=0)
351
+
352
+ return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
353
+
354
+
355
+ class DistributedOperationException(Exception):
356
+ """
357
+ An exception class for distributed operations. Raised if the operation cannot be performed due to the shape of the
358
+ tensors.
359
+ """
360
+
361
+ pass
362
+
363
+
364
+ def verify_operation(function):
365
+ """
366
+ Verifies that `tensor` is the same shape across all processes. Only ran if `PartialState().debug` is `True`.
367
+ """
368
+
369
+ @wraps(function)
370
+ def wrapper(*args, **kwargs):
371
+ if PartialState().distributed_type == DistributedType.NO or not PartialState().debug:
372
+ return function(*args, **kwargs)
373
+ operation = f"{function.__module__}.{function.__name__}"
374
+ if "tensor" in kwargs:
375
+ tensor = kwargs["tensor"]
376
+ else:
377
+ tensor = args[0]
378
+ if PartialState().device.type != find_device(tensor).type:
379
+ raise DistributedOperationException(
380
+ f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. "
381
+ f"Please move it to the {PartialState().device.type} before calling {operation}."
382
+ )
383
+ shapes = get_shape(tensor)
384
+ output = gather_object([shapes])
385
+ if output[0] is not None:
386
+ are_same = output.count(output[0]) == len(output)
387
+ if not are_same:
388
+ process_shape_str = "\n - ".join([f"Process {i}: {shape}" for i, shape in enumerate(output)])
389
+ raise DistributedOperationException(
390
+ f"Cannot apply desired operation due to shape mismatches. "
391
+ "All shapes across devices must be valid."
392
+ f"\n\nOperation: `{operation}`\nInput shapes:\n - {process_shape_str}"
393
+ )
394
+ return function(*args, **kwargs)
395
+
396
+ return wrapper
397
+
398
+
399
+ def chained_operation(function):
400
+ """
401
+ Checks that `verify_operation` failed and if so reports a more helpful error chaining the existing
402
+ `DistributedOperationException`.
403
+ """
404
+
405
+ @wraps(function)
406
+ def wrapper(*args, **kwargs):
407
+ try:
408
+ return function(*args, **kwargs)
409
+ except DistributedOperationException as e:
410
+ operation = f"{function.__module__}.{function.__name__}"
411
+ raise DistributedOperationException(
412
+ f"Error found while calling `{operation}`. Please see the earlier error for more details."
413
+ ) from e
414
+
415
+ return wrapper
416
+
417
+
418
+ @verify_operation
419
+ def gather(tensor):
420
+ """
421
+ Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
422
+
423
+ Args:
424
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
425
+ The data to gather.
426
+
427
+ Returns:
428
+ The same data structure as `tensor` with all tensors sent to the proper device.
429
+ """
430
+ if PartialState().distributed_type == DistributedType.XLA:
431
+ return _tpu_gather(tensor)
432
+ elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
433
+ return _gpu_gather(tensor)
434
+ else:
435
+ return tensor
436
+
437
+
438
+ def _gpu_gather_object(object: Any):
439
+ output_objects = [None for _ in range(PartialState().num_processes)]
440
+ torch.distributed.all_gather_object(output_objects, object)
441
+ # all_gather_object returns a list of lists, so we need to flatten it
442
+ return [x for y in output_objects for x in y]
443
+
444
+
445
+ def gather_object(object: Any):
446
+ """
447
+ Recursively gather object in a nested list/tuple/dictionary of objects from all devices.
448
+
449
+ Args:
450
+ object (nested list/tuple/dictionary of picklable object):
451
+ The data to gather.
452
+
453
+ Returns:
454
+ The same data structure as `object` with all the objects sent to every device.
455
+ """
456
+ if PartialState().distributed_type == DistributedType.XLA:
457
+ raise NotImplementedError("gather objects in TPU is not supported")
458
+ elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
459
+ return _gpu_gather_object(object)
460
+ else:
461
+ return object
462
+
463
+
464
+ def _gpu_broadcast(data, src=0):
465
+ def _gpu_broadcast_one(tensor, src=0):
466
+ torch.distributed.broadcast(tensor, src=src)
467
+ return tensor
468
+
469
+ return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src)
470
+
471
+
472
+ def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
473
+ if isinstance(tensor, (list, tuple)):
474
+ return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
475
+ elif isinstance(tensor, Mapping):
476
+ return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()})
477
+ return xm.mesh_reduce(name, tensor, lambda x: x[src])
478
+
479
+
480
+ TENSOR_TYPE_TO_INT = {
481
+ torch.float: 1,
482
+ torch.double: 2,
483
+ torch.half: 3,
484
+ torch.bfloat16: 4,
485
+ torch.uint8: 5,
486
+ torch.int8: 6,
487
+ torch.int16: 7,
488
+ torch.int32: 8,
489
+ torch.int64: 9,
490
+ torch.bool: 10,
491
+ }
492
+
493
+ TENSOR_INT_TO_DTYPE = {v: k for k, v in TENSOR_TYPE_TO_INT.items()}
494
+
495
+
496
+ def gather_tensor_shape(tensor):
497
+ """
498
+ Grabs the shape of `tensor` only available on one process and returns a tensor of its shape
499
+ """
500
+ # Allocate 80 bytes to store the shape
501
+ max_tensor_dimension = 2**20
502
+ state = PartialState()
503
+ base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device)
504
+
505
+ # Since PyTorch can't just send a tensor to another GPU without
506
+ # knowing its size, we store the size of the tensor with data
507
+ # in an allocation
508
+ if tensor is not None:
509
+ shape = tensor.shape
510
+ tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype]
511
+ base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int)
512
+ # Perform a reduction to copy the size data onto all GPUs
513
+ base_tensor = reduce(base_tensor, reduction="sum")
514
+ base_tensor = base_tensor[base_tensor.nonzero()]
515
+ # The last non-zero data contains the coded dtype the source tensor is
516
+ dtype = int(base_tensor[-1:][0])
517
+ base_tensor = base_tensor[:-1]
518
+ return base_tensor, dtype
519
+
520
+
521
+ def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
522
+ """
523
+ Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
524
+ each worker doesn't need to know its shape when used (and tensor can be `None`)
525
+
526
+ Args:
527
+ tensor (`torch.tensor`):
528
+ The tensor that should be sent to all devices. Must only have it be defined on a single device, the rest
529
+ should be `None`.
530
+ """
531
+ state = PartialState()
532
+ shape, dtype = gather_tensor_shape(tensor)
533
+ if tensor is None:
534
+ tensor = torch.zeros(shape, dtype=TENSOR_INT_TO_DTYPE[dtype]).to(state.device)
535
+ return reduce(tensor, reduction="sum")
536
+
537
+
538
+ @verify_operation
539
+ def broadcast(tensor, from_process: int = 0):
540
+ """
541
+ Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.
542
+
543
+ Args:
544
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
545
+ The data to gather.
546
+ from_process (`int`, *optional*, defaults to 0):
547
+ The process from which to send the data
548
+
549
+ Returns:
550
+ The same data structure as `tensor` with all tensors broadcasted to the proper device.
551
+ """
552
+ if PartialState().distributed_type == DistributedType.XLA:
553
+ return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast")
554
+ elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
555
+ return _gpu_broadcast(tensor, src=from_process)
556
+ else:
557
+ return tensor
558
+
559
+
560
+ def broadcast_object_list(object_list, from_process: int = 0):
561
+ """
562
+ Broadcast a list of picklable objects from one process to the others.
563
+
564
+ Args:
565
+ object_list (list of picklable objects):
566
+ The list of objects to broadcast. This list will be modified inplace.
567
+ from_process (`int`, *optional*, defaults to 0):
568
+ The process from which to send the data.
569
+
570
+ Returns:
571
+ The same list containing the objects from process 0.
572
+ """
573
+ if PartialState().distributed_type == DistributedType.XLA:
574
+ for i, obj in enumerate(object_list):
575
+ object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process])
576
+ elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
577
+ torch.distributed.broadcast_object_list(object_list, src=from_process)
578
+ return object_list
579
+
580
+
581
+ def slice_tensors(data, tensor_slice, process_index=None, num_processes=None):
582
+ """
583
+ Recursively takes a slice in a nested list/tuple/dictionary of tensors.
584
+
585
+ Args:
586
+ data (nested list/tuple/dictionary of `torch.Tensor`):
587
+ The data to slice.
588
+ tensor_slice (`slice`):
589
+ The slice to take.
590
+
591
+ Returns:
592
+ The same data structure as `data` with all the tensors slices.
593
+ """
594
+
595
+ def _slice_tensor(tensor, tensor_slice):
596
+ return tensor[tensor_slice]
597
+
598
+ return recursively_apply(_slice_tensor, data, tensor_slice)
599
+
600
+
601
+ def concatenate(data, dim=0):
602
+ """
603
+ Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
604
+
605
+ Args:
606
+ data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
607
+ The data to concatenate.
608
+ dim (`int`, *optional*, defaults to 0):
609
+ The dimension on which to concatenate.
610
+
611
+ Returns:
612
+ The same data structure as `data` with all the tensors concatenated.
613
+ """
614
+ if isinstance(data[0], (tuple, list)):
615
+ return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
616
+ elif isinstance(data[0], Mapping):
617
+ return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
618
+ elif not isinstance(data[0], torch.Tensor):
619
+ raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
620
+ return torch.cat(data, dim=dim)
621
+
622
+
623
+ class CannotPadNestedTensorWarning(UserWarning):
624
+ pass
625
+
626
+
627
+ @chained_operation
628
+ def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
629
+ """
630
+ Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they
631
+ can safely be gathered.
632
+
633
+ Args:
634
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
635
+ The data to gather.
636
+ dim (`int`, *optional*, defaults to 0):
637
+ The dimension on which to pad.
638
+ pad_index (`int`, *optional*, defaults to 0):
639
+ The value with which to pad.
640
+ pad_first (`bool`, *optional*, defaults to `False`):
641
+ Whether to pad at the beginning or the end.
642
+ """
643
+
644
+ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
645
+ if getattr(tensor, "is_nested", False):
646
+ warnings.warn(
647
+ "Cannot pad nested tensors without more information. Leaving unprocessed.",
648
+ CannotPadNestedTensorWarning,
649
+ )
650
+ return tensor
651
+ if dim >= len(tensor.shape) or dim < -len(tensor.shape):
652
+ return tensor
653
+ # Convert negative dimensions to non-negative
654
+ if dim < 0:
655
+ dim += len(tensor.shape)
656
+
657
+ # Gather all sizes
658
+ size = torch.tensor(tensor.shape, device=tensor.device)[None]
659
+ sizes = gather(size).cpu()
660
+ # Then pad to the maximum size
661
+ max_size = max(s[dim] for s in sizes)
662
+ if max_size == tensor.shape[dim]:
663
+ return tensor
664
+
665
+ old_size = tensor.shape
666
+ new_size = list(old_size)
667
+ new_size[dim] = max_size
668
+ new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
669
+ if pad_first:
670
+ indices = tuple(
671
+ slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size))
672
+ )
673
+ else:
674
+ indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
675
+ new_tensor[indices] = tensor
676
+ return new_tensor
677
+
678
+ return recursively_apply(
679
+ _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
680
+ )
681
+
682
+
683
+ def pad_input_tensors(tensor, batch_size, num_processes, dim=0):
684
+ """
685
+ Takes a `tensor` of arbitrary size and pads it so that it can work given `num_processes` needed dimensions.
686
+
687
+ New tensors are just the last input repeated.
688
+
689
+ E.g.:
690
+ Tensor: ([3,4,4]) Num processes: 4 Expected result shape: ([4,4,4])
691
+
692
+ """
693
+
694
+ def _pad_input_tensors(tensor, batch_size, num_processes, dim=0):
695
+ remainder = batch_size // num_processes
696
+ last_inputs = batch_size - (remainder * num_processes)
697
+ if batch_size // num_processes == 0:
698
+ to_pad = num_processes - batch_size
699
+ else:
700
+ to_pad = num_processes - (batch_size // num_processes)
701
+ # In the rare case that `to_pad` is negative,
702
+ # we need to pad the last inputs - the found `to_pad`
703
+ if last_inputs > to_pad & to_pad < 1:
704
+ to_pad = last_inputs - to_pad
705
+ old_size = tensor.shape
706
+ new_size = list(old_size)
707
+ new_size[0] = batch_size + to_pad
708
+ new_tensor = tensor.new_zeros(tuple(new_size))
709
+ indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
710
+ new_tensor[indices] = tensor
711
+ return new_tensor
712
+
713
+ return recursively_apply(
714
+ _pad_input_tensors,
715
+ tensor,
716
+ error_on_other_type=True,
717
+ batch_size=batch_size,
718
+ num_processes=num_processes,
719
+ dim=dim,
720
+ )
721
+
722
+
723
+ @verify_operation
724
+ def reduce(tensor, reduction="mean", scale=1.0):
725
+ """
726
+ Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the
727
+ mean of a given operation.
728
+
729
+ Args:
730
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
731
+ The data to reduce.
732
+ reduction (`str`, *optional*, defaults to `"mean"`):
733
+ A reduction method. Can be of "mean", "sum", or "none"
734
+ scale (`float`, *optional*):
735
+ A default scaling value to be applied after the reduce, only valid on XLA.
736
+
737
+ Returns:
738
+ The same data structure as `data` with all the tensors reduced.
739
+ """
740
+
741
+ def _reduce_across_processes(tensor, reduction="mean", scale=1.0):
742
+ state = PartialState()
743
+ cloned_tensor = tensor.clone()
744
+ if state.distributed_type == DistributedType.NO:
745
+ return cloned_tensor
746
+ if state.distributed_type == DistributedType.XLA:
747
+ # Some processes may have different HLO graphs than other
748
+ # processes, for example in the breakpoint API
749
+ # accelerator.set_trigger(). Use mark_step to make HLOs
750
+ # the same on all processes.
751
+ xm.mark_step()
752
+ xm.all_reduce(xm.REDUCE_SUM, [cloned_tensor], scale)
753
+ xm.mark_step()
754
+ elif state.distributed_type.value in TORCH_DISTRIBUTED_OPERATION_TYPES:
755
+ torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM)
756
+ if reduction == "mean":
757
+ cloned_tensor /= state.num_processes
758
+ return cloned_tensor
759
+
760
+ return recursively_apply(
761
+ _reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction, scale=scale
762
+ )
763
+
764
+
765
+ def convert_to_fp32(tensor):
766
+ """
767
+ Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.
768
+
769
+ Args:
770
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
771
+ The data to convert from FP16/BF16 to FP32.
772
+
773
+ Returns:
774
+ The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32.
775
+ """
776
+
777
+ def _convert_to_fp32(tensor):
778
+ return tensor.float()
779
+
780
+ def _is_fp16_bf16_tensor(tensor):
781
+ return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in (
782
+ torch.float16,
783
+ torch.bfloat16,
784
+ )
785
+
786
+ return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
787
+
788
+
789
+ class ConvertOutputsToFp32:
790
+ """
791
+ Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16
792
+ precision will be convert back to FP32.
793
+
794
+ Args:
795
+ model_forward (`Callable`):
796
+ The function which outputs we want to treat.
797
+
798
+ Returns:
799
+ The same function as `model_forward` but with converted outputs.
800
+ """
801
+
802
+ def __init__(self, model_forward):
803
+ self.model_forward = model_forward
804
+ update_wrapper(self, model_forward)
805
+
806
+ def __call__(self, *args, **kwargs):
807
+ return convert_to_fp32(self.model_forward(*args, **kwargs))
808
+
809
+ def __getstate__(self):
810
+ raise pickle.PicklingError(
811
+ "Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it."
812
+ )
813
+
814
+
815
+ def convert_outputs_to_fp32(model_forward):
816
+ model_forward = ConvertOutputsToFp32(model_forward)
817
+
818
+ def forward(*args, **kwargs):
819
+ return model_forward(*args, **kwargs)
820
+
821
+ # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
822
+ forward.__wrapped__ = model_forward
823
+
824
+ return forward
825
+
826
+
827
+ def find_device(data):
828
+ """
829
+ Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
830
+
831
+ Args:
832
+ (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
833
+ """
834
+ if isinstance(data, Mapping):
835
+ for obj in data.values():
836
+ device = find_device(obj)
837
+ if device is not None:
838
+ return device
839
+ elif isinstance(data, (tuple, list)):
840
+ for obj in data:
841
+ device = find_device(obj)
842
+ if device is not None:
843
+ return device
844
+ elif isinstance(data, torch.Tensor):
845
+ return data.device
846
+
847
+
848
+ @contextmanager
849
+ def GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True):
850
+ """
851
+ Wrapper around `deepspeed.runtime.zero.GatheredParameters`, but if Zero-3 is not enabled, will be a no-op context
852
+ manager.
853
+ """
854
+ # We need to use the `AcceleratorState` here since it has access to the deepspeed plugin
855
+ if AcceleratorState().distributed_type != DistributedType.DEEPSPEED or (
856
+ AcceleratorState().deepspeed_plugin is not None
857
+ and not AcceleratorState().deepspeed_plugin.is_zero3_init_enabled()
858
+ ):
859
+ gather_param_context = nullcontext()
860
+ else:
861
+ import deepspeed
862
+
863
+ gather_param_context = deepspeed.zero.GatheredParameters(
864
+ params, modifier_rank=modifier_rank, fwd_module=fwd_module, enabled=enabled
865
+ )
866
+ with gather_param_context:
867
+ yield
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/other.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import platform
17
+ import re
18
+ import socket
19
+ from codecs import encode
20
+ from collections import OrderedDict
21
+ from functools import partial, reduce
22
+ from types import MethodType
23
+ from typing import Optional
24
+
25
+ import numpy as np
26
+ import torch
27
+ from packaging.version import Version
28
+ from safetensors.torch import save_file as safe_save_file
29
+
30
+ from ..commands.config.default import write_basic_config # noqa: F401
31
+ from ..logging import get_logger
32
+ from ..state import PartialState
33
+ from .constants import FSDP_PYTORCH_VERSION
34
+ from .dataclasses import DistributedType
35
+ from .imports import (
36
+ is_deepspeed_available,
37
+ is_numpy_available,
38
+ is_torch_distributed_available,
39
+ is_torch_xla_available,
40
+ is_weights_only_available,
41
+ )
42
+ from .modeling import id_tensor_storage
43
+ from .transformer_engine import convert_model
44
+ from .versions import is_torch_version
45
+
46
+
47
+ logger = get_logger(__name__)
48
+
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+
54
+ def is_compiled_module(module: torch.nn.Module) -> bool:
55
+ """
56
+ Check whether the module was compiled with torch.compile()
57
+ """
58
+ if not hasattr(torch, "_dynamo"):
59
+ return False
60
+
61
+ return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
62
+
63
+
64
+ def has_compiled_regions(module: torch.nn.Module) -> bool:
65
+ """
66
+ Check whether the module has submodules that were compiled with `torch.compile()`.
67
+ """
68
+ if not hasattr(torch, "_dynamo"):
69
+ return False
70
+
71
+ if module._modules:
72
+ for submodule in module.modules():
73
+ if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule):
74
+ return True
75
+
76
+ return False
77
+
78
+
79
+ def is_repeated_blocks(module: torch.nn.Module) -> bool:
80
+ """
81
+ Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This
82
+ is useful to determine whether we should apply regional compilation to the module.
83
+ """
84
+
85
+ return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module)
86
+
87
+
88
+ def has_repeated_blocks(module: torch.nn.Module) -> bool:
89
+ """
90
+ Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at
91
+ any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the
92
+ module.
93
+ """
94
+ if module._modules:
95
+ for submodule in module.modules():
96
+ if is_repeated_blocks(submodule):
97
+ return True
98
+
99
+ return False
100
+
101
+
102
+ def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
103
+ """
104
+ Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to
105
+ hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be
106
+ accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately.
107
+
108
+ This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general.
109
+ See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details.
110
+
111
+ Args:
112
+ module (`torch.nn.Module`):
113
+ The model to compile.
114
+ **compile_kwargs:
115
+ Additional keyword arguments to pass to `torch.compile()`.
116
+
117
+ Returns:
118
+ `torch.nn.Module`: A new instance of the model with some compiled regions.
119
+
120
+ Example:
121
+ ```python
122
+ >>> from accelerate.utils import compile_regions
123
+ >>> from transformers import AutoModelForCausalLM
124
+
125
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
126
+ >>> compiled_model = compile_regions(model, mode="reduce-overhead")
127
+ >>> compiled_model.transformer.h[0]
128
+ OptimizedModule(
129
+ (_orig_mod): GPT2Block(
130
+ (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
131
+ (attn): GPT2Attention(
132
+ (c_attn): Conv1D(nf=2304, nx=768)
133
+ (c_proj): Conv1D(nf=768, nx=768)
134
+ (attn_dropout): Dropout(p=0.1, inplace=False)
135
+ (resid_dropout): Dropout(p=0.1, inplace=False)
136
+ )
137
+ (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
138
+ (mlp): GPT2MLP(
139
+ (c_fc): Conv1D(nf=3072, nx=768)
140
+ (c_proj): Conv1D(nf=768, nx=3072)
141
+ (act): NewGELUActivation()
142
+ (dropout): Dropout(p=0.1, inplace=False)
143
+ )
144
+ )
145
+ )
146
+ ```
147
+ """
148
+
149
+ def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
150
+ if is_repeated_blocks(module):
151
+ new_module = torch.nn.ModuleList()
152
+ for submodule in module:
153
+ new_module.append(torch.compile(submodule, **compile_kwargs))
154
+ elif has_repeated_blocks(module):
155
+ new_module = module.__class__.__new__(module.__class__)
156
+ new_module.__dict__.update(module.__dict__)
157
+ new_module._modules = {}
158
+ for name, submodule in module.named_children():
159
+ new_module.add_module(name, _compile_regions(submodule, **compile_kwargs))
160
+ else:
161
+ new_module = torch.compile(module, **compile_kwargs)
162
+
163
+ return new_module
164
+
165
+ new_module = _compile_regions(module, **compile_kwargs)
166
+
167
+ if "_orig_mod" not in new_module.__dict__:
168
+ # Keeps a reference to the original module to decompile/unwrap it later
169
+ new_module.__dict__["_orig_mod"] = module
170
+
171
+ return new_module
172
+
173
+
174
+ def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
175
+ """
176
+ Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`.
177
+ Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that
178
+ `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method
179
+ instead.
180
+
181
+ Args:
182
+ module (`torch.nn.Module`):
183
+ The model to compile.
184
+ **compile_kwargs:
185
+ Additional keyword arguments to pass to `module.compile()`.
186
+ """
187
+
188
+ if is_repeated_blocks(module):
189
+ for submodule in module:
190
+ submodule.compile(**compile_kwargs)
191
+ elif has_repeated_blocks(module):
192
+ for child in module.children():
193
+ compile_regions_deepspeed(child, **compile_kwargs)
194
+ else: # leaf node
195
+ module.compile(**compile_kwargs)
196
+
197
+
198
+ def model_has_dtensor(model: torch.nn.Module) -> bool:
199
+ """
200
+ Check if the model has DTensor parameters.
201
+
202
+ Args:
203
+ model (`torch.nn.Module`):
204
+ The model to check.
205
+
206
+ Returns:
207
+ `bool`: Whether the model has DTensor parameters.
208
+ """
209
+ if is_torch_version(">=", "2.5.0"):
210
+ from torch.distributed.tensor import DTensor
211
+ else:
212
+ # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
213
+ from torch.distributed._tensor import DTensor
214
+
215
+ return any(isinstance(p, DTensor) for p in model.parameters())
216
+
217
+
218
+ def extract_model_from_parallel(
219
+ model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
220
+ ):
221
+ """
222
+ Extract a model from its distributed containers.
223
+
224
+ Args:
225
+ model (`torch.nn.Module`):
226
+ The model to extract.
227
+ keep_fp32_wrapper (`bool`, *optional*):
228
+ Whether to remove mixed precision hooks from the model.
229
+ keep_torch_compile (`bool`, *optional*):
230
+ Whether to unwrap compiled model.
231
+ recursive (`bool`, *optional*, defaults to `False`):
232
+ Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
233
+ recursively, not just the top-level distributed containers.
234
+
235
+ Returns:
236
+ `torch.nn.Module`: The extracted model.
237
+ """
238
+ options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
239
+
240
+ is_compiled = is_compiled_module(model)
241
+ has_compiled = has_compiled_regions(model)
242
+
243
+ if is_compiled:
244
+ compiled_model = model
245
+ model = model._orig_mod
246
+ elif has_compiled:
247
+ compiled_model = model
248
+ model = model.__dict__["_orig_mod"]
249
+
250
+ if is_deepspeed_available():
251
+ from deepspeed import DeepSpeedEngine
252
+
253
+ options += (DeepSpeedEngine,)
254
+
255
+ if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
256
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
257
+
258
+ options += (FSDP,)
259
+
260
+ while isinstance(model, options):
261
+ model = model.module
262
+
263
+ if recursive:
264
+ # This is needed in cases such as using FSDPv2 on XLA
265
+ def _recursive_unwrap(module):
266
+ # Wrapped modules are standardly wrapped as `module`, similar to the cases earlier
267
+ # with DDP, DataParallel, DeepSpeed, and FSDP
268
+ if hasattr(module, "module"):
269
+ unwrapped_module = _recursive_unwrap(module.module)
270
+ else:
271
+ unwrapped_module = module
272
+ # Next unwrap child sublayers recursively
273
+ for name, child in unwrapped_module.named_children():
274
+ setattr(unwrapped_module, name, _recursive_unwrap(child))
275
+ return unwrapped_module
276
+
277
+ # Start with top-level
278
+ model = _recursive_unwrap(model)
279
+
280
+ if not keep_fp32_wrapper:
281
+ forward = model.forward
282
+ original_forward = model.__dict__.pop("_original_forward", None)
283
+ if original_forward is not None:
284
+ while hasattr(forward, "__wrapped__"):
285
+ forward = forward.__wrapped__
286
+ if forward == original_forward:
287
+ break
288
+ model.forward = MethodType(forward, model)
289
+ if getattr(model, "_converted_to_transformer_engine", False):
290
+ convert_model(model, to_transformer_engine=False)
291
+
292
+ if keep_torch_compile:
293
+ if is_compiled:
294
+ compiled_model._orig_mod = model
295
+ model = compiled_model
296
+ elif has_compiled:
297
+ compiled_model.__dict__["_orig_mod"] = model
298
+ model = compiled_model
299
+
300
+ return model
301
+
302
+
303
+ def wait_for_everyone():
304
+ """
305
+ Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
306
+
307
+ <Tip warning={true}>
308
+
309
+ Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
310
+
311
+ </Tip>
312
+ """
313
+ PartialState().wait_for_everyone()
314
+
315
+
316
+ def clean_state_dict_for_safetensors(state_dict: dict):
317
+ """
318
+ Cleans the state dictionary from a model and removes tensor aliasing if present.
319
+
320
+ Args:
321
+ state_dict (`dict`):
322
+ The state dictionary from a model
323
+ """
324
+ ptrs = collections.defaultdict(list)
325
+ # When bnb serialization is used, weights in state dict can be strings
326
+ for name, tensor in state_dict.items():
327
+ if not isinstance(tensor, str):
328
+ ptrs[id_tensor_storage(tensor)].append(name)
329
+
330
+ # These are all pointers of tensors with shared memory
331
+ shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
332
+ warn_names = set()
333
+ for names in shared_ptrs.values():
334
+ # When not all duplicates have been cleaned, we still remove those keys but put a clear warning.
335
+ # If the link between tensors was done at runtime then `from_pretrained` will not get
336
+ # the key back leading to random tensor. A proper warning will be shown
337
+ # during reload (if applicable), but since the file is not necessarily compatible with
338
+ # the config, better show a proper warning.
339
+ found_names = [name for name in names if name in state_dict]
340
+ warn_names.update(found_names[1:])
341
+ for name in found_names[1:]:
342
+ del state_dict[name]
343
+ if len(warn_names) > 0:
344
+ logger.warning(
345
+ f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
346
+ )
347
+ state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()}
348
+ return state_dict
349
+
350
+
351
+ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
352
+ """
353
+ Save the data to disk. Use in place of `torch.save()`.
354
+
355
+ Args:
356
+ obj:
357
+ The data to save
358
+ f:
359
+ The file (or file-like object) to use to save the data
360
+ save_on_each_node (`bool`, *optional*, defaults to `False`):
361
+ Whether to only save on the global main process
362
+ safe_serialization (`bool`, *optional*, defaults to `False`):
363
+ Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`).
364
+ """
365
+ # When TorchXLA is enabled, it's necessary to transfer all data to the CPU before saving.
366
+ # Another issue arises with `id_tensor_storage`, which treats all XLA tensors as identical.
367
+ # If tensors remain on XLA, calling `clean_state_dict_for_safetensors` will result in only
368
+ # one XLA tensor remaining.
369
+ if PartialState().distributed_type == DistributedType.XLA:
370
+ obj = xm._maybe_convert_to_cpu(obj)
371
+ # Check if it's a model and remove duplicates
372
+ if safe_serialization:
373
+ save_func = partial(safe_save_file, metadata={"format": "pt"})
374
+ if isinstance(obj, OrderedDict):
375
+ obj = clean_state_dict_for_safetensors(obj)
376
+ else:
377
+ save_func = torch.save
378
+
379
+ if PartialState().is_main_process and not save_on_each_node:
380
+ save_func(obj, f)
381
+ elif PartialState().is_local_main_process and save_on_each_node:
382
+ save_func(obj, f)
383
+
384
+
385
+ # The following are considered "safe" globals to reconstruct various types of objects when using `weights_only=True`
386
+ # These should be added and then removed after loading in the file
387
+ np_core = np._core if is_numpy_available("2.0.0") else np.core
388
+ TORCH_SAFE_GLOBALS = [
389
+ # numpy arrays are just numbers, not objects, so we can reconstruct them safely
390
+ np_core.multiarray._reconstruct,
391
+ np.ndarray,
392
+ # The following are needed for the RNG states
393
+ encode,
394
+ np.dtype,
395
+ ]
396
+
397
+ if is_numpy_available("1.25.0"):
398
+ TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType)
399
+
400
+
401
+ def load(f, map_location=None, **kwargs):
402
+ """
403
+ Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is
404
+ 2.4.0 or higher. Otherwise will ignore the kwarg.
405
+
406
+ Will also add (and then remove) an exception for numpy arrays
407
+
408
+ Args:
409
+ f:
410
+ The file (or file-like object) to use to load the data
411
+ map_location:
412
+ a function, `torch.device`, string or a dict specifying how to remap storage locations
413
+ **kwargs:
414
+ Additional keyword arguments to pass to `torch.load()`.
415
+ """
416
+ try:
417
+ if is_weights_only_available():
418
+ old_safe_globals = torch.serialization.get_safe_globals()
419
+ if "weights_only" not in kwargs:
420
+ kwargs["weights_only"] = True
421
+ torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS)
422
+ else:
423
+ kwargs.pop("weights_only", None)
424
+ loaded_obj = torch.load(f, map_location=map_location, **kwargs)
425
+ finally:
426
+ if is_weights_only_available():
427
+ torch.serialization.clear_safe_globals()
428
+ if old_safe_globals:
429
+ torch.serialization.add_safe_globals(old_safe_globals)
430
+ return loaded_obj
431
+
432
+
433
+ def get_pretty_name(obj):
434
+ """
435
+ Gets a pretty name from `obj`.
436
+ """
437
+ if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"):
438
+ obj = getattr(obj, "__class__", obj)
439
+ if hasattr(obj, "__qualname__"):
440
+ return obj.__qualname__
441
+ if hasattr(obj, "__name__"):
442
+ return obj.__name__
443
+ return str(obj)
444
+
445
+
446
+ def merge_dicts(source, destination):
447
+ """
448
+ Recursively merges two dictionaries.
449
+
450
+ Args:
451
+ source (`dict`): The dictionary to merge into `destination`.
452
+ destination (`dict`): The dictionary to merge `source` into.
453
+ """
454
+ for key, value in source.items():
455
+ if isinstance(value, dict):
456
+ node = destination.setdefault(key, {})
457
+ merge_dicts(value, node)
458
+ else:
459
+ destination[key] = value
460
+
461
+ return destination
462
+
463
+
464
+ def is_port_in_use(port: Optional[int] = None) -> bool:
465
+ """
466
+ Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been
467
+ run and need to see if the port is already in use.
468
+ """
469
+ if port is None:
470
+ port = 29500
471
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
472
+ return s.connect_ex(("localhost", port)) == 0
473
+
474
+
475
+ def get_free_port() -> int:
476
+ """
477
+ Gets a free port on `localhost`. Useful for automatic port selection when port 0 is specified in distributed
478
+ training scenarios.
479
+
480
+ Returns:
481
+ int: An available port number
482
+ """
483
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
484
+ s.bind(("", 0)) # bind to port 0 for OS to assign a free port
485
+ return s.getsockname()[1]
486
+
487
+
488
+ def convert_bytes(size):
489
+ "Converts `size` from bytes to the largest possible unit"
490
+ for x in ["bytes", "KB", "MB", "GB", "TB"]:
491
+ if size < 1024.0:
492
+ return f"{round(size, 2)} {x}"
493
+ size /= 1024.0
494
+
495
+ return f"{round(size, 2)} PB"
496
+
497
+
498
+ def check_os_kernel():
499
+ """Warns if the kernel version is below the recommended minimum on Linux."""
500
+ # see issue #1929
501
+ info = platform.uname()
502
+ system = info.system
503
+ if system != "Linux":
504
+ return
505
+
506
+ _, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release)
507
+ min_version = "5.5.0"
508
+ if Version(version) < Version(min_version):
509
+ msg = (
510
+ f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can "
511
+ "cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
512
+ )
513
+ logger.warning(msg, main_process_only=True)
514
+
515
+
516
+ def recursive_getattr(obj, attr: str):
517
+ """
518
+ Recursive `getattr`.
519
+
520
+ Args:
521
+ obj:
522
+ A class instance holding the attribute.
523
+ attr (`str`):
524
+ The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
525
+ """
526
+
527
+ def _getattr(obj, attr):
528
+ return getattr(obj, attr)
529
+
530
+ return reduce(_getattr, [obj] + attr.split("."))
531
+
532
+
533
+ def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: bool = False) -> list[torch.nn.Module]:
534
+ """Traverse the model in bottom-up order and return the children modules in that order.
535
+
536
+ Args:
537
+ model (`torch.nn.Module`): the model to get the children of
538
+
539
+ Returns:
540
+ `list[torch.nn.Module]`: a list of children modules of `model` in bottom-up order. The last element is the
541
+ `model` itself.
542
+ """
543
+ top = model if not return_fqns else ("", model)
544
+ stack = [top]
545
+ ordered_modules = []
546
+ while stack:
547
+ current_module = stack.pop()
548
+ if return_fqns:
549
+ current_module_name, current_module = current_module
550
+ for name, attr in current_module.named_children():
551
+ if isinstance(attr, torch.nn.Module):
552
+ if return_fqns:
553
+ child_name = current_module_name + "." + name if current_module_name else name
554
+ stack.append((child_name, attr))
555
+ else:
556
+ stack.append(attr)
557
+ if return_fqns:
558
+ ordered_modules.append((current_module_name, current_module))
559
+ else:
560
+ ordered_modules.append(current_module)
561
+ return ordered_modules[::-1]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/random.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from typing import Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from ..state import AcceleratorState
22
+ from .constants import CUDA_DISTRIBUTED_TYPES
23
+ from .dataclasses import DistributedType, RNGType
24
+ from .imports import (
25
+ is_hpu_available,
26
+ is_mlu_available,
27
+ is_musa_available,
28
+ is_npu_available,
29
+ is_sdaa_available,
30
+ is_torch_xla_available,
31
+ is_xpu_available,
32
+ )
33
+
34
+
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+
39
+ def set_seed(seed: int, device_specific: bool = False, deterministic: bool = False):
40
+ """
41
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
42
+
43
+ Args:
44
+ seed (`int`):
45
+ The seed to set.
46
+ device_specific (`bool`, *optional*, defaults to `False`):
47
+ Whether to differ the seed on each device slightly with `self.process_index`.
48
+ deterministic (`bool`, *optional*, defaults to `False`):
49
+ Whether to use deterministic algorithms where available. Can slow down training.
50
+ """
51
+ if device_specific:
52
+ seed += AcceleratorState().process_index
53
+ random.seed(seed)
54
+ np.random.seed(seed)
55
+ torch.manual_seed(seed)
56
+ if is_xpu_available():
57
+ torch.xpu.manual_seed_all(seed)
58
+ elif is_npu_available():
59
+ torch.npu.manual_seed_all(seed)
60
+ elif is_mlu_available():
61
+ torch.mlu.manual_seed_all(seed)
62
+ elif is_sdaa_available():
63
+ torch.sdaa.manual_seed_all(seed)
64
+ elif is_musa_available():
65
+ torch.musa.manual_seed_all(seed)
66
+ elif is_hpu_available():
67
+ torch.hpu.manual_seed_all(seed)
68
+ else:
69
+ torch.cuda.manual_seed_all(seed)
70
+ # ^^ safe to call this function even if cuda is not available
71
+ if is_torch_xla_available():
72
+ xm.set_rng_state(seed)
73
+
74
+ if deterministic:
75
+ torch.use_deterministic_algorithms(True)
76
+
77
+
78
+ def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):
79
+ # Get the proper rng state
80
+ if rng_type == RNGType.TORCH:
81
+ rng_state = torch.get_rng_state()
82
+ elif rng_type == RNGType.CUDA:
83
+ rng_state = torch.cuda.get_rng_state()
84
+ elif rng_type == RNGType.XLA:
85
+ assert is_torch_xla_available(), "Can't synchronize XLA seeds as torch_xla is unavailable."
86
+ rng_state = torch.tensor(xm.get_rng_state())
87
+ elif rng_type == RNGType.NPU:
88
+ assert is_npu_available(), "Can't synchronize NPU seeds on an environment without NPUs."
89
+ rng_state = torch.npu.get_rng_state()
90
+ elif rng_type == RNGType.MLU:
91
+ assert is_mlu_available(), "Can't synchronize MLU seeds on an environment without MLUs."
92
+ rng_state = torch.mlu.get_rng_state()
93
+ elif rng_type == RNGType.SDAA:
94
+ assert is_sdaa_available(), "Can't synchronize SDAA seeds on an environment without SDAAs."
95
+ rng_state = torch.sdaa.get_rng_state()
96
+ elif rng_type == RNGType.MUSA:
97
+ assert is_musa_available(), "Can't synchronize MUSA seeds on an environment without MUSAs."
98
+ rng_state = torch.musa.get_rng_state()
99
+ elif rng_type == RNGType.XPU:
100
+ assert is_xpu_available(), "Can't synchronize XPU seeds on an environment without XPUs."
101
+ rng_state = torch.xpu.get_rng_state()
102
+ elif rng_type == RNGType.HPU:
103
+ assert is_hpu_available(), "Can't synchronize HPU seeds on an environment without HPUs."
104
+ rng_state = torch.hpu.get_rng_state()
105
+ elif rng_type == RNGType.GENERATOR:
106
+ assert generator is not None, "Need a generator to synchronize its seed."
107
+ rng_state = generator.get_state()
108
+
109
+ # Broadcast the rng state from device 0 to other devices
110
+ state = AcceleratorState()
111
+ if state.distributed_type == DistributedType.XLA:
112
+ rng_state = rng_state.to(xm.xla_device())
113
+ xm.collective_broadcast([rng_state])
114
+ xm.mark_step()
115
+ rng_state = rng_state.cpu()
116
+ elif (
117
+ state.distributed_type in CUDA_DISTRIBUTED_TYPES
118
+ or state.distributed_type == DistributedType.MULTI_MLU
119
+ or state.distributed_type == DistributedType.MULTI_SDAA
120
+ or state.distributed_type == DistributedType.MULTI_MUSA
121
+ or state.distributed_type == DistributedType.MULTI_NPU
122
+ or state.distributed_type == DistributedType.MULTI_XPU
123
+ or state.distributed_type == DistributedType.MULTI_HPU
124
+ ):
125
+ rng_state = rng_state.to(state.device)
126
+ torch.distributed.broadcast(rng_state, 0)
127
+ rng_state = rng_state.cpu()
128
+ elif state.distributed_type == DistributedType.MULTI_CPU:
129
+ torch.distributed.broadcast(rng_state, 0)
130
+
131
+ # Set the broadcast rng state
132
+ if rng_type == RNGType.TORCH:
133
+ torch.set_rng_state(rng_state)
134
+ elif rng_type == RNGType.CUDA:
135
+ torch.cuda.set_rng_state(rng_state)
136
+ elif rng_type == RNGType.NPU:
137
+ torch.npu.set_rng_state(rng_state)
138
+ elif rng_type == RNGType.MLU:
139
+ torch.mlu.set_rng_state(rng_state)
140
+ elif rng_type == RNGType.SDAA:
141
+ torch.sdaa.set_rng_state(rng_state)
142
+ elif rng_type == RNGType.MUSA:
143
+ torch.musa.set_rng_state(rng_state)
144
+ elif rng_type == RNGType.XPU:
145
+ torch.xpu.set_rng_state(rng_state)
146
+ elif rng_state == RNGType.HPU:
147
+ torch.hpu.set_rng_state(rng_state)
148
+ elif rng_type == RNGType.XLA:
149
+ xm.set_rng_state(rng_state.item())
150
+ elif rng_type == RNGType.GENERATOR:
151
+ generator.set_state(rng_state)
152
+
153
+
154
+ def synchronize_rng_states(rng_types: list[Union[str, RNGType]], generator: Optional[torch.Generator] = None):
155
+ for rng_type in rng_types:
156
+ synchronize_rng_state(RNGType(rng_type), generator=generator)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/rich.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .imports import is_rich_available
16
+
17
+
18
+ if is_rich_available():
19
+ from rich.traceback import install
20
+
21
+ install(show_locals=False)
22
+
23
+ else:
24
+ raise ModuleNotFoundError("To use the rich extension, install rich with `pip install rich`")
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/torch_xla.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.metadata
16
+ import subprocess
17
+ import sys
18
+
19
+
20
+ def install_xla(upgrade: bool = False):
21
+ """
22
+ Helper function to install appropriate xla wheels based on the `torch` version in Google Colaboratory.
23
+
24
+ Args:
25
+ upgrade (`bool`, *optional*, defaults to `False`):
26
+ Whether to upgrade `torch` and install the latest `torch_xla` wheels.
27
+
28
+ Example:
29
+
30
+ ```python
31
+ >>> from accelerate.utils import install_xla
32
+
33
+ >>> install_xla(upgrade=True)
34
+ ```
35
+ """
36
+ in_colab = False
37
+ if "IPython" in sys.modules:
38
+ in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
39
+
40
+ if in_colab:
41
+ if upgrade:
42
+ torch_install_cmd = ["pip", "install", "-U", "torch"]
43
+ subprocess.run(torch_install_cmd, check=True)
44
+ # get the current version of torch
45
+ torch_version = importlib.metadata.version("torch")
46
+ torch_version_trunc = torch_version[: torch_version.rindex(".")]
47
+ xla_wheel = f"https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-{torch_version_trunc}-cp37-cp37m-linux_x86_64.whl"
48
+ xla_install_cmd = ["pip", "install", xla_wheel]
49
+ subprocess.run(xla_install_cmd, check=True)
50
+ else:
51
+ raise RuntimeError("`install_xla` utility works only on google colab.")
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/tqdm.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .imports import is_tqdm_available
17
+
18
+
19
+ if is_tqdm_available():
20
+ from tqdm.auto import tqdm as _tqdm
21
+
22
+ from ..state import PartialState
23
+
24
+
25
+ def tqdm(*args, main_process_only: bool = True, **kwargs):
26
+ """
27
+ Wrapper around `tqdm.tqdm` that optionally displays only on the main process.
28
+
29
+ Args:
30
+ main_process_only (`bool`, *optional*):
31
+ Whether to display the progress bar only on the main process
32
+ """
33
+ if not is_tqdm_available():
34
+ raise ImportError("Accelerate's `tqdm` module requires `tqdm` to be installed. Please run `pip install tqdm`.")
35
+ if len(args) > 0 and isinstance(args[0], bool):
36
+ raise ValueError(
37
+ "Passing `True` or `False` as the first argument to Accelerate's `tqdm` wrapper is unsupported. "
38
+ "Please use the `main_process_only` keyword argument instead."
39
+ )
40
+ disable = kwargs.pop("disable", False)
41
+ if main_process_only and not disable:
42
+ disable = PartialState().local_process_index != 0
43
+ return _tqdm(*args, **kwargs, disable=disable)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import MethodType
16
+
17
+ import torch.nn as nn
18
+
19
+ from .imports import is_hpu_available, is_transformer_engine_available
20
+ from .operations import GatheredParameters
21
+
22
+
23
+ # Do not import `transformer_engine` at package level to avoid potential issues
24
+
25
+
26
+ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
27
+ """
28
+ Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
29
+ """
30
+ if not is_transformer_engine_available():
31
+ raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
32
+
33
+ if is_hpu_available():
34
+ import intel_transformer_engine as te
35
+
36
+ if not hasattr(te, "LayerNorm"):
37
+ # HPU does not have a LayerNorm implementation in TE
38
+ te.LayerNorm = nn.LayerNorm
39
+ else:
40
+ import transformer_engine.pytorch as te
41
+
42
+ for name, module in model.named_children():
43
+ if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
44
+ has_bias = module.bias is not None
45
+ params_to_gather = [module.weight]
46
+ if has_bias:
47
+ params_to_gather.append(module.bias)
48
+
49
+ with GatheredParameters(params_to_gather, modifier_rank=0):
50
+ if any(p % 16 != 0 for p in module.weight.shape):
51
+ return
52
+ te_module = te.Linear(
53
+ module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
54
+ )
55
+ te_module.weight.copy_(module.weight)
56
+ if has_bias:
57
+ te_module.bias.copy_(module.bias)
58
+
59
+ setattr(model, name, te_module)
60
+ # Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
61
+ elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
62
+ with GatheredParameters([module.weight, module.bias], modifier_rank=0):
63
+ has_bias = module.bias is not None
64
+ te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
65
+ te_module.weight.copy_(module.weight)
66
+ if has_bias:
67
+ te_module.bias.copy_(module.bias)
68
+
69
+ setattr(model, name, te_module)
70
+ elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
71
+ has_bias = module.bias is not None
72
+ new_module = nn.Linear(
73
+ module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
74
+ )
75
+ new_module.weight.copy_(module.weight)
76
+ if has_bias:
77
+ new_module.bias.copy_(module.bias)
78
+
79
+ setattr(model, name, new_module)
80
+ elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln:
81
+ new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
82
+ new_module.weight.copy_(module.weight)
83
+ new_module.bias.copy_(module.bias)
84
+
85
+ setattr(model, name, new_module)
86
+ else:
87
+ convert_model(
88
+ module,
89
+ to_transformer_engine=to_transformer_engine,
90
+ _convert_linear=_convert_linear,
91
+ _convert_ln=_convert_ln,
92
+ )
93
+
94
+
95
+ def has_transformer_engine_layers(model):
96
+ """
97
+ Returns whether a given model has some `transformer_engine` layer or not.
98
+ """
99
+ if not is_transformer_engine_available():
100
+ raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
101
+
102
+ if is_hpu_available():
103
+ import intel_transformer_engine as te
104
+
105
+ module_cls_to_check = te.Linear
106
+ else:
107
+ import transformer_engine.pytorch as te
108
+
109
+ module_cls_to_check = (te.LayerNorm, te.Linear, te.TransformerLayer)
110
+
111
+ for m in model.modules():
112
+ if isinstance(m, module_cls_to_check):
113
+ return True
114
+
115
+ return False
116
+
117
+
118
+ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
119
+ """
120
+ Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
121
+ disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
122
+ """
123
+ if not is_transformer_engine_available():
124
+ raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
125
+
126
+ if is_hpu_available():
127
+ from intel_transformer_engine import fp8_autocast
128
+ else:
129
+ from transformer_engine.pytorch import fp8_autocast
130
+
131
+ def forward(self, *args, **kwargs):
132
+ enabled = use_during_eval or self.training
133
+ with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe):
134
+ return model_forward(*args, **kwargs)
135
+
136
+ # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
137
+ forward.__wrapped__ = model_forward
138
+
139
+ return forward
140
+
141
+
142
+ def apply_fp8_autowrap(model, fp8_recipe_handler):
143
+ """
144
+ Applies FP8 context manager to the model's forward method
145
+ """
146
+ if not is_transformer_engine_available():
147
+ raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
148
+
149
+ if is_hpu_available():
150
+ import intel_transformer_engine.recipe as te_recipe
151
+
152
+ is_fp8_block_scaling_available = False
153
+ message = "MXFP8 block scaling is not available on HPU."
154
+
155
+ else:
156
+ import transformer_engine.common.recipe as te_recipe
157
+ import transformer_engine.pytorch as te
158
+
159
+ is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
160
+
161
+ kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
162
+ if "fp8_format" in kwargs:
163
+ kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
164
+ use_during_eval = kwargs.pop("use_autocast_during_eval", False)
165
+ use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
166
+
167
+ if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
168
+ raise ValueError(f"MXFP8 block scaling is not available: {message}")
169
+
170
+ if use_mxfp8_block_scaling:
171
+ if "amax_compute_algo" in kwargs:
172
+ raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
173
+ if "amax_history_len" in kwargs:
174
+ raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
175
+ fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
176
+ else:
177
+ fp8_recipe = te_recipe.DelayedScaling(**kwargs)
178
+
179
+ new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
180
+
181
+ if hasattr(model.forward, "__func__"):
182
+ model.forward = MethodType(new_forward, model)
183
+ else:
184
+ model.forward = new_forward
185
+
186
+ return model
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/utils/versions.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.metadata
16
+ from typing import Union
17
+
18
+ from packaging.version import Version, parse
19
+
20
+ from .constants import STR_OPERATION_TO_FUNC
21
+
22
+
23
+ torch_version = parse(importlib.metadata.version("torch"))
24
+
25
+
26
+ def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
27
+ """
28
+ Compares a library version to some requirement using a given operation.
29
+
30
+ Args:
31
+ library_or_version (`str` or `packaging.version.Version`):
32
+ A library name or a version to check.
33
+ operation (`str`):
34
+ A string representation of an operator, such as `">"` or `"<="`.
35
+ requirement_version (`str`):
36
+ The version to compare the library version against
37
+ """
38
+ if operation not in STR_OPERATION_TO_FUNC.keys():
39
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
40
+ operation = STR_OPERATION_TO_FUNC[operation]
41
+ if isinstance(library_or_version, str):
42
+ library_or_version = parse(importlib.metadata.version(library_or_version))
43
+ return operation(library_or_version, parse(requirement_version))
44
+
45
+
46
+ def is_torch_version(operation: str, version: str):
47
+ """
48
+ Compares the current PyTorch version to a given reference with an operation.
49
+
50
+ Args:
51
+ operation (`str`):
52
+ A string representation of an operator, such as `">"` or `"<="`
53
+ version (`str`):
54
+ A string version of PyTorch
55
+ """
56
+ return compare_versions(torch_version, operation, version)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (283 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/annotated_doc/__pycache__/main.cpython-312.pyc ADDED
Binary file (1.93 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/cuda_pathfinder-1.4.0.dist-info/licenses/LICENSE ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.38 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/__version__.cpython-312.pyc ADDED
Binary file (619 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align.cpython-312.pyc ADDED
Binary file (1.37 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_align_getter.cpython-312.pyc ADDED
Binary file (1.99 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_base.cpython-312.pyc ADDED
Binary file (3.77 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_column.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_common.cpython-312.pyc ADDED
Binary file (3.47 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_container.cpython-312.pyc ADDED
Binary file (9.41 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_converter.cpython-312.pyc ADDED
Binary file (5.21 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_dataproperty.cpython-312.pyc ADDED
Binary file (15.8 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_extractor.cpython-312.pyc ADDED
Binary file (34.7 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_formatter.cpython-312.pyc ADDED
Binary file (4.91 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_function.cpython-312.pyc ADDED
Binary file (5.44 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_interface.cpython-312.pyc ADDED
Binary file (1.64 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_line_break.cpython-312.pyc ADDED
Binary file (546 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/_preprocessor.cpython-312.pyc ADDED
Binary file (8.11 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/__pycache__/typing.cpython-312.pyc ADDED
Binary file (1.97 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from ._logger import logger, set_logger # type: ignore
2
+
3
+
4
+ __all__ = (
5
+ "logger",
6
+ "set_logger",
7
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (307 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_logger.cpython-312.pyc ADDED
Binary file (950 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/dataproperty/logger/__pycache__/_null_logger.cpython-312.pyc ADDED
Binary file (2.04 kB). View file