xiaoanyu123 commited on
Commit
f45df05
·
verified ·
1 Parent(s): 59ad30d

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/accelerate/commands/launch.py +1245 -0
  2. pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER +1 -0
  3. pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/METADATA +441 -0
  4. pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/RECORD +31 -0
  5. pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL +5 -0
  6. pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt +27 -0
  7. pythonProject/.venv/Lib/site-packages/colorama/__pycache__/__init__.cpython-310.pyc +0 -0
  8. pythonProject/.venv/Lib/site-packages/colorama/ansi.py +102 -0
  9. pythonProject/.venv/Lib/site-packages/colorama/ansitowin32.py +277 -0
  10. pythonProject/.venv/Lib/site-packages/colorama/initialise.py +121 -0
  11. pythonProject/.venv/Lib/site-packages/colorama/win32.py +180 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/callbacks.py +244 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/configuration_utils.py +769 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_check.py +34 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_table.py +54 -0
  16. pythonProject/.venv/Lib/site-packages/diffusers/image_processor.py +1451 -0
  17. pythonProject/.venv/Lib/site-packages/diffusers/optimization.py +361 -0
  18. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_normals.cpython-310.pyc +0 -0
  19. pythonProject/.venv/Lib/site-packages/diffusers/py.typed +0 -0
  20. pythonProject/.venv/Lib/site-packages/diffusers/training_utils.py +730 -0
pythonProject/.venv/Lib/site-packages/accelerate/commands/launch.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import importlib
19
+ import logging
20
+ import os
21
+ import subprocess
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ import psutil
26
+ import torch
27
+
28
+ from accelerate.commands.config import default_config_file, load_config_from_file
29
+ from accelerate.commands.config.config_args import SageMakerConfig
30
+ from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
31
+ from accelerate.commands.utils import CustomArgumentParser
32
+ from accelerate.state import get_int_from_env
33
+ from accelerate.utils import (
34
+ ComputeEnvironment,
35
+ DistributedType,
36
+ PrepareForLaunch,
37
+ _filter_args,
38
+ check_cuda_p2p_ib_support,
39
+ convert_dict_to_env_variables,
40
+ is_bf16_available,
41
+ is_deepspeed_available,
42
+ is_hpu_available,
43
+ is_mlu_available,
44
+ is_musa_available,
45
+ is_npu_available,
46
+ is_rich_available,
47
+ is_sagemaker_available,
48
+ is_sdaa_available,
49
+ is_torch_xla_available,
50
+ is_xpu_available,
51
+ patch_environment,
52
+ prepare_deepspeed_cmd_env,
53
+ prepare_multi_gpu_env,
54
+ prepare_sagemager_args_inputs,
55
+ prepare_simple_launcher_cmd_env,
56
+ prepare_tpu,
57
+ str_to_bool,
58
+ )
59
+ from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
60
+
61
+
62
+ if is_rich_available():
63
+ from rich import get_console
64
+ from rich.logging import RichHandler
65
+
66
+ FORMAT = "%(message)s"
67
+ logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()])
68
+
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ options_to_group = {
74
+ "multi_gpu": "Distributed GPUs",
75
+ "tpu": "TPU",
76
+ "use_deepspeed": "DeepSpeed Arguments",
77
+ "use_fsdp": "FSDP Arguments",
78
+ "use_megatron_lm": "Megatron-LM Arguments",
79
+ "fp8_backend": "FP8 Arguments",
80
+ }
81
+
82
+
83
+ def clean_option(option):
84
+ "Finds all cases of - after the first two characters and changes them to _"
85
+ if "fp8_backend" in option:
86
+ option = "--fp8_backend"
87
+ if option.startswith("--"):
88
+ return option[2:].replace("-", "_")
89
+
90
+
91
+ class CustomHelpFormatter(argparse.HelpFormatter):
92
+ """
93
+ This is a custom help formatter that will hide all arguments that are not used in the command line when the help is
94
+ called. This is useful for the case where the user is using a specific platform and only wants to see the arguments
95
+ for that platform.
96
+ """
97
+
98
+ def __init__(self, *args, **kwargs):
99
+ super().__init__(*args, **kwargs)
100
+ self.titles = [
101
+ "Hardware Selection Arguments",
102
+ "Resource Selection Arguments",
103
+ "Training Paradigm Arguments",
104
+ "positional arguments",
105
+ "optional arguments",
106
+ ]
107
+
108
+ def add_argument(self, action: argparse.Action):
109
+ if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]:
110
+ args = sys.argv[2:]
111
+ else:
112
+ args = sys.argv[1:]
113
+
114
+ if len(args) > 1:
115
+ args = list(map(clean_option, args))
116
+ used_platforms = [arg for arg in args if arg in options_to_group.keys()]
117
+ used_titles = [options_to_group[o] for o in used_platforms]
118
+ if action.container.title not in self.titles + used_titles:
119
+ action.help = argparse.SUPPRESS
120
+ elif action.container.title == "Hardware Selection Arguments":
121
+ if set(action.option_strings).isdisjoint(set(args)):
122
+ action.help = argparse.SUPPRESS
123
+ else:
124
+ action.help = action.help + " (currently selected)"
125
+ elif action.container.title == "Training Paradigm Arguments":
126
+ if set(action.option_strings).isdisjoint(set(args)):
127
+ action.help = argparse.SUPPRESS
128
+ else:
129
+ action.help = action.help + " (currently selected)"
130
+
131
+ action.option_strings = [s for s in action.option_strings if "-" not in s[2:]]
132
+ super().add_argument(action)
133
+
134
+ def end_section(self):
135
+ if len(self._current_section.items) < 2:
136
+ self._current_section.items = []
137
+ self._current_section.heading = ""
138
+ super().end_section()
139
+
140
+
141
+ def launch_command_parser(subparsers=None):
142
+ description = "Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)"
143
+ if subparsers is not None:
144
+ parser = subparsers.add_parser(
145
+ "launch", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter
146
+ )
147
+ else:
148
+ parser = CustomArgumentParser(
149
+ "Accelerate launch command",
150
+ description=description,
151
+ add_help=False,
152
+ allow_abbrev=False,
153
+ formatter_class=CustomHelpFormatter,
154
+ )
155
+
156
+ parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.")
157
+
158
+ parser.add_argument(
159
+ "--config_file",
160
+ default=None,
161
+ help="The config file to use for the default values in the launching script.",
162
+ )
163
+ parser.add_argument(
164
+ "--quiet",
165
+ "-q",
166
+ action="store_true",
167
+ help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)",
168
+ )
169
+ # Hardware selection arguments
170
+ hardware_args = parser.add_argument_group(
171
+ "Hardware Selection Arguments", "Arguments for selecting the hardware to be used."
172
+ )
173
+ hardware_args.add_argument(
174
+ "--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU."
175
+ )
176
+ hardware_args.add_argument(
177
+ "--multi_gpu",
178
+ default=False,
179
+ action="store_true",
180
+ help="Whether or not this should launch a distributed GPU training.",
181
+ )
182
+ hardware_args.add_argument(
183
+ "--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
184
+ )
185
+ # Resource selection arguments
186
+ resource_args = parser.add_argument_group(
187
+ "Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
188
+ )
189
+ resource_args.add_argument(
190
+ "--mixed_precision",
191
+ type=str,
192
+ choices=["no", "fp16", "bf16", "fp8"],
193
+ help="Whether or not to use mixed precision training. "
194
+ "Choose between FP16 and BF16 (bfloat16) training. "
195
+ "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
196
+ )
197
+ resource_args.add_argument(
198
+ "--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel."
199
+ )
200
+ resource_args.add_argument(
201
+ "--num_machines", type=int, default=None, help="The total number of machines used in this training."
202
+ )
203
+ resource_args.add_argument(
204
+ "--num_cpu_threads_per_process",
205
+ type=int,
206
+ default=None,
207
+ help="The number of CPU threads per process. Can be tuned for optimal performance.",
208
+ )
209
+ resource_args.add_argument(
210
+ "--enable_cpu_affinity",
211
+ default=False,
212
+ action="store_true",
213
+ help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.",
214
+ )
215
+ # Dynamo arguments
216
+ resource_args.add_argument(
217
+ "--dynamo_backend",
218
+ type=str,
219
+ choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS],
220
+ help="Choose a backend to optimize your training with dynamo, see more at "
221
+ "https://github.com/pytorch/torchdynamo.",
222
+ )
223
+ resource_args.add_argument(
224
+ "--dynamo_mode",
225
+ type=str,
226
+ default="default",
227
+ choices=TORCH_DYNAMO_MODES,
228
+ help="Choose a mode to optimize your training with dynamo.",
229
+ )
230
+ resource_args.add_argument(
231
+ "--dynamo_use_fullgraph",
232
+ default=False,
233
+ action="store_true",
234
+ help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
235
+ )
236
+ resource_args.add_argument(
237
+ "--dynamo_use_dynamic",
238
+ default=False,
239
+ action="store_true",
240
+ help="Whether to enable dynamic shape tracing.",
241
+ )
242
+ resource_args.add_argument(
243
+ "--dynamo_use_regional_compilation",
244
+ default=False,
245
+ action="store_true",
246
+ help="Whether to enable regional compilation.",
247
+ )
248
+
249
+ # Training Paradigm arguments
250
+ paradigm_args = parser.add_argument_group(
251
+ "Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used."
252
+ )
253
+ paradigm_args.add_argument(
254
+ "--use_deepspeed",
255
+ default=False,
256
+ action="store_true",
257
+ help="Whether to use deepspeed.",
258
+ )
259
+ paradigm_args.add_argument(
260
+ "--use_fsdp",
261
+ default=False,
262
+ action="store_true",
263
+ help="Whether to use fsdp.",
264
+ )
265
+ paradigm_args.add_argument(
266
+ "--use_parallelism_config",
267
+ default=False,
268
+ action="store_true",
269
+ help="Whether to use the parallelism config to configure the N-d distributed training.",
270
+ )
271
+ paradigm_args.add_argument(
272
+ "--use_megatron_lm",
273
+ default=False,
274
+ action="store_true",
275
+ help="Whether to use Megatron-LM.",
276
+ )
277
+
278
+ paradigm_args.add_argument(
279
+ "--use_xpu",
280
+ default=None,
281
+ action="store_true",
282
+ help="Whether to use IPEX plugin to speed up training on XPU specifically. This argument is deprecated and ignored, will be removed in Accelerate v1.20.",
283
+ )
284
+
285
+ # distributed GPU training arguments
286
+ distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.")
287
+ distributed_args.add_argument(
288
+ "--gpu_ids",
289
+ default=None,
290
+ help="What GPUs (by id) should be used for training on this machine as a comma-separated list",
291
+ )
292
+ distributed_args.add_argument(
293
+ "--same_network",
294
+ default=False,
295
+ action="store_true",
296
+ help="Whether all machines used for multinode training exist on the same local network.",
297
+ )
298
+ distributed_args.add_argument(
299
+ "--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched."
300
+ )
301
+ distributed_args.add_argument(
302
+ "--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0."
303
+ )
304
+ distributed_args.add_argument(
305
+ "--main_process_port",
306
+ type=int,
307
+ default=None,
308
+ help="The port to use to communicate with the machine of rank 0.",
309
+ )
310
+ distributed_args.add_argument(
311
+ "-t",
312
+ "--tee",
313
+ default="0",
314
+ type=str,
315
+ help="Tee std streams into a log file and also to console.",
316
+ )
317
+ distributed_args.add_argument(
318
+ "--log_dir",
319
+ type=str,
320
+ default=None,
321
+ help=(
322
+ "Base directory to use for log files when using torchrun/torch.distributed.run as launcher. "
323
+ "Use with --tee to redirect std streams info log files."
324
+ ),
325
+ )
326
+ distributed_args.add_argument(
327
+ "--role",
328
+ type=str,
329
+ default="default",
330
+ help="User-defined role for the workers.",
331
+ )
332
+ # Rendezvous related arguments
333
+ distributed_args.add_argument(
334
+ "--rdzv_backend",
335
+ type=str,
336
+ default="static",
337
+ help="The rendezvous method to use, such as 'static' (the default) or 'c10d'",
338
+ )
339
+ distributed_args.add_argument(
340
+ "--rdzv_conf",
341
+ type=str,
342
+ default="",
343
+ help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
344
+ )
345
+ distributed_args.add_argument(
346
+ "--max_restarts",
347
+ type=int,
348
+ default=0,
349
+ help="Maximum number of worker group restarts before failing.",
350
+ )
351
+ distributed_args.add_argument(
352
+ "--monitor_interval",
353
+ type=float,
354
+ default=0.1,
355
+ help="Interval, in seconds, to monitor the state of workers.",
356
+ )
357
+ parser.add_argument(
358
+ "-m",
359
+ "--module",
360
+ action="store_true",
361
+ help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.",
362
+ )
363
+ parser.add_argument(
364
+ "--no_python",
365
+ action="store_true",
366
+ help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.",
367
+ )
368
+
369
+ # TPU arguments
370
+ tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.")
371
+ tpu_args.add_argument(
372
+ "--tpu_cluster",
373
+ action="store_true",
374
+ dest="tpu_use_cluster",
375
+ help="Whether to use a GCP TPU pod for training.",
376
+ )
377
+ tpu_args.add_argument(
378
+ "--no_tpu_cluster",
379
+ action="store_false",
380
+ dest="tpu_use_cluster",
381
+ help="Should not be passed explicitly, this is for internal use only.",
382
+ )
383
+ tpu_args.add_argument(
384
+ "--tpu_use_sudo",
385
+ action="store_true",
386
+ help="Whether to use `sudo` when running the TPU training script in each pod.",
387
+ )
388
+ tpu_args.add_argument(
389
+ "--vm",
390
+ type=str,
391
+ action="append",
392
+ help=(
393
+ "List of single Compute VM instance names. "
394
+ "If not provided we assume usage of instance groups. For TPU pods."
395
+ ),
396
+ )
397
+ tpu_args.add_argument(
398
+ "--env",
399
+ type=str,
400
+ action="append",
401
+ help="List of environment variables to set on the Compute VM instances. For TPU pods.",
402
+ )
403
+ tpu_args.add_argument(
404
+ "--main_training_function",
405
+ type=str,
406
+ default=None,
407
+ help="The name of the main function to be executed in your script (only for TPU training).",
408
+ )
409
+ tpu_args.add_argument(
410
+ "--downcast_bf16",
411
+ action="store_true",
412
+ help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.",
413
+ )
414
+
415
+ # DeepSpeed arguments
416
+ deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.")
417
+ deepspeed_args.add_argument(
418
+ "--deepspeed_config_file",
419
+ default=None,
420
+ type=str,
421
+ help="DeepSpeed config file.",
422
+ )
423
+ deepspeed_args.add_argument(
424
+ "--zero_stage",
425
+ default=None,
426
+ type=int,
427
+ help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). "
428
+ "If unspecified, will default to `2`.",
429
+ )
430
+ deepspeed_args.add_argument(
431
+ "--offload_optimizer_device",
432
+ default=None,
433
+ type=str,
434
+ help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
435
+ "If unspecified, will default to 'none'.",
436
+ )
437
+ deepspeed_args.add_argument(
438
+ "--offload_param_device",
439
+ default=None,
440
+ type=str,
441
+ help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). "
442
+ "If unspecified, will default to 'none'.",
443
+ )
444
+ deepspeed_args.add_argument(
445
+ "--offload_optimizer_nvme_path",
446
+ default=None,
447
+ type=str,
448
+ help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
449
+ "If unspecified, will default to 'none'.",
450
+ )
451
+ deepspeed_args.add_argument(
452
+ "--offload_param_nvme_path",
453
+ default=None,
454
+ type=str,
455
+ help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). "
456
+ "If unspecified, will default to 'none'.",
457
+ )
458
+ deepspeed_args.add_argument(
459
+ "--gradient_accumulation_steps",
460
+ default=None,
461
+ type=int,
462
+ help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). "
463
+ "If unspecified, will default to `1`.",
464
+ )
465
+ deepspeed_args.add_argument(
466
+ "--gradient_clipping",
467
+ default=None,
468
+ type=float,
469
+ help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). "
470
+ "If unspecified, will default to `1.0`.",
471
+ )
472
+ deepspeed_args.add_argument(
473
+ "--zero3_init_flag",
474
+ default=None,
475
+ type=str,
476
+ help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. "
477
+ "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.",
478
+ )
479
+ deepspeed_args.add_argument(
480
+ "--zero3_save_16bit_model",
481
+ default=None,
482
+ type=str,
483
+ help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
484
+ "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.",
485
+ )
486
+ deepspeed_args.add_argument(
487
+ "--deepspeed_hostfile",
488
+ default=None,
489
+ type=str,
490
+ help="DeepSpeed hostfile for configuring multi-node compute resources.",
491
+ )
492
+ deepspeed_args.add_argument(
493
+ "--deepspeed_exclusion_filter",
494
+ default=None,
495
+ type=str,
496
+ help="DeepSpeed exclusion filter string when using mutli-node setup.",
497
+ )
498
+ deepspeed_args.add_argument(
499
+ "--deepspeed_inclusion_filter",
500
+ default=None,
501
+ type=str,
502
+ help="DeepSpeed inclusion filter string when using mutli-node setup.",
503
+ )
504
+ deepspeed_args.add_argument(
505
+ "--deepspeed_multinode_launcher",
506
+ default=None,
507
+ type=str,
508
+ help="DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.",
509
+ )
510
+ deepspeed_args.add_argument(
511
+ "--deepspeed_moe_layer_cls_names",
512
+ default=None,
513
+ type=str,
514
+ help="comma-separated list of transformer MoE layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..."
515
+ " (useful only when `use_deepspeed` flag is passed).",
516
+ )
517
+
518
+ # fsdp arguments
519
+ fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.")
520
+ fsdp_args.add_argument(
521
+ "--fsdp_version",
522
+ type=str,
523
+ default="1",
524
+ choices=["1", "2"],
525
+ help="FSDP version to use. (useful only when `use_fsdp` flag is passed).",
526
+ )
527
+ fsdp_args.add_argument(
528
+ "--fsdp_offload_params",
529
+ default="false",
530
+ type=str,
531
+ help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).",
532
+ )
533
+ fsdp_args.add_argument(
534
+ "--fsdp_min_num_params",
535
+ type=int,
536
+ default=1e8,
537
+ help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).",
538
+ )
539
+ # We enable this for backwards compatibility, throw a warning if this is set in `FullyShardedDataParallelPlugin`
540
+ fsdp_args.add_argument(
541
+ "--fsdp_sharding_strategy",
542
+ type=str,
543
+ default="FULL_SHARD",
544
+ help="FSDP's sharding strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version=1`).",
545
+ )
546
+ fsdp_args.add_argument(
547
+ "--fsdp_reshard_after_forward",
548
+ type=str,
549
+ default="true",
550
+ help="FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).",
551
+ )
552
+ fsdp_args.add_argument(
553
+ "--fsdp_auto_wrap_policy",
554
+ type=str,
555
+ default=None,
556
+ help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).",
557
+ )
558
+ fsdp_args.add_argument(
559
+ "--fsdp_transformer_layer_cls_to_wrap",
560
+ default=None,
561
+ type=str,
562
+ help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
563
+ "(useful only when `use_fsdp` flag is passed).",
564
+ )
565
+ fsdp_args.add_argument(
566
+ "--fsdp_backward_prefetch",
567
+ default=None,
568
+ type=str,
569
+ help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
570
+ )
571
+ fsdp_args.add_argument(
572
+ "--fsdp_state_dict_type",
573
+ default=None,
574
+ type=str,
575
+ help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).",
576
+ )
577
+ fsdp_args.add_argument(
578
+ "--fsdp_forward_prefetch",
579
+ default="false",
580
+ type=str,
581
+ help="If True, then FSDP explicitly prefetches the next upcoming "
582
+ "all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).",
583
+ )
584
+ fsdp_args.add_argument(
585
+ "--fsdp_use_orig_params",
586
+ default="true",
587
+ type=str,
588
+ help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
589
+ " (useful only when `use_fsdp` flag is passed).",
590
+ )
591
+ fsdp_args.add_argument(
592
+ "--fsdp_cpu_ram_efficient_loading",
593
+ default="true",
594
+ type=str,
595
+ help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
596
+ "Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. "
597
+ "(useful only when `use_fsdp` flag is passed).",
598
+ )
599
+ fsdp_args.add_argument(
600
+ "--fsdp_sync_module_states",
601
+ default="true",
602
+ type=str,
603
+ help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
604
+ " (useful only when `use_fsdp` flag is passed).",
605
+ )
606
+ fsdp_args.add_argument(
607
+ "--fsdp_activation_checkpointing",
608
+ default="false",
609
+ type=str,
610
+ help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
611
+ )
612
+
613
+ # megatron_lm args
614
+ megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
615
+ megatron_lm_args.add_argument(
616
+ "--megatron_lm_tp_degree",
617
+ type=int,
618
+ default=1,
619
+ help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).",
620
+ )
621
+ megatron_lm_args.add_argument(
622
+ "--megatron_lm_pp_degree",
623
+ type=int,
624
+ default=1,
625
+ help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).",
626
+ )
627
+ megatron_lm_args.add_argument(
628
+ "--megatron_lm_num_micro_batches",
629
+ type=int,
630
+ default=None,
631
+ help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).",
632
+ )
633
+ megatron_lm_args.add_argument(
634
+ "--megatron_lm_sequence_parallelism",
635
+ default=None,
636
+ type=str,
637
+ help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. "
638
+ "(useful only when `use_megatron_lm` flag is passed).",
639
+ )
640
+ megatron_lm_args.add_argument(
641
+ "--megatron_lm_recompute_activations",
642
+ default=None,
643
+ type=str,
644
+ help="Decides Whether (true|false) to enable Selective Activation Recomputation. "
645
+ "(useful only when `use_megatron_lm` flag is passed).",
646
+ )
647
+ megatron_lm_args.add_argument(
648
+ "--megatron_lm_use_distributed_optimizer",
649
+ default=None,
650
+ type=str,
651
+ help="Decides Whether (true|false) to use distributed optimizer "
652
+ "which shards optimizer state and gradients across Data Pralellel (DP) ranks. "
653
+ "(useful only when `use_megatron_lm` flag is passed).",
654
+ )
655
+ megatron_lm_args.add_argument(
656
+ "--megatron_lm_gradient_clipping",
657
+ default=1.0,
658
+ type=float,
659
+ help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). "
660
+ "(useful only when `use_megatron_lm` flag is passed).",
661
+ )
662
+
663
+ # FP8 arguments
664
+ fp8_args = parser.add_argument_group(
665
+ "FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)"
666
+ )
667
+ fp8_args.add_argument(
668
+ "--fp8_backend",
669
+ type=str,
670
+ choices=["te", "msamp"],
671
+ help="Choose a backend to train with FP8 (te: TransformerEngine, msamp: MS-AMP)",
672
+ )
673
+ fp8_args.add_argument(
674
+ "--fp8_use_autocast_during_eval",
675
+ default=False,
676
+ action="store_true",
677
+ help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.",
678
+ )
679
+ fp8_args.add_argument(
680
+ "--fp8_margin",
681
+ type=int,
682
+ default=0,
683
+ help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).",
684
+ )
685
+ fp8_args.add_argument(
686
+ "--fp8_interval",
687
+ type=int,
688
+ default=1,
689
+ help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).",
690
+ )
691
+ fp8_args.add_argument(
692
+ "--fp8_format",
693
+ type=str,
694
+ default="HYBRID",
695
+ choices=["HYBRID", "E4M3", "E5M2"],
696
+ help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
697
+ )
698
+ fp8_args.add_argument(
699
+ "--fp8_amax_history_len",
700
+ type=int,
701
+ default=1024,
702
+ help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).",
703
+ )
704
+ fp8_args.add_argument(
705
+ "--fp8_amax_compute_algo",
706
+ type=str,
707
+ default="most_recent",
708
+ choices=["max", "most_recent"],
709
+ help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).",
710
+ )
711
+ fp8_args.add_argument(
712
+ "--fp8_override_linear_precision",
713
+ type=lambda x: tuple(map(str_to_bool, x.split(","))),
714
+ default=(False, False, False),
715
+ help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-separated string of booleans (useful only when `--fp8_backend=te` is passed).",
716
+ )
717
+ fp8_args.add_argument(
718
+ "--fp8_opt_level",
719
+ type=str,
720
+ default="O2",
721
+ choices=["O1", "O2"],
722
+ help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
723
+ )
724
+
725
+ # AWS arguments
726
+ aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
727
+ aws_args.add_argument(
728
+ "--aws_access_key_id",
729
+ type=str,
730
+ default=None,
731
+ help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job",
732
+ )
733
+ aws_args.add_argument(
734
+ "--aws_secret_access_key",
735
+ type=str,
736
+ default=None,
737
+ help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.",
738
+ )
739
+ parser.add_argument(
740
+ "--debug",
741
+ action="store_true",
742
+ help="Whether to print out the torch.distributed stack trace when something fails.",
743
+ )
744
+ parser.add_argument(
745
+ "training_script",
746
+ type=str,
747
+ help=(
748
+ "The full path to the script to be launched in parallel, followed by all the arguments for the training "
749
+ "script."
750
+ ),
751
+ )
752
+
753
+ # MPI arguments
754
+ mpirun_args = parser.add_argument_group("MPI Arguments", "Arguments related to mpirun for Multi-CPU")
755
+ mpirun_args.add_argument(
756
+ "--mpirun_hostfile",
757
+ type=str,
758
+ default=None,
759
+ help="Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will "
760
+ "get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.",
761
+ )
762
+ mpirun_args.add_argument(
763
+ "--mpirun_ccl",
764
+ type=int,
765
+ default=1,
766
+ help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.",
767
+ )
768
+
769
+ # ParallelismConfig arguments
770
+ parallelism_config_args = parser.add_argument_group(
771
+ "ParallelismConfig Arguments",
772
+ "Arguments related to the ParallelismConfig used for distributed training.",
773
+ )
774
+ parallelism_config_args.add_argument(
775
+ "--parallelism_config_dp_replicate_size",
776
+ type=int,
777
+ default=1,
778
+ help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
779
+ )
780
+
781
+ parallelism_config_args.add_argument(
782
+ "--parallelism_config_dp_shard_size",
783
+ type=int,
784
+ default=1,
785
+ help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
786
+ )
787
+
788
+ parallelism_config_args.add_argument(
789
+ "--parallelism_config_tp_size",
790
+ type=int,
791
+ default=1,
792
+ help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
793
+ )
794
+
795
+ parallelism_config_args.add_argument(
796
+ "--parallelism_config_cp_size",
797
+ type=int,
798
+ default=1,
799
+ help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
800
+ )
801
+ parallelism_config_args.add_argument(
802
+ "--parallelism_config_cp_comm_strategy",
803
+ type=str,
804
+ default="allgather",
805
+ help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
806
+ )
807
+
808
+ # Other arguments of the training scripts
809
+ parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
810
+
811
+ if subparsers is not None:
812
+ parser.set_defaults(func=launch_command)
813
+ return parser
814
+
815
+
816
+ def simple_launcher(args):
817
+ cmd, current_env = prepare_simple_launcher_cmd_env(args)
818
+
819
+ process = subprocess.Popen(cmd, env=current_env)
820
+ process.wait()
821
+ if process.returncode != 0:
822
+ if not args.quiet:
823
+ raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
824
+ else:
825
+ sys.exit(1)
826
+
827
+
828
+ def multi_gpu_launcher(args):
829
+ import torch.distributed.run as distrib_run
830
+
831
+ current_env = prepare_multi_gpu_env(args)
832
+ if not check_cuda_p2p_ib_support():
833
+ message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
834
+ warn = False
835
+ if "NCCL_P2P_DISABLE" not in current_env:
836
+ current_env["NCCL_P2P_DISABLE"] = "1"
837
+ warn = True
838
+ if "NCCL_IB_DISABLE" not in current_env:
839
+ current_env["NCCL_IB_DISABLE"] = "1"
840
+ warn = True
841
+ if warn:
842
+ logger.warning(message)
843
+
844
+ debug = getattr(args, "debug", False)
845
+ args = _filter_args(
846
+ args,
847
+ distrib_run.get_args_parser(),
848
+ ["--training_script", args.training_script, "--training_script_args", args.training_script_args],
849
+ )
850
+
851
+ with patch_environment(**current_env):
852
+ try:
853
+ distrib_run.run(args)
854
+ except Exception:
855
+ if is_rich_available() and debug:
856
+ console = get_console()
857
+ console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
858
+ console.print_exception(suppress=[__file__], show_locals=False)
859
+ else:
860
+ raise
861
+
862
+
863
+ def deepspeed_launcher(args):
864
+ import torch.distributed.run as distrib_run
865
+
866
+ if not is_deepspeed_available():
867
+ raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
868
+ else:
869
+ from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME
870
+
871
+ cmd, current_env = prepare_deepspeed_cmd_env(args)
872
+ if not check_cuda_p2p_ib_support():
873
+ message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
874
+ warn = False
875
+ if "NCCL_P2P_DISABLE" not in current_env:
876
+ current_env["NCCL_P2P_DISABLE"] = "1"
877
+ warn = True
878
+ if "NCCL_IB_DISABLE" not in current_env:
879
+ current_env["NCCL_IB_DISABLE"] = "1"
880
+ warn = True
881
+ if warn:
882
+ logger.warning(message)
883
+
884
+ if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
885
+ with open(DEEPSPEED_ENVIRONMENT_NAME, "a") as f:
886
+ valid_env_items = convert_dict_to_env_variables(current_env)
887
+ if len(valid_env_items) > 1:
888
+ f.writelines(valid_env_items)
889
+
890
+ process = subprocess.Popen(cmd, env=current_env)
891
+ process.wait()
892
+ if process.returncode != 0:
893
+ if not args.quiet:
894
+ raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
895
+ else:
896
+ sys.exit(1)
897
+ else:
898
+ debug = getattr(args, "debug", False)
899
+ args = _filter_args(
900
+ args,
901
+ distrib_run.get_args_parser(),
902
+ ["--training_script", args.training_script, "--training_script_args", args.training_script_args],
903
+ )
904
+ with patch_environment(**current_env):
905
+ try:
906
+ distrib_run.run(args)
907
+ except Exception:
908
+ if is_rich_available() and debug:
909
+ console = get_console()
910
+ console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
911
+ console.print_exception(suppress=[__file__], show_locals=False)
912
+ else:
913
+ raise
914
+
915
+
916
+ def tpu_launcher(args):
917
+ import torch_xla.distributed.xla_multiprocessing as xmp
918
+
919
+ if args.no_python:
920
+ raise ValueError("--no_python cannot be used with TPU launcher")
921
+
922
+ args, current_env = prepare_tpu(args, {})
923
+
924
+ if args.module:
925
+ mod_name = args.training_script
926
+ else:
927
+ # Import training_script as a module
928
+ script_path = Path(args.training_script)
929
+ sys.path.append(str(script_path.parent.resolve()))
930
+ mod_name = script_path.stem
931
+
932
+ mod = importlib.import_module(mod_name)
933
+ if not hasattr(mod, args.main_training_function):
934
+ raise ValueError(
935
+ f"Your training script should have a function named {args.main_training_function}, or you should pass a "
936
+ "different value to `--main_training_function`."
937
+ )
938
+
939
+ # Patch sys.argv
940
+ sys.argv = [mod.__file__] + args.training_script_args
941
+
942
+ main_function = getattr(mod, args.main_training_function)
943
+ with patch_environment(**current_env):
944
+ xmp.spawn(PrepareForLaunch(main_function), args=())
945
+
946
+
947
+ def tpu_pod_launcher(args):
948
+ from torch_xla.distributed import xla_dist
949
+
950
+ current_env = {}
951
+ args, current_env = prepare_tpu(args, current_env, True)
952
+ debug = getattr(args, "debug", False)
953
+
954
+ training_script = args.training_script
955
+ training_script_args = args.training_script_args
956
+ new_args = _filter_args(
957
+ args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"]
958
+ )
959
+
960
+ if args.tpu_use_sudo:
961
+ new_cmd = ["sudo"]
962
+ else:
963
+ new_cmd = []
964
+
965
+ new_cmd += [
966
+ "accelerate-launch",
967
+ "--tpu",
968
+ "--no_tpu_cluster",
969
+ "--num_machines",
970
+ "1",
971
+ "--mixed_precision",
972
+ "no",
973
+ "--dynamo_backend",
974
+ "no",
975
+ "--num_processes",
976
+ str(args.num_processes),
977
+ "--main_training_function",
978
+ str(args.main_training_function),
979
+ training_script,
980
+ ] + training_script_args
981
+
982
+ new_args.positional = new_cmd
983
+ bad_flags = ""
984
+ for arg in vars(new_args):
985
+ if arg.startswith("docker_"):
986
+ value = getattr(new_args, arg)
987
+ if value != "" and value is not None:
988
+ bad_flags += f'{arg}="{value}"\n'
989
+ if bad_flags != "":
990
+ raise ValueError(
991
+ f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}"
992
+ )
993
+ new_args.env = [f"{k}={v}" for k, v in current_env.items()]
994
+ new_args.env.append("ACCELERATE_IN_TPU_POD=1")
995
+ try:
996
+ xla_dist.resolve_and_execute(new_args)
997
+ except Exception:
998
+ if is_rich_available() and debug:
999
+ console = get_console()
1000
+ console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]")
1001
+ console.print_exception(suppress=[__file__], show_locals=False)
1002
+ else:
1003
+ raise
1004
+
1005
+
1006
+ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
1007
+ if not is_sagemaker_available():
1008
+ raise ImportError(
1009
+ "Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`"
1010
+ )
1011
+ if args.module or args.no_python:
1012
+ raise ValueError(
1013
+ "SageMaker requires a python training script file and cannot be used with --module or --no_python"
1014
+ )
1015
+
1016
+ from sagemaker.huggingface import HuggingFace
1017
+
1018
+ args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args)
1019
+
1020
+ huggingface_estimator = HuggingFace(**args)
1021
+
1022
+ huggingface_estimator.fit(inputs=sagemaker_inputs)
1023
+ print(f"You can find your model data at: {huggingface_estimator.model_data}")
1024
+
1025
+
1026
+ def _validate_launch_command(args):
1027
+ # Sanity checks
1028
+ if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
1029
+ raise ValueError(
1030
+ "You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
1031
+ )
1032
+ if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
1033
+ raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
1034
+
1035
+ if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
1036
+ raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
1037
+
1038
+ defaults = None
1039
+ warned = []
1040
+ mp_from_config_flag = False
1041
+ # Get the default from the config file.
1042
+ if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
1043
+ defaults = load_config_from_file(args.config_file)
1044
+ if (
1045
+ not args.multi_gpu
1046
+ and not args.tpu
1047
+ and not args.tpu_use_cluster
1048
+ and not args.use_deepspeed
1049
+ and not args.use_fsdp
1050
+ and not args.use_megatron_lm
1051
+ ):
1052
+ args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
1053
+ args.multi_gpu = (
1054
+ True
1055
+ if defaults.distributed_type
1056
+ in (
1057
+ DistributedType.MULTI_GPU,
1058
+ DistributedType.MULTI_NPU,
1059
+ DistributedType.MULTI_MLU,
1060
+ DistributedType.MULTI_SDAA,
1061
+ DistributedType.MULTI_MUSA,
1062
+ DistributedType.MULTI_XPU,
1063
+ DistributedType.MULTI_HPU,
1064
+ )
1065
+ else False
1066
+ )
1067
+ args.tpu = defaults.distributed_type == DistributedType.XLA
1068
+ args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
1069
+ args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
1070
+ args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
1071
+ args.use_parallelism_config = defaults.parallelism_config != {}
1072
+ if args.gpu_ids is None:
1073
+ if defaults.gpu_ids is not None:
1074
+ args.gpu_ids = defaults.gpu_ids
1075
+ else:
1076
+ args.gpu_ids = "all"
1077
+
1078
+ if args.multi_gpu and args.num_machines is None:
1079
+ args.num_machines = defaults.num_machines
1080
+
1081
+ if len(args.gpu_ids.split(",")) < 2 and (args.gpu_ids != "all") and args.multi_gpu and args.num_machines <= 1:
1082
+ raise ValueError(
1083
+ "Less than two GPU ids were configured and tried to run on on multiple GPUs. "
1084
+ "Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`."
1085
+ )
1086
+ if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
1087
+ # Update args with the defaults
1088
+ for name, attr in defaults.__dict__.items():
1089
+ if isinstance(attr, dict):
1090
+ # Copy defaults.somedict.somearg to args.somearg and
1091
+ # defaults.fsdp_config.x to args.fsdp_x
1092
+ for key, value in attr.items():
1093
+ if name == "fsdp_config" and not key.startswith("fsdp"):
1094
+ key = "fsdp_" + key
1095
+ elif name == "fp8_config" and not key.startswith("fp8"):
1096
+ key = "fp8_" + key
1097
+ if hasattr(args, "nondefault") and key not in args.nondefault:
1098
+ setattr(args, key, value)
1099
+ elif (
1100
+ name not in ["compute_environment", "mixed_precision", "distributed_type"]
1101
+ and getattr(args, name, None) is None
1102
+ ):
1103
+ # Those args are handled separately
1104
+ setattr(args, name, attr)
1105
+ if not args.debug:
1106
+ args.debug = defaults.debug
1107
+
1108
+ if not args.mixed_precision:
1109
+ if defaults.mixed_precision is None:
1110
+ args.mixed_precision = "no"
1111
+ else:
1112
+ args.mixed_precision = defaults.mixed_precision
1113
+ mp_from_config_flag = True
1114
+ else:
1115
+ native_amp = is_bf16_available(True)
1116
+ if (
1117
+ args.mixed_precision == "bf16"
1118
+ and not native_amp
1119
+ and not (args.tpu and is_torch_xla_available(check_is_tpu=True))
1120
+ ):
1121
+ raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
1122
+
1123
+ # Silently set the default here
1124
+ if args.dynamo_backend is None:
1125
+ args.dynamo_backend = "no"
1126
+ if args.num_processes == -1:
1127
+ raise ValueError("You need to manually pass in `--num_processes` using this config yaml.")
1128
+ else:
1129
+ if args.num_processes is None:
1130
+ if is_xpu_available():
1131
+ args.num_processes = torch.xpu.device_count()
1132
+ elif is_mlu_available():
1133
+ args.num_processes = torch.mlu.device_count()
1134
+ elif is_sdaa_available():
1135
+ args.num_processes = torch.sdaa.device_count()
1136
+ elif is_musa_available():
1137
+ args.num_processes = torch.musa.device_count()
1138
+ elif is_npu_available():
1139
+ args.num_processes = torch.npu.device_count()
1140
+ elif is_hpu_available():
1141
+ args.num_processes = torch.hpu.device_count()
1142
+ else:
1143
+ args.num_processes = torch.cuda.device_count()
1144
+ warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
1145
+ if args.debug is None:
1146
+ args.debug = False
1147
+ if (
1148
+ not args.multi_gpu
1149
+ and args.num_processes > 1
1150
+ and (
1151
+ (is_xpu_available() and torch.xpu.device_count() > 1)
1152
+ or (is_npu_available() and torch.npu.device_count() > 1)
1153
+ or (is_hpu_available() and torch.hpu.device_count() > 1)
1154
+ or (is_mlu_available() and torch.mlu.device_count() > 1)
1155
+ or (is_sdaa_available() and torch.sdaa.device_count() > 1)
1156
+ or (is_musa_available() and torch.musa.device_count() > 1)
1157
+ or (torch.cuda.is_available() and torch.cuda.device_count() > 1)
1158
+ )
1159
+ ):
1160
+ warned.append(
1161
+ "\t\tMore than one GPU was found, enabling multi-GPU training.\n"
1162
+ "\t\tIf this was unintended please pass in `--num_processes=1`."
1163
+ )
1164
+ args.multi_gpu = True
1165
+ if args.num_machines is None:
1166
+ warned.append("\t`--num_machines` was set to a value of `1`")
1167
+ args.num_machines = 1
1168
+ if args.mixed_precision is None:
1169
+ warned.append("\t`--mixed_precision` was set to a value of `'no'`")
1170
+ args.mixed_precision = "no"
1171
+ if not hasattr(args, "use_cpu"):
1172
+ args.use_cpu = args.cpu
1173
+ if args.dynamo_backend is None:
1174
+ warned.append("\t`--dynamo_backend` was set to a value of `'no'`")
1175
+ args.dynamo_backend = "no"
1176
+ if args.debug:
1177
+ logger.debug("Running script in debug mode, expect distributed operations to be slightly slower.")
1178
+
1179
+ is_aws_env_disabled = defaults is None or (
1180
+ defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER
1181
+ )
1182
+ if is_aws_env_disabled and args.num_cpu_threads_per_process is None:
1183
+ args.num_cpu_threads_per_process = get_int_from_env(["OMP_NUM_THREADS"], 1)
1184
+ if args.use_cpu and args.num_processes >= 1 and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0:
1185
+ local_size = get_int_from_env(
1186
+ ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"],
1187
+ max(int(args.num_processes / args.num_machines), 1),
1188
+ )
1189
+ threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
1190
+ if threads_per_process > 1:
1191
+ args.num_cpu_threads_per_process = threads_per_process
1192
+ warned.append(
1193
+ f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs"
1194
+ )
1195
+
1196
+ if args.use_xpu is not None:
1197
+ logger.warning(
1198
+ "use_xpu is deprecated and ignored, will be removed in Accelerate v1.20. "
1199
+ "XPU is a PyTorch native citizen now, we don't need extra argument to enable it any more."
1200
+ )
1201
+
1202
+ if any(warned):
1203
+ message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n"
1204
+ message += "\n".join(warned)
1205
+ message += (
1206
+ "\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`."
1207
+ )
1208
+ logger.warning(message)
1209
+ return args, defaults, mp_from_config_flag
1210
+
1211
+
1212
+ def launch_command(args):
1213
+ args, defaults, mp_from_config_flag = _validate_launch_command(args)
1214
+ # Use the proper launcher
1215
+ if args.use_deepspeed and not args.cpu:
1216
+ args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else []
1217
+ if mp_from_config_flag:
1218
+ args.deepspeed_fields_from_accelerate_config.append("mixed_precision")
1219
+ args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config)
1220
+ deepspeed_launcher(args)
1221
+ elif args.use_fsdp and not args.cpu:
1222
+ multi_gpu_launcher(args)
1223
+ elif args.use_megatron_lm and not args.cpu:
1224
+ multi_gpu_launcher(args)
1225
+ elif args.multi_gpu and not args.cpu:
1226
+ multi_gpu_launcher(args)
1227
+ elif args.tpu and not args.cpu:
1228
+ if args.tpu_use_cluster:
1229
+ tpu_pod_launcher(args)
1230
+ else:
1231
+ tpu_launcher(args)
1232
+ elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
1233
+ sagemaker_launcher(defaults, args)
1234
+ else:
1235
+ simple_launcher(args)
1236
+
1237
+
1238
+ def main():
1239
+ parser = launch_command_parser()
1240
+ args = parser.parse_args()
1241
+ launch_command(args)
1242
+
1243
+
1244
+ if __name__ == "__main__":
1245
+ main()
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/METADATA ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: colorama
3
+ Version: 0.4.6
4
+ Summary: Cross-platform colored terminal text.
5
+ Project-URL: Homepage, https://github.com/tartley/colorama
6
+ Author-email: Jonathan Hartley <tartley@tartley.com>
7
+ License-File: LICENSE.txt
8
+ Keywords: ansi,color,colour,crossplatform,terminal,text,windows,xplatform
9
+ Classifier: Development Status :: 5 - Production/Stable
10
+ Classifier: Environment :: Console
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: BSD License
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python
15
+ Classifier: Programming Language :: Python :: 2
16
+ Classifier: Programming Language :: Python :: 2.7
17
+ Classifier: Programming Language :: Python :: 3
18
+ Classifier: Programming Language :: Python :: 3.7
19
+ Classifier: Programming Language :: Python :: 3.8
20
+ Classifier: Programming Language :: Python :: 3.9
21
+ Classifier: Programming Language :: Python :: 3.10
22
+ Classifier: Programming Language :: Python :: Implementation :: CPython
23
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
24
+ Classifier: Topic :: Terminals
25
+ Requires-Python: !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7
26
+ Description-Content-Type: text/x-rst
27
+
28
+ .. image:: https://img.shields.io/pypi/v/colorama.svg
29
+ :target: https://pypi.org/project/colorama/
30
+ :alt: Latest Version
31
+
32
+ .. image:: https://img.shields.io/pypi/pyversions/colorama.svg
33
+ :target: https://pypi.org/project/colorama/
34
+ :alt: Supported Python versions
35
+
36
+ .. image:: https://github.com/tartley/colorama/actions/workflows/test.yml/badge.svg
37
+ :target: https://github.com/tartley/colorama/actions/workflows/test.yml
38
+ :alt: Build Status
39
+
40
+ Colorama
41
+ ========
42
+
43
+ Makes ANSI escape character sequences (for producing colored terminal text and
44
+ cursor positioning) work under MS Windows.
45
+
46
+ .. |donate| image:: https://www.paypalobjects.com/en_US/i/btn/btn_donate_SM.gif
47
+ :target: https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=2MZ9D2GMLYCUJ&item_name=Colorama&currency_code=USD
48
+ :alt: Donate with Paypal
49
+
50
+ `PyPI for releases <https://pypi.org/project/colorama/>`_ |
51
+ `Github for source <https://github.com/tartley/colorama>`_ |
52
+ `Colorama for enterprise on Tidelift <https://github.com/tartley/colorama/blob/master/ENTERPRISE.md>`_
53
+
54
+ If you find Colorama useful, please |donate| to the authors. Thank you!
55
+
56
+ Installation
57
+ ------------
58
+
59
+ Tested on CPython 2.7, 3.7, 3.8, 3.9 and 3.10 and Pypy 2.7 and 3.8.
60
+
61
+ No requirements other than the standard library.
62
+
63
+ .. code-block:: bash
64
+
65
+ pip install colorama
66
+ # or
67
+ conda install -c anaconda colorama
68
+
69
+ Description
70
+ -----------
71
+
72
+ ANSI escape character sequences have long been used to produce colored terminal
73
+ text and cursor positioning on Unix and Macs. Colorama makes this work on
74
+ Windows, too, by wrapping ``stdout``, stripping ANSI sequences it finds (which
75
+ would appear as gobbledygook in the output), and converting them into the
76
+ appropriate win32 calls to modify the state of the terminal. On other platforms,
77
+ Colorama does nothing.
78
+
79
+ This has the upshot of providing a simple cross-platform API for printing
80
+ colored terminal text from Python, and has the happy side-effect that existing
81
+ applications or libraries which use ANSI sequences to produce colored output on
82
+ Linux or Macs can now also work on Windows, simply by calling
83
+ ``colorama.just_fix_windows_console()`` (since v0.4.6) or ``colorama.init()``
84
+ (all versions, but may have other side-effects – see below).
85
+
86
+ An alternative approach is to install ``ansi.sys`` on Windows machines, which
87
+ provides the same behaviour for all applications running in terminals. Colorama
88
+ is intended for situations where that isn't easy (e.g., maybe your app doesn't
89
+ have an installer.)
90
+
91
+ Demo scripts in the source code repository print some colored text using
92
+ ANSI sequences. Compare their output under Gnome-terminal's built in ANSI
93
+ handling, versus on Windows Command-Prompt using Colorama:
94
+
95
+ .. image:: https://github.com/tartley/colorama/raw/master/screenshots/ubuntu-demo.png
96
+ :width: 661
97
+ :height: 357
98
+ :alt: ANSI sequences on Ubuntu under gnome-terminal.
99
+
100
+ .. image:: https://github.com/tartley/colorama/raw/master/screenshots/windows-demo.png
101
+ :width: 668
102
+ :height: 325
103
+ :alt: Same ANSI sequences on Windows, using Colorama.
104
+
105
+ These screenshots show that, on Windows, Colorama does not support ANSI 'dim
106
+ text'; it looks the same as 'normal text'.
107
+
108
+ Usage
109
+ -----
110
+
111
+ Initialisation
112
+ ..............
113
+
114
+ If the only thing you want from Colorama is to get ANSI escapes to work on
115
+ Windows, then run:
116
+
117
+ .. code-block:: python
118
+
119
+ from colorama import just_fix_windows_console
120
+ just_fix_windows_console()
121
+
122
+ If you're on a recent version of Windows 10 or better, and your stdout/stderr
123
+ are pointing to a Windows console, then this will flip the magic configuration
124
+ switch to enable Windows' built-in ANSI support.
125
+
126
+ If you're on an older version of Windows, and your stdout/stderr are pointing to
127
+ a Windows console, then this will wrap ``sys.stdout`` and/or ``sys.stderr`` in a
128
+ magic file object that intercepts ANSI escape sequences and issues the
129
+ appropriate Win32 calls to emulate them.
130
+
131
+ In all other circumstances, it does nothing whatsoever. Basically the idea is
132
+ that this makes Windows act like Unix with respect to ANSI escape handling.
133
+
134
+ It's safe to call this function multiple times. It's safe to call this function
135
+ on non-Windows platforms, but it won't do anything. It's safe to call this
136
+ function when one or both of your stdout/stderr are redirected to a file – it
137
+ won't do anything to those streams.
138
+
139
+ Alternatively, you can use the older interface with more features (but also more
140
+ potential footguns):
141
+
142
+ .. code-block:: python
143
+
144
+ from colorama import init
145
+ init()
146
+
147
+ This does the same thing as ``just_fix_windows_console``, except for the
148
+ following differences:
149
+
150
+ - It's not safe to call ``init`` multiple times; you can end up with multiple
151
+ layers of wrapping and broken ANSI support.
152
+
153
+ - Colorama will apply a heuristic to guess whether stdout/stderr support ANSI,
154
+ and if it thinks they don't, then it will wrap ``sys.stdout`` and
155
+ ``sys.stderr`` in a magic file object that strips out ANSI escape sequences
156
+ before printing them. This happens on all platforms, and can be convenient if
157
+ you want to write your code to emit ANSI escape sequences unconditionally, and
158
+ let Colorama decide whether they should actually be output. But note that
159
+ Colorama's heuristic is not particularly clever.
160
+
161
+ - ``init`` also accepts explicit keyword args to enable/disable various
162
+ functionality – see below.
163
+
164
+ To stop using Colorama before your program exits, simply call ``deinit()``.
165
+ This will restore ``stdout`` and ``stderr`` to their original values, so that
166
+ Colorama is disabled. To resume using Colorama again, call ``reinit()``; it is
167
+ cheaper than calling ``init()`` again (but does the same thing).
168
+
169
+ Most users should depend on ``colorama >= 0.4.6``, and use
170
+ ``just_fix_windows_console``. The old ``init`` interface will be supported
171
+ indefinitely for backwards compatibility, but we don't plan to fix any issues
172
+ with it, also for backwards compatibility.
173
+
174
+ Colored Output
175
+ ..............
176
+
177
+ Cross-platform printing of colored text can then be done using Colorama's
178
+ constant shorthand for ANSI escape sequences. These are deliberately
179
+ rudimentary, see below.
180
+
181
+ .. code-block:: python
182
+
183
+ from colorama import Fore, Back, Style
184
+ print(Fore.RED + 'some red text')
185
+ print(Back.GREEN + 'and with a green background')
186
+ print(Style.DIM + 'and in dim text')
187
+ print(Style.RESET_ALL)
188
+ print('back to normal now')
189
+
190
+ ...or simply by manually printing ANSI sequences from your own code:
191
+
192
+ .. code-block:: python
193
+
194
+ print('\033[31m' + 'some red text')
195
+ print('\033[39m') # and reset to default color
196
+
197
+ ...or, Colorama can be used in conjunction with existing ANSI libraries
198
+ such as the venerable `Termcolor <https://pypi.org/project/termcolor/>`_
199
+ the fabulous `Blessings <https://pypi.org/project/blessings/>`_,
200
+ or the incredible `_Rich <https://pypi.org/project/rich/>`_.
201
+
202
+ If you wish Colorama's Fore, Back and Style constants were more capable,
203
+ then consider using one of the above highly capable libraries to generate
204
+ colors, etc, and use Colorama just for its primary purpose: to convert
205
+ those ANSI sequences to also work on Windows:
206
+
207
+ SIMILARLY, do not send PRs adding the generation of new ANSI types to Colorama.
208
+ We are only interested in converting ANSI codes to win32 API calls, not
209
+ shortcuts like the above to generate ANSI characters.
210
+
211
+ .. code-block:: python
212
+
213
+ from colorama import just_fix_windows_console
214
+ from termcolor import colored
215
+
216
+ # use Colorama to make Termcolor work on Windows too
217
+ just_fix_windows_console()
218
+
219
+ # then use Termcolor for all colored text output
220
+ print(colored('Hello, World!', 'green', 'on_red'))
221
+
222
+ Available formatting constants are::
223
+
224
+ Fore: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET.
225
+ Back: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET.
226
+ Style: DIM, NORMAL, BRIGHT, RESET_ALL
227
+
228
+ ``Style.RESET_ALL`` resets foreground, background, and brightness. Colorama will
229
+ perform this reset automatically on program exit.
230
+
231
+ These are fairly well supported, but not part of the standard::
232
+
233
+ Fore: LIGHTBLACK_EX, LIGHTRED_EX, LIGHTGREEN_EX, LIGHTYELLOW_EX, LIGHTBLUE_EX, LIGHTMAGENTA_EX, LIGHTCYAN_EX, LIGHTWHITE_EX
234
+ Back: LIGHTBLACK_EX, LIGHTRED_EX, LIGHTGREEN_EX, LIGHTYELLOW_EX, LIGHTBLUE_EX, LIGHTMAGENTA_EX, LIGHTCYAN_EX, LIGHTWHITE_EX
235
+
236
+ Cursor Positioning
237
+ ..................
238
+
239
+ ANSI codes to reposition the cursor are supported. See ``demos/demo06.py`` for
240
+ an example of how to generate them.
241
+
242
+ Init Keyword Args
243
+ .................
244
+
245
+ ``init()`` accepts some ``**kwargs`` to override default behaviour.
246
+
247
+ init(autoreset=False):
248
+ If you find yourself repeatedly sending reset sequences to turn off color
249
+ changes at the end of every print, then ``init(autoreset=True)`` will
250
+ automate that:
251
+
252
+ .. code-block:: python
253
+
254
+ from colorama import init
255
+ init(autoreset=True)
256
+ print(Fore.RED + 'some red text')
257
+ print('automatically back to default color again')
258
+
259
+ init(strip=None):
260
+ Pass ``True`` or ``False`` to override whether ANSI codes should be
261
+ stripped from the output. The default behaviour is to strip if on Windows
262
+ or if output is redirected (not a tty).
263
+
264
+ init(convert=None):
265
+ Pass ``True`` or ``False`` to override whether to convert ANSI codes in the
266
+ output into win32 calls. The default behaviour is to convert if on Windows
267
+ and output is to a tty (terminal).
268
+
269
+ init(wrap=True):
270
+ On Windows, Colorama works by replacing ``sys.stdout`` and ``sys.stderr``
271
+ with proxy objects, which override the ``.write()`` method to do their work.
272
+ If this wrapping causes you problems, then this can be disabled by passing
273
+ ``init(wrap=False)``. The default behaviour is to wrap if ``autoreset`` or
274
+ ``strip`` or ``convert`` are True.
275
+
276
+ When wrapping is disabled, colored printing on non-Windows platforms will
277
+ continue to work as normal. To do cross-platform colored output, you can
278
+ use Colorama's ``AnsiToWin32`` proxy directly:
279
+
280
+ .. code-block:: python
281
+
282
+ import sys
283
+ from colorama import init, AnsiToWin32
284
+ init(wrap=False)
285
+ stream = AnsiToWin32(sys.stderr).stream
286
+
287
+ # Python 2
288
+ print >>stream, Fore.BLUE + 'blue text on stderr'
289
+
290
+ # Python 3
291
+ print(Fore.BLUE + 'blue text on stderr', file=stream)
292
+
293
+ Recognised ANSI Sequences
294
+ .........................
295
+
296
+ ANSI sequences generally take the form::
297
+
298
+ ESC [ <param> ; <param> ... <command>
299
+
300
+ Where ``<param>`` is an integer, and ``<command>`` is a single letter. Zero or
301
+ more params are passed to a ``<command>``. If no params are passed, it is
302
+ generally synonymous with passing a single zero. No spaces exist in the
303
+ sequence; they have been inserted here simply to read more easily.
304
+
305
+ The only ANSI sequences that Colorama converts into win32 calls are::
306
+
307
+ ESC [ 0 m # reset all (colors and brightness)
308
+ ESC [ 1 m # bright
309
+ ESC [ 2 m # dim (looks same as normal brightness)
310
+ ESC [ 22 m # normal brightness
311
+
312
+ # FOREGROUND:
313
+ ESC [ 30 m # black
314
+ ESC [ 31 m # red
315
+ ESC [ 32 m # green
316
+ ESC [ 33 m # yellow
317
+ ESC [ 34 m # blue
318
+ ESC [ 35 m # magenta
319
+ ESC [ 36 m # cyan
320
+ ESC [ 37 m # white
321
+ ESC [ 39 m # reset
322
+
323
+ # BACKGROUND
324
+ ESC [ 40 m # black
325
+ ESC [ 41 m # red
326
+ ESC [ 42 m # green
327
+ ESC [ 43 m # yellow
328
+ ESC [ 44 m # blue
329
+ ESC [ 45 m # magenta
330
+ ESC [ 46 m # cyan
331
+ ESC [ 47 m # white
332
+ ESC [ 49 m # reset
333
+
334
+ # cursor positioning
335
+ ESC [ y;x H # position cursor at x across, y down
336
+ ESC [ y;x f # position cursor at x across, y down
337
+ ESC [ n A # move cursor n lines up
338
+ ESC [ n B # move cursor n lines down
339
+ ESC [ n C # move cursor n characters forward
340
+ ESC [ n D # move cursor n characters backward
341
+
342
+ # clear the screen
343
+ ESC [ mode J # clear the screen
344
+
345
+ # clear the line
346
+ ESC [ mode K # clear the line
347
+
348
+ Multiple numeric params to the ``'m'`` command can be combined into a single
349
+ sequence::
350
+
351
+ ESC [ 36 ; 45 ; 1 m # bright cyan text on magenta background
352
+
353
+ All other ANSI sequences of the form ``ESC [ <param> ; <param> ... <command>``
354
+ are silently stripped from the output on Windows.
355
+
356
+ Any other form of ANSI sequence, such as single-character codes or alternative
357
+ initial characters, are not recognised or stripped. It would be cool to add
358
+ them though. Let me know if it would be useful for you, via the Issues on
359
+ GitHub.
360
+
361
+ Status & Known Problems
362
+ -----------------------
363
+
364
+ I've personally only tested it on Windows XP (CMD, Console2), Ubuntu
365
+ (gnome-terminal, xterm), and OS X.
366
+
367
+ Some valid ANSI sequences aren't recognised.
368
+
369
+ If you're hacking on the code, see `README-hacking.md`_. ESPECIALLY, see the
370
+ explanation there of why we do not want PRs that allow Colorama to generate new
371
+ types of ANSI codes.
372
+
373
+ See outstanding issues and wish-list:
374
+ https://github.com/tartley/colorama/issues
375
+
376
+ If anything doesn't work for you, or doesn't do what you expected or hoped for,
377
+ I'd love to hear about it on that issues list, would be delighted by patches,
378
+ and would be happy to grant commit access to anyone who submits a working patch
379
+ or two.
380
+
381
+ .. _README-hacking.md: README-hacking.md
382
+
383
+ License
384
+ -------
385
+
386
+ Copyright Jonathan Hartley & Arnon Yaari, 2013-2020. BSD 3-Clause license; see
387
+ LICENSE file.
388
+
389
+ Professional support
390
+ --------------------
391
+
392
+ .. |tideliftlogo| image:: https://cdn2.hubspot.net/hubfs/4008838/website/logos/logos_for_download/Tidelift_primary-shorthand-logo.png
393
+ :alt: Tidelift
394
+ :target: https://tidelift.com/subscription/pkg/pypi-colorama?utm_source=pypi-colorama&utm_medium=referral&utm_campaign=readme
395
+
396
+ .. list-table::
397
+ :widths: 10 100
398
+
399
+ * - |tideliftlogo|
400
+ - Professional support for colorama is available as part of the
401
+ `Tidelift Subscription`_.
402
+ Tidelift gives software development teams a single source for purchasing
403
+ and maintaining their software, with professional grade assurances from
404
+ the experts who know it best, while seamlessly integrating with existing
405
+ tools.
406
+
407
+ .. _Tidelift Subscription: https://tidelift.com/subscription/pkg/pypi-colorama?utm_source=pypi-colorama&utm_medium=referral&utm_campaign=readme
408
+
409
+ Thanks
410
+ ------
411
+
412
+ See the CHANGELOG for more thanks!
413
+
414
+ * Marc Schlaich (schlamar) for a ``setup.py`` fix for Python2.5.
415
+ * Marc Abramowitz, reported & fixed a crash on exit with closed ``stdout``,
416
+ providing a solution to issue #7's setuptools/distutils debate,
417
+ and other fixes.
418
+ * User 'eryksun', for guidance on correctly instantiating ``ctypes.windll``.
419
+ * Matthew McCormick for politely pointing out a longstanding crash on non-Win.
420
+ * Ben Hoyt, for a magnificent fix under 64-bit Windows.
421
+ * Jesse at Empty Square for submitting a fix for examples in the README.
422
+ * User 'jamessp', an observant documentation fix for cursor positioning.
423
+ * User 'vaal1239', Dave Mckee & Lackner Kristof for a tiny but much-needed Win7
424
+ fix.
425
+ * Julien Stuyck, for wisely suggesting Python3 compatible updates to README.
426
+ * Daniel Griffith for multiple fabulous patches.
427
+ * Oscar Lesta for a valuable fix to stop ANSI chars being sent to non-tty
428
+ output.
429
+ * Roger Binns, for many suggestions, valuable feedback, & bug reports.
430
+ * Tim Golden for thought and much appreciated feedback on the initial idea.
431
+ * User 'Zearin' for updates to the README file.
432
+ * John Szakmeister for adding support for light colors
433
+ * Charles Merriam for adding documentation to demos
434
+ * Jurko for a fix on 64-bit Windows CPython2.5 w/o ctypes
435
+ * Florian Bruhin for a fix when stdout or stderr are None
436
+ * Thomas Weininger for fixing ValueError on Windows
437
+ * Remi Rampin for better Github integration and fixes to the README file
438
+ * Simeon Visser for closing a file handle using 'with' and updating classifiers
439
+ to include Python 3.3 and 3.4
440
+ * Andy Neff for fixing RESET of LIGHT_EX colors.
441
+ * Jonathan Hartley for the initial idea and implementation.
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/RECORD ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ colorama-0.4.6.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ colorama-0.4.6.dist-info/METADATA,sha256=e67SnrUMOym9sz_4TjF3vxvAV4T3aF7NyqRHHH3YEMw,17158
3
+ colorama-0.4.6.dist-info/RECORD,,
4
+ colorama-0.4.6.dist-info/WHEEL,sha256=cdcF4Fbd0FPtw2EMIOwH-3rSOTUdTCeOSXRMD1iLUb8,105
5
+ colorama-0.4.6.dist-info/licenses/LICENSE.txt,sha256=ysNcAmhuXQSlpxQL-zs25zrtSWZW6JEQLkKIhteTAxg,1491
6
+ colorama/__init__.py,sha256=wePQA4U20tKgYARySLEC047ucNX-g8pRLpYBuiHlLb8,266
7
+ colorama/__pycache__/__init__.cpython-310.pyc,,
8
+ colorama/__pycache__/ansi.cpython-310.pyc,,
9
+ colorama/__pycache__/ansitowin32.cpython-310.pyc,,
10
+ colorama/__pycache__/initialise.cpython-310.pyc,,
11
+ colorama/__pycache__/win32.cpython-310.pyc,,
12
+ colorama/__pycache__/winterm.cpython-310.pyc,,
13
+ colorama/ansi.py,sha256=Top4EeEuaQdBWdteKMEcGOTeKeF19Q-Wo_6_Cj5kOzQ,2522
14
+ colorama/ansitowin32.py,sha256=vPNYa3OZbxjbuFyaVo0Tmhmy1FZ1lKMWCnT7odXpItk,11128
15
+ colorama/initialise.py,sha256=-hIny86ClXo39ixh5iSCfUIa2f_h_bgKRDW7gqs-KLU,3325
16
+ colorama/tests/__init__.py,sha256=MkgPAEzGQd-Rq0w0PZXSX2LadRWhUECcisJY8lSrm4Q,75
17
+ colorama/tests/__pycache__/__init__.cpython-310.pyc,,
18
+ colorama/tests/__pycache__/ansi_test.cpython-310.pyc,,
19
+ colorama/tests/__pycache__/ansitowin32_test.cpython-310.pyc,,
20
+ colorama/tests/__pycache__/initialise_test.cpython-310.pyc,,
21
+ colorama/tests/__pycache__/isatty_test.cpython-310.pyc,,
22
+ colorama/tests/__pycache__/utils.cpython-310.pyc,,
23
+ colorama/tests/__pycache__/winterm_test.cpython-310.pyc,,
24
+ colorama/tests/ansi_test.py,sha256=FeViDrUINIZcr505PAxvU4AjXz1asEiALs9GXMhwRaE,2839
25
+ colorama/tests/ansitowin32_test.py,sha256=RN7AIhMJ5EqDsYaCjVo-o4u8JzDD4ukJbmevWKS70rY,10678
26
+ colorama/tests/initialise_test.py,sha256=BbPy-XfyHwJ6zKozuQOvNvQZzsx9vdb_0bYXn7hsBTc,6741
27
+ colorama/tests/isatty_test.py,sha256=Pg26LRpv0yQDB5Ac-sxgVXG7hsA1NYvapFgApZfYzZg,1866
28
+ colorama/tests/utils.py,sha256=1IIRylG39z5-dzq09R_ngufxyPZxgldNbrxKxUGwGKE,1079
29
+ colorama/tests/winterm_test.py,sha256=qoWFPEjym5gm2RuMwpf3pOis3a5r_PJZFCzK254JL8A,3709
30
+ colorama/win32.py,sha256=YQOKwMTwtGBbsY4dL5HYTvwTeP9wIQra5MvPNddpxZs,6181
31
+ colorama/winterm.py,sha256=XCQFDHjPi6AHYNdZwy0tA02H-Jh48Jp-HvCjeLeLp3U,7134
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.11.1
3
+ Root-Is-Purelib: true
4
+ Tag: py2-none-any
5
+ Tag: py3-none-any
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2010 Jonathan Hartley
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of the copyright holders, nor those of its contributors
15
+ may be used to endorse or promote products derived from this software without
16
+ specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
pythonProject/.venv/Lib/site-packages/colorama/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (451 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/colorama/ansi.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
2
+ '''
3
+ This module generates ANSI character codes to printing colors to terminals.
4
+ See: http://en.wikipedia.org/wiki/ANSI_escape_code
5
+ '''
6
+
7
+ CSI = '\033['
8
+ OSC = '\033]'
9
+ BEL = '\a'
10
+
11
+
12
+ def code_to_chars(code):
13
+ return CSI + str(code) + 'm'
14
+
15
+ def set_title(title):
16
+ return OSC + '2;' + title + BEL
17
+
18
+ def clear_screen(mode=2):
19
+ return CSI + str(mode) + 'J'
20
+
21
+ def clear_line(mode=2):
22
+ return CSI + str(mode) + 'K'
23
+
24
+
25
+ class AnsiCodes(object):
26
+ def __init__(self):
27
+ # the subclasses declare class attributes which are numbers.
28
+ # Upon instantiation we define instance attributes, which are the same
29
+ # as the class attributes but wrapped with the ANSI escape sequence
30
+ for name in dir(self):
31
+ if not name.startswith('_'):
32
+ value = getattr(self, name)
33
+ setattr(self, name, code_to_chars(value))
34
+
35
+
36
+ class AnsiCursor(object):
37
+ def UP(self, n=1):
38
+ return CSI + str(n) + 'A'
39
+ def DOWN(self, n=1):
40
+ return CSI + str(n) + 'B'
41
+ def FORWARD(self, n=1):
42
+ return CSI + str(n) + 'C'
43
+ def BACK(self, n=1):
44
+ return CSI + str(n) + 'D'
45
+ def POS(self, x=1, y=1):
46
+ return CSI + str(y) + ';' + str(x) + 'H'
47
+
48
+
49
+ class AnsiFore(AnsiCodes):
50
+ BLACK = 30
51
+ RED = 31
52
+ GREEN = 32
53
+ YELLOW = 33
54
+ BLUE = 34
55
+ MAGENTA = 35
56
+ CYAN = 36
57
+ WHITE = 37
58
+ RESET = 39
59
+
60
+ # These are fairly well supported, but not part of the standard.
61
+ LIGHTBLACK_EX = 90
62
+ LIGHTRED_EX = 91
63
+ LIGHTGREEN_EX = 92
64
+ LIGHTYELLOW_EX = 93
65
+ LIGHTBLUE_EX = 94
66
+ LIGHTMAGENTA_EX = 95
67
+ LIGHTCYAN_EX = 96
68
+ LIGHTWHITE_EX = 97
69
+
70
+
71
+ class AnsiBack(AnsiCodes):
72
+ BLACK = 40
73
+ RED = 41
74
+ GREEN = 42
75
+ YELLOW = 43
76
+ BLUE = 44
77
+ MAGENTA = 45
78
+ CYAN = 46
79
+ WHITE = 47
80
+ RESET = 49
81
+
82
+ # These are fairly well supported, but not part of the standard.
83
+ LIGHTBLACK_EX = 100
84
+ LIGHTRED_EX = 101
85
+ LIGHTGREEN_EX = 102
86
+ LIGHTYELLOW_EX = 103
87
+ LIGHTBLUE_EX = 104
88
+ LIGHTMAGENTA_EX = 105
89
+ LIGHTCYAN_EX = 106
90
+ LIGHTWHITE_EX = 107
91
+
92
+
93
+ class AnsiStyle(AnsiCodes):
94
+ BRIGHT = 1
95
+ DIM = 2
96
+ NORMAL = 22
97
+ RESET_ALL = 0
98
+
99
+ Fore = AnsiFore()
100
+ Back = AnsiBack()
101
+ Style = AnsiStyle()
102
+ Cursor = AnsiCursor()
pythonProject/.venv/Lib/site-packages/colorama/ansitowin32.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
2
+ import re
3
+ import sys
4
+ import os
5
+
6
+ from .ansi import AnsiFore, AnsiBack, AnsiStyle, Style, BEL
7
+ from .winterm import enable_vt_processing, WinTerm, WinColor, WinStyle
8
+ from .win32 import windll, winapi_test
9
+
10
+
11
+ winterm = None
12
+ if windll is not None:
13
+ winterm = WinTerm()
14
+
15
+
16
+ class StreamWrapper(object):
17
+ '''
18
+ Wraps a stream (such as stdout), acting as a transparent proxy for all
19
+ attribute access apart from method 'write()', which is delegated to our
20
+ Converter instance.
21
+ '''
22
+ def __init__(self, wrapped, converter):
23
+ # double-underscore everything to prevent clashes with names of
24
+ # attributes on the wrapped stream object.
25
+ self.__wrapped = wrapped
26
+ self.__convertor = converter
27
+
28
+ def __getattr__(self, name):
29
+ return getattr(self.__wrapped, name)
30
+
31
+ def __enter__(self, *args, **kwargs):
32
+ # special method lookup bypasses __getattr__/__getattribute__, see
33
+ # https://stackoverflow.com/questions/12632894/why-doesnt-getattr-work-with-exit
34
+ # thus, contextlib magic methods are not proxied via __getattr__
35
+ return self.__wrapped.__enter__(*args, **kwargs)
36
+
37
+ def __exit__(self, *args, **kwargs):
38
+ return self.__wrapped.__exit__(*args, **kwargs)
39
+
40
+ def __setstate__(self, state):
41
+ self.__dict__ = state
42
+
43
+ def __getstate__(self):
44
+ return self.__dict__
45
+
46
+ def write(self, text):
47
+ self.__convertor.write(text)
48
+
49
+ def isatty(self):
50
+ stream = self.__wrapped
51
+ if 'PYCHARM_HOSTED' in os.environ:
52
+ if stream is not None and (stream is sys.__stdout__ or stream is sys.__stderr__):
53
+ return True
54
+ try:
55
+ stream_isatty = stream.isatty
56
+ except AttributeError:
57
+ return False
58
+ else:
59
+ return stream_isatty()
60
+
61
+ @property
62
+ def closed(self):
63
+ stream = self.__wrapped
64
+ try:
65
+ return stream.closed
66
+ # AttributeError in the case that the stream doesn't support being closed
67
+ # ValueError for the case that the stream has already been detached when atexit runs
68
+ except (AttributeError, ValueError):
69
+ return True
70
+
71
+
72
+ class AnsiToWin32(object):
73
+ '''
74
+ Implements a 'write()' method which, on Windows, will strip ANSI character
75
+ sequences from the text, and if outputting to a tty, will convert them into
76
+ win32 function calls.
77
+ '''
78
+ ANSI_CSI_RE = re.compile('\001?\033\\[((?:\\d|;)*)([a-zA-Z])\002?') # Control Sequence Introducer
79
+ ANSI_OSC_RE = re.compile('\001?\033\\]([^\a]*)(\a)\002?') # Operating System Command
80
+
81
+ def __init__(self, wrapped, convert=None, strip=None, autoreset=False):
82
+ # The wrapped stream (normally sys.stdout or sys.stderr)
83
+ self.wrapped = wrapped
84
+
85
+ # should we reset colors to defaults after every .write()
86
+ self.autoreset = autoreset
87
+
88
+ # create the proxy wrapping our output stream
89
+ self.stream = StreamWrapper(wrapped, self)
90
+
91
+ on_windows = os.name == 'nt'
92
+ # We test if the WinAPI works, because even if we are on Windows
93
+ # we may be using a terminal that doesn't support the WinAPI
94
+ # (e.g. Cygwin Terminal). In this case it's up to the terminal
95
+ # to support the ANSI codes.
96
+ conversion_supported = on_windows and winapi_test()
97
+ try:
98
+ fd = wrapped.fileno()
99
+ except Exception:
100
+ fd = -1
101
+ system_has_native_ansi = not on_windows or enable_vt_processing(fd)
102
+ have_tty = not self.stream.closed and self.stream.isatty()
103
+ need_conversion = conversion_supported and not system_has_native_ansi
104
+
105
+ # should we strip ANSI sequences from our output?
106
+ if strip is None:
107
+ strip = need_conversion or not have_tty
108
+ self.strip = strip
109
+
110
+ # should we should convert ANSI sequences into win32 calls?
111
+ if convert is None:
112
+ convert = need_conversion and have_tty
113
+ self.convert = convert
114
+
115
+ # dict of ansi codes to win32 functions and parameters
116
+ self.win32_calls = self.get_win32_calls()
117
+
118
+ # are we wrapping stderr?
119
+ self.on_stderr = self.wrapped is sys.stderr
120
+
121
+ def should_wrap(self):
122
+ '''
123
+ True if this class is actually needed. If false, then the output
124
+ stream will not be affected, nor will win32 calls be issued, so
125
+ wrapping stdout is not actually required. This will generally be
126
+ False on non-Windows platforms, unless optional functionality like
127
+ autoreset has been requested using kwargs to init()
128
+ '''
129
+ return self.convert or self.strip or self.autoreset
130
+
131
+ def get_win32_calls(self):
132
+ if self.convert and winterm:
133
+ return {
134
+ AnsiStyle.RESET_ALL: (winterm.reset_all, ),
135
+ AnsiStyle.BRIGHT: (winterm.style, WinStyle.BRIGHT),
136
+ AnsiStyle.DIM: (winterm.style, WinStyle.NORMAL),
137
+ AnsiStyle.NORMAL: (winterm.style, WinStyle.NORMAL),
138
+ AnsiFore.BLACK: (winterm.fore, WinColor.BLACK),
139
+ AnsiFore.RED: (winterm.fore, WinColor.RED),
140
+ AnsiFore.GREEN: (winterm.fore, WinColor.GREEN),
141
+ AnsiFore.YELLOW: (winterm.fore, WinColor.YELLOW),
142
+ AnsiFore.BLUE: (winterm.fore, WinColor.BLUE),
143
+ AnsiFore.MAGENTA: (winterm.fore, WinColor.MAGENTA),
144
+ AnsiFore.CYAN: (winterm.fore, WinColor.CYAN),
145
+ AnsiFore.WHITE: (winterm.fore, WinColor.GREY),
146
+ AnsiFore.RESET: (winterm.fore, ),
147
+ AnsiFore.LIGHTBLACK_EX: (winterm.fore, WinColor.BLACK, True),
148
+ AnsiFore.LIGHTRED_EX: (winterm.fore, WinColor.RED, True),
149
+ AnsiFore.LIGHTGREEN_EX: (winterm.fore, WinColor.GREEN, True),
150
+ AnsiFore.LIGHTYELLOW_EX: (winterm.fore, WinColor.YELLOW, True),
151
+ AnsiFore.LIGHTBLUE_EX: (winterm.fore, WinColor.BLUE, True),
152
+ AnsiFore.LIGHTMAGENTA_EX: (winterm.fore, WinColor.MAGENTA, True),
153
+ AnsiFore.LIGHTCYAN_EX: (winterm.fore, WinColor.CYAN, True),
154
+ AnsiFore.LIGHTWHITE_EX: (winterm.fore, WinColor.GREY, True),
155
+ AnsiBack.BLACK: (winterm.back, WinColor.BLACK),
156
+ AnsiBack.RED: (winterm.back, WinColor.RED),
157
+ AnsiBack.GREEN: (winterm.back, WinColor.GREEN),
158
+ AnsiBack.YELLOW: (winterm.back, WinColor.YELLOW),
159
+ AnsiBack.BLUE: (winterm.back, WinColor.BLUE),
160
+ AnsiBack.MAGENTA: (winterm.back, WinColor.MAGENTA),
161
+ AnsiBack.CYAN: (winterm.back, WinColor.CYAN),
162
+ AnsiBack.WHITE: (winterm.back, WinColor.GREY),
163
+ AnsiBack.RESET: (winterm.back, ),
164
+ AnsiBack.LIGHTBLACK_EX: (winterm.back, WinColor.BLACK, True),
165
+ AnsiBack.LIGHTRED_EX: (winterm.back, WinColor.RED, True),
166
+ AnsiBack.LIGHTGREEN_EX: (winterm.back, WinColor.GREEN, True),
167
+ AnsiBack.LIGHTYELLOW_EX: (winterm.back, WinColor.YELLOW, True),
168
+ AnsiBack.LIGHTBLUE_EX: (winterm.back, WinColor.BLUE, True),
169
+ AnsiBack.LIGHTMAGENTA_EX: (winterm.back, WinColor.MAGENTA, True),
170
+ AnsiBack.LIGHTCYAN_EX: (winterm.back, WinColor.CYAN, True),
171
+ AnsiBack.LIGHTWHITE_EX: (winterm.back, WinColor.GREY, True),
172
+ }
173
+ return dict()
174
+
175
+ def write(self, text):
176
+ if self.strip or self.convert:
177
+ self.write_and_convert(text)
178
+ else:
179
+ self.wrapped.write(text)
180
+ self.wrapped.flush()
181
+ if self.autoreset:
182
+ self.reset_all()
183
+
184
+
185
+ def reset_all(self):
186
+ if self.convert:
187
+ self.call_win32('m', (0,))
188
+ elif not self.strip and not self.stream.closed:
189
+ self.wrapped.write(Style.RESET_ALL)
190
+
191
+
192
+ def write_and_convert(self, text):
193
+ '''
194
+ Write the given text to our wrapped stream, stripping any ANSI
195
+ sequences from the text, and optionally converting them into win32
196
+ calls.
197
+ '''
198
+ cursor = 0
199
+ text = self.convert_osc(text)
200
+ for match in self.ANSI_CSI_RE.finditer(text):
201
+ start, end = match.span()
202
+ self.write_plain_text(text, cursor, start)
203
+ self.convert_ansi(*match.groups())
204
+ cursor = end
205
+ self.write_plain_text(text, cursor, len(text))
206
+
207
+
208
+ def write_plain_text(self, text, start, end):
209
+ if start < end:
210
+ self.wrapped.write(text[start:end])
211
+ self.wrapped.flush()
212
+
213
+
214
+ def convert_ansi(self, paramstring, command):
215
+ if self.convert:
216
+ params = self.extract_params(command, paramstring)
217
+ self.call_win32(command, params)
218
+
219
+
220
+ def extract_params(self, command, paramstring):
221
+ if command in 'Hf':
222
+ params = tuple(int(p) if len(p) != 0 else 1 for p in paramstring.split(';'))
223
+ while len(params) < 2:
224
+ # defaults:
225
+ params = params + (1,)
226
+ else:
227
+ params = tuple(int(p) for p in paramstring.split(';') if len(p) != 0)
228
+ if len(params) == 0:
229
+ # defaults:
230
+ if command in 'JKm':
231
+ params = (0,)
232
+ elif command in 'ABCD':
233
+ params = (1,)
234
+
235
+ return params
236
+
237
+
238
+ def call_win32(self, command, params):
239
+ if command == 'm':
240
+ for param in params:
241
+ if param in self.win32_calls:
242
+ func_args = self.win32_calls[param]
243
+ func = func_args[0]
244
+ args = func_args[1:]
245
+ kwargs = dict(on_stderr=self.on_stderr)
246
+ func(*args, **kwargs)
247
+ elif command in 'J':
248
+ winterm.erase_screen(params[0], on_stderr=self.on_stderr)
249
+ elif command in 'K':
250
+ winterm.erase_line(params[0], on_stderr=self.on_stderr)
251
+ elif command in 'Hf': # cursor position - absolute
252
+ winterm.set_cursor_position(params, on_stderr=self.on_stderr)
253
+ elif command in 'ABCD': # cursor position - relative
254
+ n = params[0]
255
+ # A - up, B - down, C - forward, D - back
256
+ x, y = {'A': (0, -n), 'B': (0, n), 'C': (n, 0), 'D': (-n, 0)}[command]
257
+ winterm.cursor_adjust(x, y, on_stderr=self.on_stderr)
258
+
259
+
260
+ def convert_osc(self, text):
261
+ for match in self.ANSI_OSC_RE.finditer(text):
262
+ start, end = match.span()
263
+ text = text[:start] + text[end:]
264
+ paramstring, command = match.groups()
265
+ if command == BEL:
266
+ if paramstring.count(";") == 1:
267
+ params = paramstring.split(";")
268
+ # 0 - change title and icon (we will only change title)
269
+ # 1 - change icon (we don't support this)
270
+ # 2 - change title
271
+ if params[0] in '02':
272
+ winterm.set_title(params[1])
273
+ return text
274
+
275
+
276
+ def flush(self):
277
+ self.wrapped.flush()
pythonProject/.venv/Lib/site-packages/colorama/initialise.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
2
+ import atexit
3
+ import contextlib
4
+ import sys
5
+
6
+ from .ansitowin32 import AnsiToWin32
7
+
8
+
9
+ def _wipe_internal_state_for_tests():
10
+ global orig_stdout, orig_stderr
11
+ orig_stdout = None
12
+ orig_stderr = None
13
+
14
+ global wrapped_stdout, wrapped_stderr
15
+ wrapped_stdout = None
16
+ wrapped_stderr = None
17
+
18
+ global atexit_done
19
+ atexit_done = False
20
+
21
+ global fixed_windows_console
22
+ fixed_windows_console = False
23
+
24
+ try:
25
+ # no-op if it wasn't registered
26
+ atexit.unregister(reset_all)
27
+ except AttributeError:
28
+ # python 2: no atexit.unregister. Oh well, we did our best.
29
+ pass
30
+
31
+
32
+ def reset_all():
33
+ if AnsiToWin32 is not None: # Issue #74: objects might become None at exit
34
+ AnsiToWin32(orig_stdout).reset_all()
35
+
36
+
37
+ def init(autoreset=False, convert=None, strip=None, wrap=True):
38
+
39
+ if not wrap and any([autoreset, convert, strip]):
40
+ raise ValueError('wrap=False conflicts with any other arg=True')
41
+
42
+ global wrapped_stdout, wrapped_stderr
43
+ global orig_stdout, orig_stderr
44
+
45
+ orig_stdout = sys.stdout
46
+ orig_stderr = sys.stderr
47
+
48
+ if sys.stdout is None:
49
+ wrapped_stdout = None
50
+ else:
51
+ sys.stdout = wrapped_stdout = \
52
+ wrap_stream(orig_stdout, convert, strip, autoreset, wrap)
53
+ if sys.stderr is None:
54
+ wrapped_stderr = None
55
+ else:
56
+ sys.stderr = wrapped_stderr = \
57
+ wrap_stream(orig_stderr, convert, strip, autoreset, wrap)
58
+
59
+ global atexit_done
60
+ if not atexit_done:
61
+ atexit.register(reset_all)
62
+ atexit_done = True
63
+
64
+
65
+ def deinit():
66
+ if orig_stdout is not None:
67
+ sys.stdout = orig_stdout
68
+ if orig_stderr is not None:
69
+ sys.stderr = orig_stderr
70
+
71
+
72
+ def just_fix_windows_console():
73
+ global fixed_windows_console
74
+
75
+ if sys.platform != "win32":
76
+ return
77
+ if fixed_windows_console:
78
+ return
79
+ if wrapped_stdout is not None or wrapped_stderr is not None:
80
+ # Someone already ran init() and it did stuff, so we won't second-guess them
81
+ return
82
+
83
+ # On newer versions of Windows, AnsiToWin32.__init__ will implicitly enable the
84
+ # native ANSI support in the console as a side-effect. We only need to actually
85
+ # replace sys.stdout/stderr if we're in the old-style conversion mode.
86
+ new_stdout = AnsiToWin32(sys.stdout, convert=None, strip=None, autoreset=False)
87
+ if new_stdout.convert:
88
+ sys.stdout = new_stdout
89
+ new_stderr = AnsiToWin32(sys.stderr, convert=None, strip=None, autoreset=False)
90
+ if new_stderr.convert:
91
+ sys.stderr = new_stderr
92
+
93
+ fixed_windows_console = True
94
+
95
+ @contextlib.contextmanager
96
+ def colorama_text(*args, **kwargs):
97
+ init(*args, **kwargs)
98
+ try:
99
+ yield
100
+ finally:
101
+ deinit()
102
+
103
+
104
+ def reinit():
105
+ if wrapped_stdout is not None:
106
+ sys.stdout = wrapped_stdout
107
+ if wrapped_stderr is not None:
108
+ sys.stderr = wrapped_stderr
109
+
110
+
111
+ def wrap_stream(stream, convert, strip, autoreset, wrap):
112
+ if wrap:
113
+ wrapper = AnsiToWin32(stream,
114
+ convert=convert, strip=strip, autoreset=autoreset)
115
+ if wrapper.should_wrap():
116
+ stream = wrapper.stream
117
+ return stream
118
+
119
+
120
+ # Use this for initial setup as well, to reduce code duplication
121
+ _wipe_internal_state_for_tests()
pythonProject/.venv/Lib/site-packages/colorama/win32.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
2
+
3
+ # from winbase.h
4
+ STDOUT = -11
5
+ STDERR = -12
6
+
7
+ ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004
8
+
9
+ try:
10
+ import ctypes
11
+ from ctypes import LibraryLoader
12
+ windll = LibraryLoader(ctypes.WinDLL)
13
+ from ctypes import wintypes
14
+ except (AttributeError, ImportError):
15
+ windll = None
16
+ SetConsoleTextAttribute = lambda *_: None
17
+ winapi_test = lambda *_: None
18
+ else:
19
+ from ctypes import byref, Structure, c_char, POINTER
20
+
21
+ COORD = wintypes._COORD
22
+
23
+ class CONSOLE_SCREEN_BUFFER_INFO(Structure):
24
+ """struct in wincon.h."""
25
+ _fields_ = [
26
+ ("dwSize", COORD),
27
+ ("dwCursorPosition", COORD),
28
+ ("wAttributes", wintypes.WORD),
29
+ ("srWindow", wintypes.SMALL_RECT),
30
+ ("dwMaximumWindowSize", COORD),
31
+ ]
32
+ def __str__(self):
33
+ return '(%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d)' % (
34
+ self.dwSize.Y, self.dwSize.X
35
+ , self.dwCursorPosition.Y, self.dwCursorPosition.X
36
+ , self.wAttributes
37
+ , self.srWindow.Top, self.srWindow.Left, self.srWindow.Bottom, self.srWindow.Right
38
+ , self.dwMaximumWindowSize.Y, self.dwMaximumWindowSize.X
39
+ )
40
+
41
+ _GetStdHandle = windll.kernel32.GetStdHandle
42
+ _GetStdHandle.argtypes = [
43
+ wintypes.DWORD,
44
+ ]
45
+ _GetStdHandle.restype = wintypes.HANDLE
46
+
47
+ _GetConsoleScreenBufferInfo = windll.kernel32.GetConsoleScreenBufferInfo
48
+ _GetConsoleScreenBufferInfo.argtypes = [
49
+ wintypes.HANDLE,
50
+ POINTER(CONSOLE_SCREEN_BUFFER_INFO),
51
+ ]
52
+ _GetConsoleScreenBufferInfo.restype = wintypes.BOOL
53
+
54
+ _SetConsoleTextAttribute = windll.kernel32.SetConsoleTextAttribute
55
+ _SetConsoleTextAttribute.argtypes = [
56
+ wintypes.HANDLE,
57
+ wintypes.WORD,
58
+ ]
59
+ _SetConsoleTextAttribute.restype = wintypes.BOOL
60
+
61
+ _SetConsoleCursorPosition = windll.kernel32.SetConsoleCursorPosition
62
+ _SetConsoleCursorPosition.argtypes = [
63
+ wintypes.HANDLE,
64
+ COORD,
65
+ ]
66
+ _SetConsoleCursorPosition.restype = wintypes.BOOL
67
+
68
+ _FillConsoleOutputCharacterA = windll.kernel32.FillConsoleOutputCharacterA
69
+ _FillConsoleOutputCharacterA.argtypes = [
70
+ wintypes.HANDLE,
71
+ c_char,
72
+ wintypes.DWORD,
73
+ COORD,
74
+ POINTER(wintypes.DWORD),
75
+ ]
76
+ _FillConsoleOutputCharacterA.restype = wintypes.BOOL
77
+
78
+ _FillConsoleOutputAttribute = windll.kernel32.FillConsoleOutputAttribute
79
+ _FillConsoleOutputAttribute.argtypes = [
80
+ wintypes.HANDLE,
81
+ wintypes.WORD,
82
+ wintypes.DWORD,
83
+ COORD,
84
+ POINTER(wintypes.DWORD),
85
+ ]
86
+ _FillConsoleOutputAttribute.restype = wintypes.BOOL
87
+
88
+ _SetConsoleTitleW = windll.kernel32.SetConsoleTitleW
89
+ _SetConsoleTitleW.argtypes = [
90
+ wintypes.LPCWSTR
91
+ ]
92
+ _SetConsoleTitleW.restype = wintypes.BOOL
93
+
94
+ _GetConsoleMode = windll.kernel32.GetConsoleMode
95
+ _GetConsoleMode.argtypes = [
96
+ wintypes.HANDLE,
97
+ POINTER(wintypes.DWORD)
98
+ ]
99
+ _GetConsoleMode.restype = wintypes.BOOL
100
+
101
+ _SetConsoleMode = windll.kernel32.SetConsoleMode
102
+ _SetConsoleMode.argtypes = [
103
+ wintypes.HANDLE,
104
+ wintypes.DWORD
105
+ ]
106
+ _SetConsoleMode.restype = wintypes.BOOL
107
+
108
+ def _winapi_test(handle):
109
+ csbi = CONSOLE_SCREEN_BUFFER_INFO()
110
+ success = _GetConsoleScreenBufferInfo(
111
+ handle, byref(csbi))
112
+ return bool(success)
113
+
114
+ def winapi_test():
115
+ return any(_winapi_test(h) for h in
116
+ (_GetStdHandle(STDOUT), _GetStdHandle(STDERR)))
117
+
118
+ def GetConsoleScreenBufferInfo(stream_id=STDOUT):
119
+ handle = _GetStdHandle(stream_id)
120
+ csbi = CONSOLE_SCREEN_BUFFER_INFO()
121
+ success = _GetConsoleScreenBufferInfo(
122
+ handle, byref(csbi))
123
+ return csbi
124
+
125
+ def SetConsoleTextAttribute(stream_id, attrs):
126
+ handle = _GetStdHandle(stream_id)
127
+ return _SetConsoleTextAttribute(handle, attrs)
128
+
129
+ def SetConsoleCursorPosition(stream_id, position, adjust=True):
130
+ position = COORD(*position)
131
+ # If the position is out of range, do nothing.
132
+ if position.Y <= 0 or position.X <= 0:
133
+ return
134
+ # Adjust for Windows' SetConsoleCursorPosition:
135
+ # 1. being 0-based, while ANSI is 1-based.
136
+ # 2. expecting (x,y), while ANSI uses (y,x).
137
+ adjusted_position = COORD(position.Y - 1, position.X - 1)
138
+ if adjust:
139
+ # Adjust for viewport's scroll position
140
+ sr = GetConsoleScreenBufferInfo(STDOUT).srWindow
141
+ adjusted_position.Y += sr.Top
142
+ adjusted_position.X += sr.Left
143
+ # Resume normal processing
144
+ handle = _GetStdHandle(stream_id)
145
+ return _SetConsoleCursorPosition(handle, adjusted_position)
146
+
147
+ def FillConsoleOutputCharacter(stream_id, char, length, start):
148
+ handle = _GetStdHandle(stream_id)
149
+ char = c_char(char.encode())
150
+ length = wintypes.DWORD(length)
151
+ num_written = wintypes.DWORD(0)
152
+ # Note that this is hard-coded for ANSI (vs wide) bytes.
153
+ success = _FillConsoleOutputCharacterA(
154
+ handle, char, length, start, byref(num_written))
155
+ return num_written.value
156
+
157
+ def FillConsoleOutputAttribute(stream_id, attr, length, start):
158
+ ''' FillConsoleOutputAttribute( hConsole, csbi.wAttributes, dwConSize, coordScreen, &cCharsWritten )'''
159
+ handle = _GetStdHandle(stream_id)
160
+ attribute = wintypes.WORD(attr)
161
+ length = wintypes.DWORD(length)
162
+ num_written = wintypes.DWORD(0)
163
+ # Note that this is hard-coded for ANSI (vs wide) bytes.
164
+ return _FillConsoleOutputAttribute(
165
+ handle, attribute, length, start, byref(num_written))
166
+
167
+ def SetConsoleTitle(title):
168
+ return _SetConsoleTitleW(title)
169
+
170
+ def GetConsoleMode(handle):
171
+ mode = wintypes.DWORD()
172
+ success = _GetConsoleMode(handle, byref(mode))
173
+ if not success:
174
+ raise ctypes.WinError()
175
+ return mode.value
176
+
177
+ def SetConsoleMode(handle, mode):
178
+ success = _SetConsoleMode(handle, mode)
179
+ if not success:
180
+ raise ctypes.WinError()
pythonProject/.venv/Lib/site-packages/diffusers/callbacks.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from .configuration_utils import ConfigMixin, register_to_config
4
+ from .utils import CONFIG_NAME
5
+
6
+
7
+ class PipelineCallback(ConfigMixin):
8
+ """
9
+ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
10
+ custom callbacks and ensures that all callbacks have a consistent interface.
11
+
12
+ Please implement the following:
13
+ `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
14
+ include
15
+ variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
16
+ `callback_fn`: This method defines the core functionality of your callback.
17
+ """
18
+
19
+ config_name = CONFIG_NAME
20
+
21
+ @register_to_config
22
+ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
23
+ super().__init__()
24
+
25
+ if (cutoff_step_ratio is None and cutoff_step_index is None) or (
26
+ cutoff_step_ratio is not None and cutoff_step_index is not None
27
+ ):
28
+ raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
29
+
30
+ if cutoff_step_ratio is not None and (
31
+ not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
32
+ ):
33
+ raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
34
+
35
+ @property
36
+ def tensor_inputs(self) -> List[str]:
37
+ raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
38
+
39
+ def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
40
+ raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
41
+
42
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
43
+ return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
44
+
45
+
46
+ class MultiPipelineCallbacks:
47
+ """
48
+ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
49
+ provides a unified interface for calling all of them.
50
+ """
51
+
52
+ def __init__(self, callbacks: List[PipelineCallback]):
53
+ self.callbacks = callbacks
54
+
55
+ @property
56
+ def tensor_inputs(self) -> List[str]:
57
+ return [input for callback in self.callbacks for input in callback.tensor_inputs]
58
+
59
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
60
+ """
61
+ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
62
+ """
63
+ for callback in self.callbacks:
64
+ callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
65
+
66
+ return callback_kwargs
67
+
68
+
69
+ class SDCFGCutoffCallback(PipelineCallback):
70
+ """
71
+ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
72
+ `cutoff_step_index`), this callback will disable the CFG.
73
+
74
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
75
+ """
76
+
77
+ tensor_inputs = ["prompt_embeds"]
78
+
79
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
80
+ cutoff_step_ratio = self.config.cutoff_step_ratio
81
+ cutoff_step_index = self.config.cutoff_step_index
82
+
83
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
84
+ cutoff_step = (
85
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
86
+ )
87
+
88
+ if step_index == cutoff_step:
89
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
90
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
91
+
92
+ pipeline._guidance_scale = 0.0
93
+
94
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
95
+ return callback_kwargs
96
+
97
+
98
+ class SDXLCFGCutoffCallback(PipelineCallback):
99
+ """
100
+ Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
101
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
102
+
103
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104
+ """
105
+
106
+ tensor_inputs = [
107
+ "prompt_embeds",
108
+ "add_text_embeds",
109
+ "add_time_ids",
110
+ ]
111
+
112
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
113
+ cutoff_step_ratio = self.config.cutoff_step_ratio
114
+ cutoff_step_index = self.config.cutoff_step_index
115
+
116
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
117
+ cutoff_step = (
118
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
119
+ )
120
+
121
+ if step_index == cutoff_step:
122
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
123
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
124
+
125
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
126
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
127
+
128
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
129
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
130
+
131
+ pipeline._guidance_scale = 0.0
132
+
133
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
134
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
135
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
136
+
137
+ return callback_kwargs
138
+
139
+
140
+ class SDXLControlnetCFGCutoffCallback(PipelineCallback):
141
+ """
142
+ Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
143
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
144
+
145
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
146
+ """
147
+
148
+ tensor_inputs = [
149
+ "prompt_embeds",
150
+ "add_text_embeds",
151
+ "add_time_ids",
152
+ "image",
153
+ ]
154
+
155
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
156
+ cutoff_step_ratio = self.config.cutoff_step_ratio
157
+ cutoff_step_index = self.config.cutoff_step_index
158
+
159
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
160
+ cutoff_step = (
161
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
162
+ )
163
+
164
+ if step_index == cutoff_step:
165
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
166
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
167
+
168
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
169
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
170
+
171
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
172
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
173
+
174
+ # For Controlnet
175
+ image = callback_kwargs[self.tensor_inputs[3]]
176
+ image = image[-1:]
177
+
178
+ pipeline._guidance_scale = 0.0
179
+
180
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
181
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
182
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
183
+ callback_kwargs[self.tensor_inputs[3]] = image
184
+
185
+ return callback_kwargs
186
+
187
+
188
+ class IPAdapterScaleCutoffCallback(PipelineCallback):
189
+ """
190
+ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
191
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
192
+
193
+ Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
194
+ """
195
+
196
+ tensor_inputs = []
197
+
198
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
199
+ cutoff_step_ratio = self.config.cutoff_step_ratio
200
+ cutoff_step_index = self.config.cutoff_step_index
201
+
202
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
203
+ cutoff_step = (
204
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
205
+ )
206
+
207
+ if step_index == cutoff_step:
208
+ pipeline.set_ip_adapter_scale(0.0)
209
+ return callback_kwargs
210
+
211
+
212
+ class SD3CFGCutoffCallback(PipelineCallback):
213
+ """
214
+ Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
215
+ `cutoff_step_index`), this callback will disable the CFG.
216
+
217
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
218
+ """
219
+
220
+ tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
221
+
222
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
223
+ cutoff_step_ratio = self.config.cutoff_step_ratio
224
+ cutoff_step_index = self.config.cutoff_step_index
225
+
226
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
227
+ cutoff_step = (
228
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
229
+ )
230
+
231
+ if step_index == cutoff_step:
232
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
233
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
234
+
235
+ pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
236
+ pooled_prompt_embeds = pooled_prompt_embeds[
237
+ -1:
238
+ ] # "-1" denotes the embeddings for conditional pooled text tokens.
239
+
240
+ pipeline._guidance_scale = 0.0
241
+
242
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
243
+ callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
244
+ return callback_kwargs
pythonProject/.venv/Lib/site-packages/diffusers/configuration_utils.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ConfigMixin base class and utilities."""
17
+
18
+ import dataclasses
19
+ import functools
20
+ import importlib
21
+ import inspect
22
+ import json
23
+ import os
24
+ import re
25
+ from collections import OrderedDict
26
+ from pathlib import Path
27
+ from typing import Any, Dict, Optional, Tuple, Union
28
+
29
+ import numpy as np
30
+ from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
31
+ from huggingface_hub.utils import (
32
+ EntryNotFoundError,
33
+ RepositoryNotFoundError,
34
+ RevisionNotFoundError,
35
+ validate_hf_hub_args,
36
+ )
37
+ from requests import HTTPError
38
+ from typing_extensions import Self
39
+
40
+ from . import __version__
41
+ from .utils import (
42
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
43
+ DummyObject,
44
+ deprecate,
45
+ extract_commit_hash,
46
+ http_user_agent,
47
+ logging,
48
+ )
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
54
+
55
+
56
+ class FrozenDict(OrderedDict):
57
+ def __init__(self, *args, **kwargs):
58
+ super().__init__(*args, **kwargs)
59
+
60
+ for key, value in self.items():
61
+ setattr(self, key, value)
62
+
63
+ self.__frozen = True
64
+
65
+ def __delitem__(self, *args, **kwargs):
66
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
67
+
68
+ def setdefault(self, *args, **kwargs):
69
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
70
+
71
+ def pop(self, *args, **kwargs):
72
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
73
+
74
+ def update(self, *args, **kwargs):
75
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
76
+
77
+ def __setattr__(self, name, value):
78
+ if hasattr(self, "__frozen") and self.__frozen:
79
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
80
+ super().__setattr__(name, value)
81
+
82
+ def __setitem__(self, name, value):
83
+ if hasattr(self, "__frozen") and self.__frozen:
84
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
85
+ super().__setitem__(name, value)
86
+
87
+
88
+ class ConfigMixin:
89
+ r"""
90
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
91
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
92
+ saving classes that inherit from [`ConfigMixin`].
93
+
94
+ Class attributes:
95
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
96
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
97
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
98
+ overridden by subclass).
99
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
100
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
101
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
102
+ subclass).
103
+ """
104
+
105
+ config_name = None
106
+ ignore_for_config = []
107
+ has_compatibles = False
108
+
109
+ _deprecated_kwargs = []
110
+
111
+ def register_to_config(self, **kwargs):
112
+ if self.config_name is None:
113
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
114
+ # Special case for `kwargs` used in deprecation warning added to schedulers
115
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
116
+ # or solve in a more general way.
117
+ kwargs.pop("kwargs", None)
118
+
119
+ if not hasattr(self, "_internal_dict"):
120
+ internal_dict = kwargs
121
+ else:
122
+ previous_dict = dict(self._internal_dict)
123
+ internal_dict = {**self._internal_dict, **kwargs}
124
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
125
+
126
+ self._internal_dict = FrozenDict(internal_dict)
127
+
128
+ def __getattr__(self, name: str) -> Any:
129
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
130
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
131
+
132
+ This function is mostly copied from PyTorch's __getattr__ overwrite:
133
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
134
+ """
135
+
136
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
137
+ is_attribute = name in self.__dict__
138
+
139
+ if is_in_config and not is_attribute:
140
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
141
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
142
+ return self._internal_dict[name]
143
+
144
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
145
+
146
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
147
+ """
148
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
149
+ [`~ConfigMixin.from_config`] class method.
150
+
151
+ Args:
152
+ save_directory (`str` or `os.PathLike`):
153
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
154
+ push_to_hub (`bool`, *optional*, defaults to `False`):
155
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
156
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
157
+ namespace).
158
+ kwargs (`Dict[str, Any]`, *optional*):
159
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
160
+ """
161
+ if os.path.isfile(save_directory):
162
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
163
+
164
+ os.makedirs(save_directory, exist_ok=True)
165
+
166
+ # If we save using the predefined names, we can load using `from_config`
167
+ output_config_file = os.path.join(save_directory, self.config_name)
168
+
169
+ self.to_json_file(output_config_file)
170
+ logger.info(f"Configuration saved in {output_config_file}")
171
+
172
+ if push_to_hub:
173
+ commit_message = kwargs.pop("commit_message", None)
174
+ private = kwargs.pop("private", None)
175
+ create_pr = kwargs.pop("create_pr", False)
176
+ token = kwargs.pop("token", None)
177
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
178
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
179
+ subfolder = kwargs.pop("subfolder", None)
180
+
181
+ self._upload_folder(
182
+ save_directory,
183
+ repo_id,
184
+ token=token,
185
+ commit_message=commit_message,
186
+ create_pr=create_pr,
187
+ subfolder=subfolder,
188
+ )
189
+
190
+ @classmethod
191
+ def from_config(
192
+ cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
193
+ ) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
194
+ r"""
195
+ Instantiate a Python class from a config dictionary.
196
+
197
+ Parameters:
198
+ config (`Dict[str, Any]`):
199
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
200
+ files of compatible classes.
201
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
202
+ Whether kwargs that are not consumed by the Python class should be returned or not.
203
+ kwargs (remaining dictionary of keyword arguments, *optional*):
204
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
205
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
206
+ overwrite the same named arguments in `config`.
207
+
208
+ Returns:
209
+ [`ModelMixin`] or [`SchedulerMixin`]:
210
+ A model or scheduler object instantiated from a config dictionary.
211
+
212
+ Examples:
213
+
214
+ ```python
215
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
216
+
217
+ >>> # Download scheduler from huggingface.co and cache.
218
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
219
+
220
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
221
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
222
+
223
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
224
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
225
+ ```
226
+ """
227
+ # <===== TO BE REMOVED WITH DEPRECATION
228
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
229
+ if "pretrained_model_name_or_path" in kwargs:
230
+ config = kwargs.pop("pretrained_model_name_or_path")
231
+
232
+ if config is None:
233
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
234
+ # ======>
235
+
236
+ if not isinstance(config, dict):
237
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
238
+ if "Scheduler" in cls.__name__:
239
+ deprecation_message += (
240
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
241
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
242
+ " be removed in v1.0.0."
243
+ )
244
+ elif "Model" in cls.__name__:
245
+ deprecation_message += (
246
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
247
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
248
+ " instead. This functionality will be removed in v1.0.0."
249
+ )
250
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
251
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
252
+
253
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
254
+
255
+ # Allow dtype to be specified on initialization
256
+ if "dtype" in unused_kwargs:
257
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
258
+
259
+ # add possible deprecated kwargs
260
+ for deprecated_kwarg in cls._deprecated_kwargs:
261
+ if deprecated_kwarg in unused_kwargs:
262
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
263
+
264
+ # Return model and optionally state and/or unused_kwargs
265
+ model = cls(**init_dict)
266
+
267
+ # make sure to also save config parameters that might be used for compatible classes
268
+ # update _class_name
269
+ if "_class_name" in hidden_dict:
270
+ hidden_dict["_class_name"] = cls.__name__
271
+
272
+ model.register_to_config(**hidden_dict)
273
+
274
+ # add hidden kwargs of compatible classes to unused_kwargs
275
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
276
+
277
+ if return_unused_kwargs:
278
+ return (model, unused_kwargs)
279
+ else:
280
+ return model
281
+
282
+ @classmethod
283
+ def get_config_dict(cls, *args, **kwargs):
284
+ deprecation_message = (
285
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
286
+ " removed in version v1.0.0"
287
+ )
288
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
289
+ return cls.load_config(*args, **kwargs)
290
+
291
+ @classmethod
292
+ @validate_hf_hub_args
293
+ def load_config(
294
+ cls,
295
+ pretrained_model_name_or_path: Union[str, os.PathLike],
296
+ return_unused_kwargs=False,
297
+ return_commit_hash=False,
298
+ **kwargs,
299
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
300
+ r"""
301
+ Load a model or scheduler configuration.
302
+
303
+ Parameters:
304
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
305
+ Can be either:
306
+
307
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
308
+ the Hub.
309
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
310
+ [`~ConfigMixin.save_config`].
311
+
312
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
313
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
314
+ is not used.
315
+ force_download (`bool`, *optional*, defaults to `False`):
316
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
317
+ cached versions if they exist.
318
+ proxies (`Dict[str, str]`, *optional*):
319
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
320
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
321
+ output_loading_info(`bool`, *optional*, defaults to `False`):
322
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
323
+ local_files_only (`bool`, *optional*, defaults to `False`):
324
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
325
+ won't be downloaded from the Hub.
326
+ token (`str` or *bool*, *optional*):
327
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
328
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
329
+ revision (`str`, *optional*, defaults to `"main"`):
330
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
331
+ allowed by Git.
332
+ subfolder (`str`, *optional*, defaults to `""`):
333
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
334
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
335
+ Whether unused keyword arguments of the config are returned.
336
+ return_commit_hash (`bool`, *optional*, defaults to `False):
337
+ Whether the `commit_hash` of the loaded configuration are returned.
338
+
339
+ Returns:
340
+ `dict`:
341
+ A dictionary of all the parameters stored in a JSON configuration file.
342
+
343
+ """
344
+ cache_dir = kwargs.pop("cache_dir", None)
345
+ local_dir = kwargs.pop("local_dir", None)
346
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
347
+ force_download = kwargs.pop("force_download", False)
348
+ proxies = kwargs.pop("proxies", None)
349
+ token = kwargs.pop("token", None)
350
+ local_files_only = kwargs.pop("local_files_only", False)
351
+ revision = kwargs.pop("revision", None)
352
+ _ = kwargs.pop("mirror", None)
353
+ subfolder = kwargs.pop("subfolder", None)
354
+ user_agent = kwargs.pop("user_agent", {})
355
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
356
+
357
+ user_agent = {**user_agent, "file_type": "config"}
358
+ user_agent = http_user_agent(user_agent)
359
+
360
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
361
+
362
+ if cls.config_name is None:
363
+ raise ValueError(
364
+ "`self.config_name` is not defined. Note that one should not load a config from "
365
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
366
+ )
367
+ # Custom path for now
368
+ if dduf_entries:
369
+ if subfolder is not None:
370
+ raise ValueError(
371
+ "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
372
+ "Please check the DDUF structure"
373
+ )
374
+ config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
375
+ elif os.path.isfile(pretrained_model_name_or_path):
376
+ config_file = pretrained_model_name_or_path
377
+ elif os.path.isdir(pretrained_model_name_or_path):
378
+ if subfolder is not None and os.path.isfile(
379
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
380
+ ):
381
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
382
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
383
+ # Load from a PyTorch checkpoint
384
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
385
+ else:
386
+ raise EnvironmentError(
387
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
388
+ )
389
+ else:
390
+ try:
391
+ # Load from URL or cache if already cached
392
+ config_file = hf_hub_download(
393
+ pretrained_model_name_or_path,
394
+ filename=cls.config_name,
395
+ cache_dir=cache_dir,
396
+ force_download=force_download,
397
+ proxies=proxies,
398
+ local_files_only=local_files_only,
399
+ token=token,
400
+ user_agent=user_agent,
401
+ subfolder=subfolder,
402
+ revision=revision,
403
+ local_dir=local_dir,
404
+ local_dir_use_symlinks=local_dir_use_symlinks,
405
+ )
406
+ except RepositoryNotFoundError:
407
+ raise EnvironmentError(
408
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
409
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
410
+ " token having permission to this repo with `token` or log in with `hf auth login`."
411
+ )
412
+ except RevisionNotFoundError:
413
+ raise EnvironmentError(
414
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
415
+ " this model name. Check the model page at"
416
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
417
+ )
418
+ except EntryNotFoundError:
419
+ raise EnvironmentError(
420
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
421
+ )
422
+ except HTTPError as err:
423
+ raise EnvironmentError(
424
+ "There was a specific connection error when trying to load"
425
+ f" {pretrained_model_name_or_path}:\n{err}"
426
+ )
427
+ except ValueError:
428
+ raise EnvironmentError(
429
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
430
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
431
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
432
+ " run the library in offline mode at"
433
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
434
+ )
435
+ except EnvironmentError:
436
+ raise EnvironmentError(
437
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
438
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
439
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
440
+ f"containing a {cls.config_name} file"
441
+ )
442
+ try:
443
+ config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)
444
+
445
+ commit_hash = extract_commit_hash(config_file)
446
+ except (json.JSONDecodeError, UnicodeDecodeError):
447
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
448
+
449
+ if not (return_unused_kwargs or return_commit_hash):
450
+ return config_dict
451
+
452
+ outputs = (config_dict,)
453
+
454
+ if return_unused_kwargs:
455
+ outputs += (kwargs,)
456
+
457
+ if return_commit_hash:
458
+ outputs += (commit_hash,)
459
+
460
+ return outputs
461
+
462
+ @staticmethod
463
+ def _get_init_keys(input_class):
464
+ return set(dict(inspect.signature(input_class.__init__).parameters).keys())
465
+
466
+ @classmethod
467
+ def extract_init_dict(cls, config_dict, **kwargs):
468
+ # Skip keys that were not present in the original config, so default __init__ values were used
469
+ used_defaults = config_dict.get("_use_default_values", [])
470
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
471
+
472
+ # 0. Copy origin config dict
473
+ original_dict = dict(config_dict.items())
474
+
475
+ # 1. Retrieve expected config attributes from __init__ signature
476
+ expected_keys = cls._get_init_keys(cls)
477
+ expected_keys.remove("self")
478
+ # remove general kwargs if present in dict
479
+ if "kwargs" in expected_keys:
480
+ expected_keys.remove("kwargs")
481
+ # remove flax internal keys
482
+ if hasattr(cls, "_flax_internal_args"):
483
+ for arg in cls._flax_internal_args:
484
+ expected_keys.remove(arg)
485
+
486
+ # 2. Remove attributes that cannot be expected from expected config attributes
487
+ # remove keys to be ignored
488
+ if len(cls.ignore_for_config) > 0:
489
+ expected_keys = expected_keys - set(cls.ignore_for_config)
490
+
491
+ # load diffusers library to import compatible and original scheduler
492
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
493
+
494
+ if cls.has_compatibles:
495
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
496
+ else:
497
+ compatible_classes = []
498
+
499
+ expected_keys_comp_cls = set()
500
+ for c in compatible_classes:
501
+ expected_keys_c = cls._get_init_keys(c)
502
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
503
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
504
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
505
+
506
+ # remove attributes from orig class that cannot be expected
507
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
508
+ if (
509
+ isinstance(orig_cls_name, str)
510
+ and orig_cls_name != cls.__name__
511
+ and hasattr(diffusers_library, orig_cls_name)
512
+ ):
513
+ orig_cls = getattr(diffusers_library, orig_cls_name)
514
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
515
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
516
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
517
+ raise ValueError(
518
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
519
+ )
520
+
521
+ # remove private attributes
522
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
523
+
524
+ # remove quantization_config
525
+ config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"}
526
+
527
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
528
+ init_dict = {}
529
+ for key in expected_keys:
530
+ # if config param is passed to kwarg and is present in config dict
531
+ # it should overwrite existing config dict key
532
+ if key in kwargs and key in config_dict:
533
+ config_dict[key] = kwargs.pop(key)
534
+
535
+ if key in kwargs:
536
+ # overwrite key
537
+ init_dict[key] = kwargs.pop(key)
538
+ elif key in config_dict:
539
+ # use value from config dict
540
+ init_dict[key] = config_dict.pop(key)
541
+
542
+ # 4. Give nice warning if unexpected values have been passed
543
+ if len(config_dict) > 0:
544
+ logger.warning(
545
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
546
+ "but are not expected and will be ignored. Please verify your "
547
+ f"{cls.config_name} configuration file."
548
+ )
549
+
550
+ # 5. Give nice info if config attributes are initialized to default because they have not been passed
551
+ passed_keys = set(init_dict.keys())
552
+ if len(expected_keys - passed_keys) > 0:
553
+ logger.info(
554
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
555
+ )
556
+
557
+ # 6. Define unused keyword arguments
558
+ unused_kwargs = {**config_dict, **kwargs}
559
+
560
+ # 7. Define "hidden" config parameters that were saved for compatible classes
561
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
562
+
563
+ return init_dict, unused_kwargs, hidden_config_dict
564
+
565
+ @classmethod
566
+ def _dict_from_json_file(
567
+ cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
568
+ ):
569
+ if dduf_entries:
570
+ text = dduf_entries[json_file].read_text()
571
+ else:
572
+ with open(json_file, "r", encoding="utf-8") as reader:
573
+ text = reader.read()
574
+ return json.loads(text)
575
+
576
+ def __repr__(self):
577
+ return f"{self.__class__.__name__} {self.to_json_string()}"
578
+
579
+ @property
580
+ def config(self) -> Dict[str, Any]:
581
+ """
582
+ Returns the config of the class as a frozen dictionary
583
+
584
+ Returns:
585
+ `Dict[str, Any]`: Config of the class.
586
+ """
587
+ return self._internal_dict
588
+
589
+ def to_json_string(self) -> str:
590
+ """
591
+ Serializes the configuration instance to a JSON string.
592
+
593
+ Returns:
594
+ `str`:
595
+ String containing all the attributes that make up the configuration instance in JSON format.
596
+ """
597
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
598
+ config_dict["_class_name"] = self.__class__.__name__
599
+ config_dict["_diffusers_version"] = __version__
600
+
601
+ def to_json_saveable(value):
602
+ if isinstance(value, np.ndarray):
603
+ value = value.tolist()
604
+ elif isinstance(value, Path):
605
+ value = value.as_posix()
606
+ elif hasattr(value, "to_dict") and callable(value.to_dict):
607
+ value = value.to_dict()
608
+ elif isinstance(value, list):
609
+ value = [to_json_saveable(v) for v in value]
610
+ return value
611
+
612
+ if "quantization_config" in config_dict:
613
+ config_dict["quantization_config"] = (
614
+ config_dict.quantization_config.to_dict()
615
+ if not isinstance(config_dict.quantization_config, dict)
616
+ else config_dict.quantization_config
617
+ )
618
+
619
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
620
+ # Don't save "_ignore_files" or "_use_default_values"
621
+ config_dict.pop("_ignore_files", None)
622
+ config_dict.pop("_use_default_values", None)
623
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
624
+ _ = config_dict.pop("_pre_quantization_dtype", None)
625
+
626
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
627
+
628
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
629
+ """
630
+ Save the configuration instance's parameters to a JSON file.
631
+
632
+ Args:
633
+ json_file_path (`str` or `os.PathLike`):
634
+ Path to the JSON file to save a configuration instance's parameters.
635
+ """
636
+ with open(json_file_path, "w", encoding="utf-8") as writer:
637
+ writer.write(self.to_json_string())
638
+
639
+ @classmethod
640
+ def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
641
+ # paths inside a DDUF file must always be "/"
642
+ config_file = (
643
+ cls.config_name
644
+ if pretrained_model_name_or_path == ""
645
+ else "/".join([pretrained_model_name_or_path, cls.config_name])
646
+ )
647
+ if config_file not in dduf_entries:
648
+ raise ValueError(
649
+ f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
650
+ )
651
+ return config_file
652
+
653
+
654
+ def register_to_config(init):
655
+ r"""
656
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
657
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
658
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
659
+
660
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
661
+ """
662
+
663
+ @functools.wraps(init)
664
+ def inner_init(self, *args, **kwargs):
665
+ # Ignore private kwargs in the init.
666
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
667
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
668
+ if not isinstance(self, ConfigMixin):
669
+ raise RuntimeError(
670
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
671
+ "not inherit from `ConfigMixin`."
672
+ )
673
+
674
+ ignore = getattr(self, "ignore_for_config", [])
675
+ # Get positional arguments aligned with kwargs
676
+ new_kwargs = {}
677
+ signature = inspect.signature(init)
678
+ parameters = {
679
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
680
+ }
681
+ for arg, name in zip(args, parameters.keys()):
682
+ new_kwargs[name] = arg
683
+
684
+ # Then add all kwargs
685
+ new_kwargs.update(
686
+ {
687
+ k: init_kwargs.get(k, default)
688
+ for k, default in parameters.items()
689
+ if k not in ignore and k not in new_kwargs
690
+ }
691
+ )
692
+
693
+ # Take note of the parameters that were not present in the loaded config
694
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
695
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
696
+
697
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
698
+ getattr(self, "register_to_config")(**new_kwargs)
699
+ init(self, *args, **init_kwargs)
700
+
701
+ return inner_init
702
+
703
+
704
+ def flax_register_to_config(cls):
705
+ original_init = cls.__init__
706
+
707
+ @functools.wraps(original_init)
708
+ def init(self, *args, **kwargs):
709
+ if not isinstance(self, ConfigMixin):
710
+ raise RuntimeError(
711
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
712
+ "not inherit from `ConfigMixin`."
713
+ )
714
+
715
+ # Ignore private kwargs in the init. Retrieve all passed attributes
716
+ init_kwargs = dict(kwargs.items())
717
+
718
+ # Retrieve default values
719
+ fields = dataclasses.fields(self)
720
+ default_kwargs = {}
721
+ for field in fields:
722
+ # ignore flax specific attributes
723
+ if field.name in self._flax_internal_args:
724
+ continue
725
+ if type(field.default) == dataclasses._MISSING_TYPE:
726
+ default_kwargs[field.name] = None
727
+ else:
728
+ default_kwargs[field.name] = getattr(self, field.name)
729
+
730
+ # Make sure init_kwargs override default kwargs
731
+ new_kwargs = {**default_kwargs, **init_kwargs}
732
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
733
+ if "dtype" in new_kwargs:
734
+ new_kwargs.pop("dtype")
735
+
736
+ # Get positional arguments aligned with kwargs
737
+ for i, arg in enumerate(args):
738
+ name = fields[i].name
739
+ new_kwargs[name] = arg
740
+
741
+ # Take note of the parameters that were not present in the loaded config
742
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
743
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
744
+
745
+ getattr(self, "register_to_config")(**new_kwargs)
746
+ original_init(self, *args, **kwargs)
747
+
748
+ cls.__init__ = init
749
+ return cls
750
+
751
+
752
+ class LegacyConfigMixin(ConfigMixin):
753
+ r"""
754
+ A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
755
+ pipeline-specific classes (like `DiTTransformer2DModel`).
756
+ """
757
+
758
+ @classmethod
759
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
760
+ # To prevent dependency import problem.
761
+ from .models.model_loading_utils import _fetch_remapped_cls_from_config
762
+
763
+ # resolve remapping
764
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
765
+
766
+ if remapped_class is cls:
767
+ return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
768
+ else:
769
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dependency_versions_table import deps
16
+ from .utils.versions import require_version, require_version_core
17
+
18
+
19
+ # define which module versions we always want to check at run time
20
+ # (usually the ones defined in `install_requires` in setup.py)
21
+ #
22
+ # order specific notes:
23
+ # - tqdm must be checked before tokenizers
24
+
25
+ pkgs_to_check_at_runtime = "python requests filelock numpy".split()
26
+ for pkg in pkgs_to_check_at_runtime:
27
+ if pkg in deps:
28
+ require_version_core(deps[pkg])
29
+ else:
30
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
31
+
32
+
33
+ def dep_version_check(pkg, hint=None):
34
+ require_version(deps[pkg], hint)
pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update`
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.31.0",
7
+ "compel": "compel==0.1.8",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flax": "flax>=0.4.1",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.34.0",
13
+ "requests-mock": "requests-mock==1.10.0",
14
+ "importlib_metadata": "importlib_metadata",
15
+ "invisible-watermark": "invisible-watermark>=0.2.0",
16
+ "isort": "isort>=5.5.4",
17
+ "jax": "jax>=0.4.1",
18
+ "jaxlib": "jaxlib>=0.4.1",
19
+ "Jinja2": "Jinja2",
20
+ "k-diffusion": "k-diffusion==0.0.12",
21
+ "torchsde": "torchsde",
22
+ "note_seq": "note_seq",
23
+ "librosa": "librosa",
24
+ "numpy": "numpy",
25
+ "parameterized": "parameterized",
26
+ "peft": "peft>=0.17.0",
27
+ "protobuf": "protobuf>=3.20.3,<4",
28
+ "pytest": "pytest",
29
+ "pytest-timeout": "pytest-timeout",
30
+ "pytest-xdist": "pytest-xdist",
31
+ "python": "python>=3.8.0",
32
+ "ruff": "ruff==0.9.10",
33
+ "safetensors": "safetensors>=0.3.1",
34
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
+ "GitPython": "GitPython<3.1.19",
36
+ "scipy": "scipy",
37
+ "onnx": "onnx",
38
+ "optimum_quanto": "optimum_quanto>=0.2.6",
39
+ "gguf": "gguf>=0.10.0",
40
+ "torchao": "torchao>=0.7.0",
41
+ "bitsandbytes": "bitsandbytes>=0.43.3",
42
+ "nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1",
43
+ "regex": "regex!=2019.12.17",
44
+ "requests": "requests",
45
+ "tensorboard": "tensorboard",
46
+ "tiktoken": "tiktoken>=0.7.0",
47
+ "torch": "torch>=1.4",
48
+ "torchvision": "torchvision",
49
+ "transformers": "transformers>=4.41.2",
50
+ "urllib3": "urllib3<=2.0.0",
51
+ "black": "black",
52
+ "phonemizer": "phonemizer",
53
+ "opencv-python": "opencv-python",
54
+ }
pythonProject/.venv/Lib/site-packages/diffusers/image_processor.py ADDED
@@ -0,0 +1,1451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from PIL import Image, ImageFilter, ImageOps
24
+
25
+ from .configuration_utils import ConfigMixin, register_to_config
26
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
27
+
28
+
29
+ PipelineImageInput = Union[
30
+ PIL.Image.Image,
31
+ np.ndarray,
32
+ torch.Tensor,
33
+ List[PIL.Image.Image],
34
+ List[np.ndarray],
35
+ List[torch.Tensor],
36
+ ]
37
+
38
+ PipelineDepthInput = PipelineImageInput
39
+
40
+
41
+ def is_valid_image(image) -> bool:
42
+ r"""
43
+ Checks if the input is a valid image.
44
+
45
+ A valid image can be:
46
+ - A `PIL.Image.Image`.
47
+ - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
48
+
49
+ Args:
50
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
51
+ The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
52
+
53
+ Returns:
54
+ `bool`:
55
+ `True` if the input is a valid image, `False` otherwise.
56
+ """
57
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
58
+
59
+
60
+ def is_valid_image_imagelist(images):
61
+ r"""
62
+ Checks if the input is a valid image or list of images.
63
+
64
+ The input can be one of the following formats:
65
+ - A 4D tensor or numpy array (batch of images).
66
+ - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
67
+ `torch.Tensor`.
68
+ - A list of valid images.
69
+
70
+ Args:
71
+ images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
72
+ The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
73
+ images.
74
+
75
+ Returns:
76
+ `bool`:
77
+ `True` if the input is valid, `False` otherwise.
78
+ """
79
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
80
+ return True
81
+ elif is_valid_image(images):
82
+ return True
83
+ elif isinstance(images, list):
84
+ return all(is_valid_image(image) for image in images)
85
+ return False
86
+
87
+
88
+ class VaeImageProcessor(ConfigMixin):
89
+ """
90
+ Image processor for VAE.
91
+
92
+ Args:
93
+ do_resize (`bool`, *optional*, defaults to `True`):
94
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
95
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
96
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
97
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
98
+ resample (`str`, *optional*, defaults to `lanczos`):
99
+ Resampling filter to use when resizing the image.
100
+ do_normalize (`bool`, *optional*, defaults to `True`):
101
+ Whether to normalize the image to [-1,1].
102
+ do_binarize (`bool`, *optional*, defaults to `False`):
103
+ Whether to binarize the image to 0/1.
104
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
105
+ Whether to convert the images to RGB format.
106
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
107
+ Whether to convert the images to grayscale format.
108
+ """
109
+
110
+ config_name = CONFIG_NAME
111
+
112
+ @register_to_config
113
+ def __init__(
114
+ self,
115
+ do_resize: bool = True,
116
+ vae_scale_factor: int = 8,
117
+ vae_latent_channels: int = 4,
118
+ resample: str = "lanczos",
119
+ reducing_gap: int = None,
120
+ do_normalize: bool = True,
121
+ do_binarize: bool = False,
122
+ do_convert_rgb: bool = False,
123
+ do_convert_grayscale: bool = False,
124
+ ):
125
+ super().__init__()
126
+ if do_convert_rgb and do_convert_grayscale:
127
+ raise ValueError(
128
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
129
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
130
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
131
+ )
132
+
133
+ @staticmethod
134
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
135
+ r"""
136
+ Convert a numpy image or a batch of images to a PIL image.
137
+
138
+ Args:
139
+ images (`np.ndarray`):
140
+ The image array to convert to PIL format.
141
+
142
+ Returns:
143
+ `List[PIL.Image.Image]`:
144
+ A list of PIL images.
145
+ """
146
+ if images.ndim == 3:
147
+ images = images[None, ...]
148
+ images = (images * 255).round().astype("uint8")
149
+ if images.shape[-1] == 1:
150
+ # special case for grayscale (single channel) images
151
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
152
+ else:
153
+ pil_images = [Image.fromarray(image) for image in images]
154
+
155
+ return pil_images
156
+
157
+ @staticmethod
158
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
159
+ r"""
160
+ Convert a PIL image or a list of PIL images to NumPy arrays.
161
+
162
+ Args:
163
+ images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
164
+ The PIL image or list of images to convert to NumPy format.
165
+
166
+ Returns:
167
+ `np.ndarray`:
168
+ A NumPy array representation of the images.
169
+ """
170
+ if not isinstance(images, list):
171
+ images = [images]
172
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
173
+ images = np.stack(images, axis=0)
174
+
175
+ return images
176
+
177
+ @staticmethod
178
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
179
+ r"""
180
+ Convert a NumPy image to a PyTorch tensor.
181
+
182
+ Args:
183
+ images (`np.ndarray`):
184
+ The NumPy image array to convert to PyTorch format.
185
+
186
+ Returns:
187
+ `torch.Tensor`:
188
+ A PyTorch tensor representation of the images.
189
+ """
190
+ if images.ndim == 3:
191
+ images = images[..., None]
192
+
193
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
194
+ return images
195
+
196
+ @staticmethod
197
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
198
+ r"""
199
+ Convert a PyTorch tensor to a NumPy image.
200
+
201
+ Args:
202
+ images (`torch.Tensor`):
203
+ The PyTorch tensor to convert to NumPy format.
204
+
205
+ Returns:
206
+ `np.ndarray`:
207
+ A NumPy array representation of the images.
208
+ """
209
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
210
+ return images
211
+
212
+ @staticmethod
213
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
214
+ r"""
215
+ Normalize an image array to [-1,1].
216
+
217
+ Args:
218
+ images (`np.ndarray` or `torch.Tensor`):
219
+ The image array to normalize.
220
+
221
+ Returns:
222
+ `np.ndarray` or `torch.Tensor`:
223
+ The normalized image array.
224
+ """
225
+ return 2.0 * images - 1.0
226
+
227
+ @staticmethod
228
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
229
+ r"""
230
+ Denormalize an image array to [0,1].
231
+
232
+ Args:
233
+ images (`np.ndarray` or `torch.Tensor`):
234
+ The image array to denormalize.
235
+
236
+ Returns:
237
+ `np.ndarray` or `torch.Tensor`:
238
+ The denormalized image array.
239
+ """
240
+ return (images * 0.5 + 0.5).clamp(0, 1)
241
+
242
+ @staticmethod
243
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
244
+ r"""
245
+ Converts a PIL image to RGB format.
246
+
247
+ Args:
248
+ image (`PIL.Image.Image`):
249
+ The PIL image to convert to RGB.
250
+
251
+ Returns:
252
+ `PIL.Image.Image`:
253
+ The RGB-converted PIL image.
254
+ """
255
+ image = image.convert("RGB")
256
+
257
+ return image
258
+
259
+ @staticmethod
260
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
261
+ r"""
262
+ Converts a given PIL image to grayscale.
263
+
264
+ Args:
265
+ image (`PIL.Image.Image`):
266
+ The input image to convert.
267
+
268
+ Returns:
269
+ `PIL.Image.Image`:
270
+ The image converted to grayscale.
271
+ """
272
+ image = image.convert("L")
273
+
274
+ return image
275
+
276
+ @staticmethod
277
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
278
+ r"""
279
+ Applies Gaussian blur to an image.
280
+
281
+ Args:
282
+ image (`PIL.Image.Image`):
283
+ The PIL image to convert to grayscale.
284
+
285
+ Returns:
286
+ `PIL.Image.Image`:
287
+ The grayscale-converted PIL image.
288
+ """
289
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
290
+
291
+ return image
292
+
293
+ @staticmethod
294
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
295
+ r"""
296
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
297
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
298
+ processing are 512x512, the region will be expanded to 128x128.
299
+
300
+ Args:
301
+ mask_image (PIL.Image.Image): Mask image.
302
+ width (int): Width of the image to be processed.
303
+ height (int): Height of the image to be processed.
304
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
305
+
306
+ Returns:
307
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
308
+ matches the original aspect ratio.
309
+ """
310
+
311
+ mask_image = mask_image.convert("L")
312
+ mask = np.array(mask_image)
313
+
314
+ # 1. find a rectangular region that contains all masked ares in an image
315
+ h, w = mask.shape
316
+ crop_left = 0
317
+ for i in range(w):
318
+ if not (mask[:, i] == 0).all():
319
+ break
320
+ crop_left += 1
321
+
322
+ crop_right = 0
323
+ for i in reversed(range(w)):
324
+ if not (mask[:, i] == 0).all():
325
+ break
326
+ crop_right += 1
327
+
328
+ crop_top = 0
329
+ for i in range(h):
330
+ if not (mask[i] == 0).all():
331
+ break
332
+ crop_top += 1
333
+
334
+ crop_bottom = 0
335
+ for i in reversed(range(h)):
336
+ if not (mask[i] == 0).all():
337
+ break
338
+ crop_bottom += 1
339
+
340
+ # 2. add padding to the crop region
341
+ x1, y1, x2, y2 = (
342
+ int(max(crop_left - pad, 0)),
343
+ int(max(crop_top - pad, 0)),
344
+ int(min(w - crop_right + pad, w)),
345
+ int(min(h - crop_bottom + pad, h)),
346
+ )
347
+
348
+ # 3. expands crop region to match the aspect ratio of the image to be processed
349
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
350
+ ratio_processing = width / height
351
+
352
+ if ratio_crop_region > ratio_processing:
353
+ desired_height = (x2 - x1) / ratio_processing
354
+ desired_height_diff = int(desired_height - (y2 - y1))
355
+ y1 -= desired_height_diff // 2
356
+ y2 += desired_height_diff - desired_height_diff // 2
357
+ if y2 >= mask_image.height:
358
+ diff = y2 - mask_image.height
359
+ y2 -= diff
360
+ y1 -= diff
361
+ if y1 < 0:
362
+ y2 -= y1
363
+ y1 -= y1
364
+ if y2 >= mask_image.height:
365
+ y2 = mask_image.height
366
+ else:
367
+ desired_width = (y2 - y1) * ratio_processing
368
+ desired_width_diff = int(desired_width - (x2 - x1))
369
+ x1 -= desired_width_diff // 2
370
+ x2 += desired_width_diff - desired_width_diff // 2
371
+ if x2 >= mask_image.width:
372
+ diff = x2 - mask_image.width
373
+ x2 -= diff
374
+ x1 -= diff
375
+ if x1 < 0:
376
+ x2 -= x1
377
+ x1 -= x1
378
+ if x2 >= mask_image.width:
379
+ x2 = mask_image.width
380
+
381
+ return x1, y1, x2, y2
382
+
383
+ def _resize_and_fill(
384
+ self,
385
+ image: PIL.Image.Image,
386
+ width: int,
387
+ height: int,
388
+ ) -> PIL.Image.Image:
389
+ r"""
390
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
391
+ the image within the dimensions, filling empty with data from image.
392
+
393
+ Args:
394
+ image (`PIL.Image.Image`):
395
+ The image to resize and fill.
396
+ width (`int`):
397
+ The width to resize the image to.
398
+ height (`int`):
399
+ The height to resize the image to.
400
+
401
+ Returns:
402
+ `PIL.Image.Image`:
403
+ The resized and filled image.
404
+ """
405
+
406
+ ratio = width / height
407
+ src_ratio = image.width / image.height
408
+
409
+ src_w = width if ratio < src_ratio else image.width * height // image.height
410
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
411
+
412
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
413
+ res = Image.new("RGB", (width, height))
414
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
415
+
416
+ if ratio < src_ratio:
417
+ fill_height = height // 2 - src_h // 2
418
+ if fill_height > 0:
419
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
420
+ res.paste(
421
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
422
+ box=(0, fill_height + src_h),
423
+ )
424
+ elif ratio > src_ratio:
425
+ fill_width = width // 2 - src_w // 2
426
+ if fill_width > 0:
427
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
428
+ res.paste(
429
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
430
+ box=(fill_width + src_w, 0),
431
+ )
432
+
433
+ return res
434
+
435
+ def _resize_and_crop(
436
+ self,
437
+ image: PIL.Image.Image,
438
+ width: int,
439
+ height: int,
440
+ ) -> PIL.Image.Image:
441
+ r"""
442
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
443
+ the image within the dimensions, cropping the excess.
444
+
445
+ Args:
446
+ image (`PIL.Image.Image`):
447
+ The image to resize and crop.
448
+ width (`int`):
449
+ The width to resize the image to.
450
+ height (`int`):
451
+ The height to resize the image to.
452
+
453
+ Returns:
454
+ `PIL.Image.Image`:
455
+ The resized and cropped image.
456
+ """
457
+ ratio = width / height
458
+ src_ratio = image.width / image.height
459
+
460
+ src_w = width if ratio > src_ratio else image.width * height // image.height
461
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
462
+
463
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
464
+ res = Image.new("RGB", (width, height))
465
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
466
+ return res
467
+
468
+ def resize(
469
+ self,
470
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
471
+ height: int,
472
+ width: int,
473
+ resize_mode: str = "default", # "default", "fill", "crop"
474
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
475
+ """
476
+ Resize image.
477
+
478
+ Args:
479
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
480
+ The image input, can be a PIL image, numpy array or pytorch tensor.
481
+ height (`int`):
482
+ The height to resize to.
483
+ width (`int`):
484
+ The width to resize to.
485
+ resize_mode (`str`, *optional*, defaults to `default`):
486
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
487
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
488
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
489
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
490
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
491
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
492
+ supported for PIL image input.
493
+
494
+ Returns:
495
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
496
+ The resized image.
497
+ """
498
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
499
+ raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
500
+ if isinstance(image, PIL.Image.Image):
501
+ if resize_mode == "default":
502
+ image = image.resize(
503
+ (width, height),
504
+ resample=PIL_INTERPOLATION[self.config.resample],
505
+ reducing_gap=self.config.reducing_gap,
506
+ )
507
+ elif resize_mode == "fill":
508
+ image = self._resize_and_fill(image, width, height)
509
+ elif resize_mode == "crop":
510
+ image = self._resize_and_crop(image, width, height)
511
+ else:
512
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
513
+
514
+ elif isinstance(image, torch.Tensor):
515
+ image = torch.nn.functional.interpolate(
516
+ image,
517
+ size=(height, width),
518
+ )
519
+ elif isinstance(image, np.ndarray):
520
+ image = self.numpy_to_pt(image)
521
+ image = torch.nn.functional.interpolate(
522
+ image,
523
+ size=(height, width),
524
+ )
525
+ image = self.pt_to_numpy(image)
526
+
527
+ return image
528
+
529
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
530
+ """
531
+ Create a mask.
532
+
533
+ Args:
534
+ image (`PIL.Image.Image`):
535
+ The image input, should be a PIL image.
536
+
537
+ Returns:
538
+ `PIL.Image.Image`:
539
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
540
+ """
541
+ image[image < 0.5] = 0
542
+ image[image >= 0.5] = 1
543
+
544
+ return image
545
+
546
+ def _denormalize_conditionally(
547
+ self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
548
+ ) -> torch.Tensor:
549
+ r"""
550
+ Denormalize a batch of images based on a condition list.
551
+
552
+ Args:
553
+ images (`torch.Tensor`):
554
+ The input image tensor.
555
+ do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
556
+ A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
557
+ value of `do_normalize` in the `VaeImageProcessor` config.
558
+ """
559
+ if do_denormalize is None:
560
+ return self.denormalize(images) if self.config.do_normalize else images
561
+
562
+ return torch.stack(
563
+ [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
564
+ )
565
+
566
+ def get_default_height_width(
567
+ self,
568
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
569
+ height: Optional[int] = None,
570
+ width: Optional[int] = None,
571
+ ) -> Tuple[int, int]:
572
+ r"""
573
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
574
+
575
+ Args:
576
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
577
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
578
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
579
+ tensor, it should have shape `[batch, channels, height, width]`.
580
+ height (`Optional[int]`, *optional*, defaults to `None`):
581
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
582
+ width (`Optional[int]`, *optional*, defaults to `None`):
583
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
584
+
585
+ Returns:
586
+ `Tuple[int, int]`:
587
+ A tuple containing the height and width, both resized to the nearest integer multiple of
588
+ `vae_scale_factor`.
589
+ """
590
+
591
+ if height is None:
592
+ if isinstance(image, PIL.Image.Image):
593
+ height = image.height
594
+ elif isinstance(image, torch.Tensor):
595
+ height = image.shape[2]
596
+ else:
597
+ height = image.shape[1]
598
+
599
+ if width is None:
600
+ if isinstance(image, PIL.Image.Image):
601
+ width = image.width
602
+ elif isinstance(image, torch.Tensor):
603
+ width = image.shape[3]
604
+ else:
605
+ width = image.shape[2]
606
+
607
+ width, height = (
608
+ x - x % self.config.vae_scale_factor for x in (width, height)
609
+ ) # resize to integer multiple of vae_scale_factor
610
+
611
+ return height, width
612
+
613
+ def preprocess(
614
+ self,
615
+ image: PipelineImageInput,
616
+ height: Optional[int] = None,
617
+ width: Optional[int] = None,
618
+ resize_mode: str = "default", # "default", "fill", "crop"
619
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
620
+ ) -> torch.Tensor:
621
+ """
622
+ Preprocess the image input.
623
+
624
+ Args:
625
+ image (`PipelineImageInput`):
626
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
627
+ supported formats.
628
+ height (`int`, *optional*):
629
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
630
+ height.
631
+ width (`int`, *optional*):
632
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
633
+ resize_mode (`str`, *optional*, defaults to `default`):
634
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
635
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
636
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
637
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
638
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
639
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
640
+ supported for PIL image input.
641
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
642
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
643
+
644
+ Returns:
645
+ `torch.Tensor`:
646
+ The preprocessed image.
647
+ """
648
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
649
+
650
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
651
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
652
+ if isinstance(image, torch.Tensor):
653
+ # if image is a pytorch tensor could have 2 possible shapes:
654
+ # 1. batch x height x width: we should insert the channel dimension at position 1
655
+ # 2. channel x height x width: we should insert batch dimension at position 0,
656
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
657
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
658
+ image = image.unsqueeze(1)
659
+ else:
660
+ # if it is a numpy array, it could have 2 possible shapes:
661
+ # 1. batch x height x width: insert channel dimension on last position
662
+ # 2. height x width x channel: insert batch dimension on first position
663
+ if image.shape[-1] == 1:
664
+ image = np.expand_dims(image, axis=0)
665
+ else:
666
+ image = np.expand_dims(image, axis=-1)
667
+
668
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
669
+ warnings.warn(
670
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
671
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
672
+ FutureWarning,
673
+ )
674
+ image = np.concatenate(image, axis=0)
675
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
676
+ warnings.warn(
677
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
678
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
679
+ FutureWarning,
680
+ )
681
+ image = torch.cat(image, axis=0)
682
+
683
+ if not is_valid_image_imagelist(image):
684
+ raise ValueError(
685
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
686
+ )
687
+ if not isinstance(image, list):
688
+ image = [image]
689
+
690
+ if isinstance(image[0], PIL.Image.Image):
691
+ if crops_coords is not None:
692
+ image = [i.crop(crops_coords) for i in image]
693
+ if self.config.do_resize:
694
+ height, width = self.get_default_height_width(image[0], height, width)
695
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
696
+ if self.config.do_convert_rgb:
697
+ image = [self.convert_to_rgb(i) for i in image]
698
+ elif self.config.do_convert_grayscale:
699
+ image = [self.convert_to_grayscale(i) for i in image]
700
+ image = self.pil_to_numpy(image) # to np
701
+ image = self.numpy_to_pt(image) # to pt
702
+
703
+ elif isinstance(image[0], np.ndarray):
704
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
705
+
706
+ image = self.numpy_to_pt(image)
707
+
708
+ height, width = self.get_default_height_width(image, height, width)
709
+ if self.config.do_resize:
710
+ image = self.resize(image, height, width)
711
+
712
+ elif isinstance(image[0], torch.Tensor):
713
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
714
+
715
+ if self.config.do_convert_grayscale and image.ndim == 3:
716
+ image = image.unsqueeze(1)
717
+
718
+ channel = image.shape[1]
719
+ # don't need any preprocess if the image is latents
720
+ if channel == self.config.vae_latent_channels:
721
+ return image
722
+
723
+ height, width = self.get_default_height_width(image, height, width)
724
+ if self.config.do_resize:
725
+ image = self.resize(image, height, width)
726
+
727
+ # expected range [0,1], normalize to [-1,1]
728
+ do_normalize = self.config.do_normalize
729
+ if do_normalize and image.min() < 0:
730
+ warnings.warn(
731
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
732
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
733
+ FutureWarning,
734
+ )
735
+ do_normalize = False
736
+ if do_normalize:
737
+ image = self.normalize(image)
738
+
739
+ if self.config.do_binarize:
740
+ image = self.binarize(image)
741
+
742
+ return image
743
+
744
+ def postprocess(
745
+ self,
746
+ image: torch.Tensor,
747
+ output_type: str = "pil",
748
+ do_denormalize: Optional[List[bool]] = None,
749
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
750
+ """
751
+ Postprocess the image output from tensor to `output_type`.
752
+
753
+ Args:
754
+ image (`torch.Tensor`):
755
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
756
+ output_type (`str`, *optional*, defaults to `pil`):
757
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
758
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
759
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
760
+ `VaeImageProcessor` config.
761
+
762
+ Returns:
763
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
764
+ The postprocessed image.
765
+ """
766
+ if not isinstance(image, torch.Tensor):
767
+ raise ValueError(
768
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
769
+ )
770
+ if output_type not in ["latent", "pt", "np", "pil"]:
771
+ deprecation_message = (
772
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
773
+ "`pil`, `np`, `pt`, `latent`"
774
+ )
775
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
776
+ output_type = "np"
777
+
778
+ if output_type == "latent":
779
+ return image
780
+
781
+ image = self._denormalize_conditionally(image, do_denormalize)
782
+
783
+ if output_type == "pt":
784
+ return image
785
+
786
+ image = self.pt_to_numpy(image)
787
+
788
+ if output_type == "np":
789
+ return image
790
+
791
+ if output_type == "pil":
792
+ return self.numpy_to_pil(image)
793
+
794
+ def apply_overlay(
795
+ self,
796
+ mask: PIL.Image.Image,
797
+ init_image: PIL.Image.Image,
798
+ image: PIL.Image.Image,
799
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
800
+ ) -> PIL.Image.Image:
801
+ r"""
802
+ Applies an overlay of the mask and the inpainted image on the original image.
803
+
804
+ Args:
805
+ mask (`PIL.Image.Image`):
806
+ The mask image that highlights regions to overlay.
807
+ init_image (`PIL.Image.Image`):
808
+ The original image to which the overlay is applied.
809
+ image (`PIL.Image.Image`):
810
+ The image to overlay onto the original.
811
+ crop_coords (`Tuple[int, int, int, int]`, *optional*):
812
+ Coordinates to crop the image. If provided, the image will be cropped accordingly.
813
+
814
+ Returns:
815
+ `PIL.Image.Image`:
816
+ The final image with the overlay applied.
817
+ """
818
+
819
+ width, height = init_image.width, init_image.height
820
+
821
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
822
+ init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
823
+
824
+ init_image_masked = init_image_masked.convert("RGBA")
825
+
826
+ if crop_coords is not None:
827
+ x, y, x2, y2 = crop_coords
828
+ w = x2 - x
829
+ h = y2 - y
830
+ base_image = PIL.Image.new("RGBA", (width, height))
831
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
832
+ base_image.paste(image, (x, y))
833
+ image = base_image.convert("RGB")
834
+
835
+ image = image.convert("RGBA")
836
+ image.alpha_composite(init_image_masked)
837
+ image = image.convert("RGB")
838
+
839
+ return image
840
+
841
+
842
+ class InpaintProcessor(ConfigMixin):
843
+ """
844
+ Image processor for inpainting image and mask.
845
+ """
846
+
847
+ config_name = CONFIG_NAME
848
+
849
+ @register_to_config
850
+ def __init__(
851
+ self,
852
+ do_resize: bool = True,
853
+ vae_scale_factor: int = 8,
854
+ vae_latent_channels: int = 4,
855
+ resample: str = "lanczos",
856
+ reducing_gap: int = None,
857
+ do_normalize: bool = True,
858
+ do_binarize: bool = False,
859
+ do_convert_grayscale: bool = False,
860
+ mask_do_normalize: bool = False,
861
+ mask_do_binarize: bool = True,
862
+ mask_do_convert_grayscale: bool = True,
863
+ ):
864
+ super().__init__()
865
+
866
+ self._image_processor = VaeImageProcessor(
867
+ do_resize=do_resize,
868
+ vae_scale_factor=vae_scale_factor,
869
+ vae_latent_channels=vae_latent_channels,
870
+ resample=resample,
871
+ reducing_gap=reducing_gap,
872
+ do_normalize=do_normalize,
873
+ do_binarize=do_binarize,
874
+ do_convert_grayscale=do_convert_grayscale,
875
+ )
876
+ self._mask_processor = VaeImageProcessor(
877
+ do_resize=do_resize,
878
+ vae_scale_factor=vae_scale_factor,
879
+ vae_latent_channels=vae_latent_channels,
880
+ resample=resample,
881
+ reducing_gap=reducing_gap,
882
+ do_normalize=mask_do_normalize,
883
+ do_binarize=mask_do_binarize,
884
+ do_convert_grayscale=mask_do_convert_grayscale,
885
+ )
886
+
887
+ def preprocess(
888
+ self,
889
+ image: PIL.Image.Image,
890
+ mask: PIL.Image.Image = None,
891
+ height: int = None,
892
+ width: int = None,
893
+ padding_mask_crop: Optional[int] = None,
894
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
895
+ """
896
+ Preprocess the image and mask.
897
+ """
898
+ if mask is None and padding_mask_crop is not None:
899
+ raise ValueError("mask must be provided if padding_mask_crop is provided")
900
+
901
+ # if mask is None, same behavior as regular image processor
902
+ if mask is None:
903
+ return self._image_processor.preprocess(image, height=height, width=width)
904
+
905
+ if padding_mask_crop is not None:
906
+ crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
907
+ resize_mode = "fill"
908
+ else:
909
+ crops_coords = None
910
+ resize_mode = "default"
911
+
912
+ processed_image = self._image_processor.preprocess(
913
+ image,
914
+ height=height,
915
+ width=width,
916
+ crops_coords=crops_coords,
917
+ resize_mode=resize_mode,
918
+ )
919
+
920
+ processed_mask = self._mask_processor.preprocess(
921
+ mask,
922
+ height=height,
923
+ width=width,
924
+ resize_mode=resize_mode,
925
+ crops_coords=crops_coords,
926
+ )
927
+
928
+ if crops_coords is not None:
929
+ postprocessing_kwargs = {
930
+ "crops_coords": crops_coords,
931
+ "original_image": image,
932
+ "original_mask": mask,
933
+ }
934
+ else:
935
+ postprocessing_kwargs = {
936
+ "crops_coords": None,
937
+ "original_image": None,
938
+ "original_mask": None,
939
+ }
940
+
941
+ return processed_image, processed_mask, postprocessing_kwargs
942
+
943
+ def postprocess(
944
+ self,
945
+ image: torch.Tensor,
946
+ output_type: str = "pil",
947
+ original_image: Optional[PIL.Image.Image] = None,
948
+ original_mask: Optional[PIL.Image.Image] = None,
949
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
950
+ ) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
951
+ """
952
+ Postprocess the image, optionally apply mask overlay
953
+ """
954
+ image = self._image_processor.postprocess(
955
+ image,
956
+ output_type=output_type,
957
+ )
958
+ # optionally apply the mask overlay
959
+ if crops_coords is not None and (original_image is None or original_mask is None):
960
+ raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
961
+
962
+ elif crops_coords is not None and output_type != "pil":
963
+ raise ValueError("output_type must be 'pil' if crops_coords is provided")
964
+
965
+ elif crops_coords is not None:
966
+ image = [
967
+ self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
968
+ ]
969
+
970
+ return image
971
+
972
+
973
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
974
+ """
975
+ Image processor for VAE LDM3D.
976
+
977
+ Args:
978
+ do_resize (`bool`, *optional*, defaults to `True`):
979
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
980
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
981
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
982
+ resample (`str`, *optional*, defaults to `lanczos`):
983
+ Resampling filter to use when resizing the image.
984
+ do_normalize (`bool`, *optional*, defaults to `True`):
985
+ Whether to normalize the image to [-1,1].
986
+ """
987
+
988
+ config_name = CONFIG_NAME
989
+
990
+ @register_to_config
991
+ def __init__(
992
+ self,
993
+ do_resize: bool = True,
994
+ vae_scale_factor: int = 8,
995
+ resample: str = "lanczos",
996
+ do_normalize: bool = True,
997
+ ):
998
+ super().__init__()
999
+
1000
+ @staticmethod
1001
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
1002
+ r"""
1003
+ Convert a NumPy image or a batch of images to a list of PIL images.
1004
+
1005
+ Args:
1006
+ images (`np.ndarray`):
1007
+ The input NumPy array of images, which can be a single image or a batch.
1008
+
1009
+ Returns:
1010
+ `List[PIL.Image.Image]`:
1011
+ A list of PIL images converted from the input NumPy array.
1012
+ """
1013
+ if images.ndim == 3:
1014
+ images = images[None, ...]
1015
+ images = (images * 255).round().astype("uint8")
1016
+ if images.shape[-1] == 1:
1017
+ # special case for grayscale (single channel) images
1018
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
1019
+ else:
1020
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
1021
+
1022
+ return pil_images
1023
+
1024
+ @staticmethod
1025
+ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
1026
+ r"""
1027
+ Convert a PIL image or a list of PIL images to NumPy arrays.
1028
+
1029
+ Args:
1030
+ images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
1031
+ The input image or list of images to be converted.
1032
+
1033
+ Returns:
1034
+ `np.ndarray`:
1035
+ A NumPy array of the converted images.
1036
+ """
1037
+ if not isinstance(images, list):
1038
+ images = [images]
1039
+
1040
+ images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
1041
+ images = np.stack(images, axis=0)
1042
+ return images
1043
+
1044
+ @staticmethod
1045
+ def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
1046
+ r"""
1047
+ Convert an RGB-like depth image to a depth map.
1048
+
1049
+ Args:
1050
+ image (`Union[np.ndarray, torch.Tensor]`):
1051
+ The RGB-like depth image to convert.
1052
+
1053
+ Returns:
1054
+ `Union[np.ndarray, torch.Tensor]`:
1055
+ The corresponding depth map.
1056
+ """
1057
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
1058
+
1059
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
1060
+ r"""
1061
+ Convert a NumPy depth image or a batch of images to a list of PIL images.
1062
+
1063
+ Args:
1064
+ images (`np.ndarray`):
1065
+ The input NumPy array of depth images, which can be a single image or a batch.
1066
+
1067
+ Returns:
1068
+ `List[PIL.Image.Image]`:
1069
+ A list of PIL images converted from the input NumPy depth images.
1070
+ """
1071
+ if images.ndim == 3:
1072
+ images = images[None, ...]
1073
+ images_depth = images[:, :, :, 3:]
1074
+ if images.shape[-1] == 6:
1075
+ images_depth = (images_depth * 255).round().astype("uint8")
1076
+ pil_images = [
1077
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
1078
+ ]
1079
+ elif images.shape[-1] == 4:
1080
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
1081
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
1082
+ else:
1083
+ raise Exception("Not supported")
1084
+
1085
+ return pil_images
1086
+
1087
+ def postprocess(
1088
+ self,
1089
+ image: torch.Tensor,
1090
+ output_type: str = "pil",
1091
+ do_denormalize: Optional[List[bool]] = None,
1092
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
1093
+ """
1094
+ Postprocess the image output from tensor to `output_type`.
1095
+
1096
+ Args:
1097
+ image (`torch.Tensor`):
1098
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
1099
+ output_type (`str`, *optional*, defaults to `pil`):
1100
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
1101
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
1102
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
1103
+ `VaeImageProcessor` config.
1104
+
1105
+ Returns:
1106
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
1107
+ The postprocessed image.
1108
+ """
1109
+ if not isinstance(image, torch.Tensor):
1110
+ raise ValueError(
1111
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
1112
+ )
1113
+ if output_type not in ["latent", "pt", "np", "pil"]:
1114
+ deprecation_message = (
1115
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
1116
+ "`pil`, `np`, `pt`, `latent`"
1117
+ )
1118
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
1119
+ output_type = "np"
1120
+
1121
+ image = self._denormalize_conditionally(image, do_denormalize)
1122
+
1123
+ image = self.pt_to_numpy(image)
1124
+
1125
+ if output_type == "np":
1126
+ if image.shape[-1] == 6:
1127
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
1128
+ else:
1129
+ image_depth = image[:, :, :, 3:]
1130
+ return image[:, :, :, :3], image_depth
1131
+
1132
+ if output_type == "pil":
1133
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
1134
+ else:
1135
+ raise Exception(f"This type {output_type} is not supported")
1136
+
1137
+ def preprocess(
1138
+ self,
1139
+ rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
1140
+ depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
1141
+ height: Optional[int] = None,
1142
+ width: Optional[int] = None,
1143
+ target_res: Optional[int] = None,
1144
+ ) -> torch.Tensor:
1145
+ r"""
1146
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
1147
+
1148
+ Args:
1149
+ rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
1150
+ The RGB input image, which can be a single image or a batch.
1151
+ depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
1152
+ The depth input image, which can be a single image or a batch.
1153
+ height (`Optional[int]`, *optional*, defaults to `None`):
1154
+ The desired height of the processed image. If `None`, defaults to the height of the input image.
1155
+ width (`Optional[int]`, *optional*, defaults to `None`):
1156
+ The desired width of the processed image. If `None`, defaults to the width of the input image.
1157
+ target_res (`Optional[int]`, *optional*, defaults to `None`):
1158
+ Target resolution for resizing the images. If specified, overrides height and width.
1159
+
1160
+ Returns:
1161
+ `Tuple[torch.Tensor, torch.Tensor]`:
1162
+ A tuple containing the processed RGB and depth images as PyTorch tensors.
1163
+ """
1164
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
1165
+
1166
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
1167
+ if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
1168
+ raise Exception("This is not yet supported")
1169
+
1170
+ if isinstance(rgb, supported_formats):
1171
+ rgb = [rgb]
1172
+ depth = [depth]
1173
+ elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
1174
+ raise ValueError(
1175
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
1176
+ )
1177
+
1178
+ if isinstance(rgb[0], PIL.Image.Image):
1179
+ if self.config.do_convert_rgb:
1180
+ raise Exception("This is not yet supported")
1181
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
1182
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
1183
+ if self.config.do_resize or target_res:
1184
+ height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
1185
+ rgb = [self.resize(i, height, width) for i in rgb]
1186
+ depth = [self.resize(i, height, width) for i in depth]
1187
+ rgb = self.pil_to_numpy(rgb) # to np
1188
+ rgb = self.numpy_to_pt(rgb) # to pt
1189
+
1190
+ depth = self.depth_pil_to_numpy(depth) # to np
1191
+ depth = self.numpy_to_pt(depth) # to pt
1192
+
1193
+ elif isinstance(rgb[0], np.ndarray):
1194
+ rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
1195
+ rgb = self.numpy_to_pt(rgb)
1196
+ height, width = self.get_default_height_width(rgb, height, width)
1197
+ if self.config.do_resize:
1198
+ rgb = self.resize(rgb, height, width)
1199
+
1200
+ depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
1201
+ depth = self.numpy_to_pt(depth)
1202
+ height, width = self.get_default_height_width(depth, height, width)
1203
+ if self.config.do_resize:
1204
+ depth = self.resize(depth, height, width)
1205
+
1206
+ elif isinstance(rgb[0], torch.Tensor):
1207
+ raise Exception("This is not yet supported")
1208
+ # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
1209
+
1210
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
1211
+ # rgb = rgb.unsqueeze(1)
1212
+
1213
+ # channel = rgb.shape[1]
1214
+
1215
+ # height, width = self.get_default_height_width(rgb, height, width)
1216
+ # if self.config.do_resize:
1217
+ # rgb = self.resize(rgb, height, width)
1218
+
1219
+ # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
1220
+
1221
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
1222
+ # depth = depth.unsqueeze(1)
1223
+
1224
+ # channel = depth.shape[1]
1225
+ # # don't need any preprocess if the image is latents
1226
+ # if depth == 4:
1227
+ # return rgb, depth
1228
+
1229
+ # height, width = self.get_default_height_width(depth, height, width)
1230
+ # if self.config.do_resize:
1231
+ # depth = self.resize(depth, height, width)
1232
+ # expected range [0,1], normalize to [-1,1]
1233
+ do_normalize = self.config.do_normalize
1234
+ if rgb.min() < 0 and do_normalize:
1235
+ warnings.warn(
1236
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
1237
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
1238
+ FutureWarning,
1239
+ )
1240
+ do_normalize = False
1241
+
1242
+ if do_normalize:
1243
+ rgb = self.normalize(rgb)
1244
+ depth = self.normalize(depth)
1245
+
1246
+ if self.config.do_binarize:
1247
+ rgb = self.binarize(rgb)
1248
+ depth = self.binarize(depth)
1249
+
1250
+ return rgb, depth
1251
+
1252
+
1253
+ class IPAdapterMaskProcessor(VaeImageProcessor):
1254
+ """
1255
+ Image processor for IP Adapter image masks.
1256
+
1257
+ Args:
1258
+ do_resize (`bool`, *optional*, defaults to `True`):
1259
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
1260
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1261
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1262
+ resample (`str`, *optional*, defaults to `lanczos`):
1263
+ Resampling filter to use when resizing the image.
1264
+ do_normalize (`bool`, *optional*, defaults to `False`):
1265
+ Whether to normalize the image to [-1,1].
1266
+ do_binarize (`bool`, *optional*, defaults to `True`):
1267
+ Whether to binarize the image to 0/1.
1268
+ do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
1269
+ Whether to convert the images to grayscale format.
1270
+
1271
+ """
1272
+
1273
+ config_name = CONFIG_NAME
1274
+
1275
+ @register_to_config
1276
+ def __init__(
1277
+ self,
1278
+ do_resize: bool = True,
1279
+ vae_scale_factor: int = 8,
1280
+ resample: str = "lanczos",
1281
+ do_normalize: bool = False,
1282
+ do_binarize: bool = True,
1283
+ do_convert_grayscale: bool = True,
1284
+ ):
1285
+ super().__init__(
1286
+ do_resize=do_resize,
1287
+ vae_scale_factor=vae_scale_factor,
1288
+ resample=resample,
1289
+ do_normalize=do_normalize,
1290
+ do_binarize=do_binarize,
1291
+ do_convert_grayscale=do_convert_grayscale,
1292
+ )
1293
+
1294
+ @staticmethod
1295
+ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
1296
+ """
1297
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
1298
+ aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
1299
+
1300
+ Args:
1301
+ mask (`torch.Tensor`):
1302
+ The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
1303
+ batch_size (`int`):
1304
+ The batch size.
1305
+ num_queries (`int`):
1306
+ The number of queries.
1307
+ value_embed_dim (`int`):
1308
+ The dimensionality of the value embeddings.
1309
+
1310
+ Returns:
1311
+ `torch.Tensor`:
1312
+ The downsampled mask tensor.
1313
+
1314
+ """
1315
+ o_h = mask.shape[1]
1316
+ o_w = mask.shape[2]
1317
+ ratio = o_w / o_h
1318
+ mask_h = int(math.sqrt(num_queries / ratio))
1319
+ mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
1320
+ mask_w = num_queries // mask_h
1321
+
1322
+ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
1323
+
1324
+ # Repeat batch_size times
1325
+ if mask_downsample.shape[0] < batch_size:
1326
+ mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
1327
+
1328
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
1329
+
1330
+ downsampled_area = mask_h * mask_w
1331
+ # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
1332
+ # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
1333
+ if downsampled_area < num_queries:
1334
+ warnings.warn(
1335
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1336
+ "Please update your masks or adjust the output size for optimal performance.",
1337
+ UserWarning,
1338
+ )
1339
+ mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
1340
+ # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
1341
+ if downsampled_area > num_queries:
1342
+ warnings.warn(
1343
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1344
+ "Please update your masks or adjust the output size for optimal performance.",
1345
+ UserWarning,
1346
+ )
1347
+ mask_downsample = mask_downsample[:, :num_queries]
1348
+
1349
+ # Repeat last dimension to match SDPA output shape
1350
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
1351
+ 1, 1, value_embed_dim
1352
+ )
1353
+
1354
+ return mask_downsample
1355
+
1356
+
1357
+ class PixArtImageProcessor(VaeImageProcessor):
1358
+ """
1359
+ Image processor for PixArt image resize and crop.
1360
+
1361
+ Args:
1362
+ do_resize (`bool`, *optional*, defaults to `True`):
1363
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
1364
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
1365
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1366
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1367
+ resample (`str`, *optional*, defaults to `lanczos`):
1368
+ Resampling filter to use when resizing the image.
1369
+ do_normalize (`bool`, *optional*, defaults to `True`):
1370
+ Whether to normalize the image to [-1,1].
1371
+ do_binarize (`bool`, *optional*, defaults to `False`):
1372
+ Whether to binarize the image to 0/1.
1373
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
1374
+ Whether to convert the images to RGB format.
1375
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
1376
+ Whether to convert the images to grayscale format.
1377
+ """
1378
+
1379
+ @register_to_config
1380
+ def __init__(
1381
+ self,
1382
+ do_resize: bool = True,
1383
+ vae_scale_factor: int = 8,
1384
+ resample: str = "lanczos",
1385
+ do_normalize: bool = True,
1386
+ do_binarize: bool = False,
1387
+ do_convert_grayscale: bool = False,
1388
+ ):
1389
+ super().__init__(
1390
+ do_resize=do_resize,
1391
+ vae_scale_factor=vae_scale_factor,
1392
+ resample=resample,
1393
+ do_normalize=do_normalize,
1394
+ do_binarize=do_binarize,
1395
+ do_convert_grayscale=do_convert_grayscale,
1396
+ )
1397
+
1398
+ @staticmethod
1399
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
1400
+ r"""
1401
+ Returns the binned height and width based on the aspect ratio.
1402
+
1403
+ Args:
1404
+ height (`int`): The height of the image.
1405
+ width (`int`): The width of the image.
1406
+ ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
1407
+
1408
+ Returns:
1409
+ `Tuple[int, int]`: The closest binned height and width.
1410
+ """
1411
+ ar = float(height / width)
1412
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
1413
+ default_hw = ratios[closest_ratio]
1414
+ return int(default_hw[0]), int(default_hw[1])
1415
+
1416
+ @staticmethod
1417
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
1418
+ r"""
1419
+ Resizes and crops a tensor of images to the specified dimensions.
1420
+
1421
+ Args:
1422
+ samples (`torch.Tensor`):
1423
+ A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
1424
+ and W is the width.
1425
+ new_width (`int`): The desired width of the output images.
1426
+ new_height (`int`): The desired height of the output images.
1427
+
1428
+ Returns:
1429
+ `torch.Tensor`: A tensor containing the resized and cropped images.
1430
+ """
1431
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
1432
+
1433
+ # Check if resizing is needed
1434
+ if orig_height != new_height or orig_width != new_width:
1435
+ ratio = max(new_height / orig_height, new_width / orig_width)
1436
+ resized_width = int(orig_width * ratio)
1437
+ resized_height = int(orig_height * ratio)
1438
+
1439
+ # Resize
1440
+ samples = F.interpolate(
1441
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
1442
+ )
1443
+
1444
+ # Center Crop
1445
+ start_x = (resized_width - new_width) // 2
1446
+ end_x = start_x + new_width
1447
+ start_y = (resized_height - new_height) // 2
1448
+ end_y = start_y + new_height
1449
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
1450
+
1451
+ return samples
pythonProject/.venv/Lib/site-packages/diffusers/optimization.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for diffusion models."""
16
+
17
+ import math
18
+ from enum import Enum
19
+ from typing import Optional, Union
20
+
21
+ from torch.optim import Optimizer
22
+ from torch.optim.lr_scheduler import LambdaLR
23
+
24
+ from .utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SchedulerType(Enum):
31
+ LINEAR = "linear"
32
+ COSINE = "cosine"
33
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
34
+ POLYNOMIAL = "polynomial"
35
+ CONSTANT = "constant"
36
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
37
+ PIECEWISE_CONSTANT = "piecewise_constant"
38
+
39
+
40
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR:
41
+ """
42
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
43
+
44
+ Args:
45
+ optimizer ([`~torch.optim.Optimizer`]):
46
+ The optimizer for which to schedule the learning rate.
47
+ last_epoch (`int`, *optional*, defaults to -1):
48
+ The index of the last epoch when resuming training.
49
+
50
+ Return:
51
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
52
+ """
53
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
54
+
55
+
56
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR:
57
+ """
58
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
59
+ increases linearly between 0 and the initial lr set in the optimizer.
60
+
61
+ Args:
62
+ optimizer ([`~torch.optim.Optimizer`]):
63
+ The optimizer for which to schedule the learning rate.
64
+ num_warmup_steps (`int`):
65
+ The number of steps for the warmup phase.
66
+ last_epoch (`int`, *optional*, defaults to -1):
67
+ The index of the last epoch when resuming training.
68
+
69
+ Return:
70
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
71
+ """
72
+
73
+ def lr_lambda(current_step: int):
74
+ if current_step < num_warmup_steps:
75
+ return float(current_step) / float(max(1.0, num_warmup_steps))
76
+ return 1.0
77
+
78
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
79
+
80
+
81
+ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR:
82
+ """
83
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
84
+
85
+ Args:
86
+ optimizer ([`~torch.optim.Optimizer`]):
87
+ The optimizer for which to schedule the learning rate.
88
+ step_rules (`string`):
89
+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90
+ if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91
+ steps and multiple 0.005 for the other steps.
92
+ last_epoch (`int`, *optional*, defaults to -1):
93
+ The index of the last epoch when resuming training.
94
+
95
+ Return:
96
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97
+ """
98
+
99
+ rules_dict = {}
100
+ rule_list = step_rules.split(",")
101
+ for rule_str in rule_list[:-1]:
102
+ value_str, steps_str = rule_str.split(":")
103
+ steps = int(steps_str)
104
+ value = float(value_str)
105
+ rules_dict[steps] = value
106
+ last_lr_multiple = float(rule_list[-1])
107
+
108
+ def create_rules_function(rules_dict, last_lr_multiple):
109
+ def rule_func(steps: int) -> float:
110
+ sorted_steps = sorted(rules_dict.keys())
111
+ for i, sorted_step in enumerate(sorted_steps):
112
+ if steps < sorted_step:
113
+ return rules_dict[sorted_steps[i]]
114
+ return last_lr_multiple
115
+
116
+ return rule_func
117
+
118
+ rules_func = create_rules_function(rules_dict, last_lr_multiple)
119
+
120
+ return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
121
+
122
+
123
+ def get_linear_schedule_with_warmup(
124
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
125
+ ) -> LambdaLR:
126
+ """
127
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
128
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
129
+
130
+ Args:
131
+ optimizer ([`~torch.optim.Optimizer`]):
132
+ The optimizer for which to schedule the learning rate.
133
+ num_warmup_steps (`int`):
134
+ The number of steps for the warmup phase.
135
+ num_training_steps (`int`):
136
+ The total number of training steps.
137
+ last_epoch (`int`, *optional*, defaults to -1):
138
+ The index of the last epoch when resuming training.
139
+
140
+ Return:
141
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
142
+ """
143
+
144
+ def lr_lambda(current_step: int):
145
+ if current_step < num_warmup_steps:
146
+ return float(current_step) / float(max(1, num_warmup_steps))
147
+ return max(
148
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
149
+ )
150
+
151
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
152
+
153
+
154
+ def get_cosine_schedule_with_warmup(
155
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
156
+ ) -> LambdaLR:
157
+ """
158
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
159
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
160
+ initial lr set in the optimizer.
161
+
162
+ Args:
163
+ optimizer ([`~torch.optim.Optimizer`]):
164
+ The optimizer for which to schedule the learning rate.
165
+ num_warmup_steps (`int`):
166
+ The number of steps for the warmup phase.
167
+ num_training_steps (`int`):
168
+ The total number of training steps.
169
+ num_periods (`float`, *optional*, defaults to 0.5):
170
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
171
+ value to 0 following a half-cosine).
172
+ last_epoch (`int`, *optional*, defaults to -1):
173
+ The index of the last epoch when resuming training.
174
+
175
+ Return:
176
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
177
+ """
178
+
179
+ def lr_lambda(current_step):
180
+ if current_step < num_warmup_steps:
181
+ return float(current_step) / float(max(1, num_warmup_steps))
182
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
183
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
184
+
185
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
186
+
187
+
188
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
189
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
190
+ ) -> LambdaLR:
191
+ """
192
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
193
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
194
+ linearly between 0 and the initial lr set in the optimizer.
195
+
196
+ Args:
197
+ optimizer ([`~torch.optim.Optimizer`]):
198
+ The optimizer for which to schedule the learning rate.
199
+ num_warmup_steps (`int`):
200
+ The number of steps for the warmup phase.
201
+ num_training_steps (`int`):
202
+ The total number of training steps.
203
+ num_cycles (`int`, *optional*, defaults to 1):
204
+ The number of hard restarts to use.
205
+ last_epoch (`int`, *optional*, defaults to -1):
206
+ The index of the last epoch when resuming training.
207
+
208
+ Return:
209
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
210
+ """
211
+
212
+ def lr_lambda(current_step):
213
+ if current_step < num_warmup_steps:
214
+ return float(current_step) / float(max(1, num_warmup_steps))
215
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
216
+ if progress >= 1.0:
217
+ return 0.0
218
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
219
+
220
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
221
+
222
+
223
+ def get_polynomial_decay_schedule_with_warmup(
224
+ optimizer: Optimizer,
225
+ num_warmup_steps: int,
226
+ num_training_steps: int,
227
+ lr_end: float = 1e-7,
228
+ power: float = 1.0,
229
+ last_epoch: int = -1,
230
+ ) -> LambdaLR:
231
+ """
232
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
233
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
234
+ initial lr set in the optimizer.
235
+
236
+ Args:
237
+ optimizer ([`~torch.optim.Optimizer`]):
238
+ The optimizer for which to schedule the learning rate.
239
+ num_warmup_steps (`int`):
240
+ The number of steps for the warmup phase.
241
+ num_training_steps (`int`):
242
+ The total number of training steps.
243
+ lr_end (`float`, *optional*, defaults to 1e-7):
244
+ The end LR.
245
+ power (`float`, *optional*, defaults to 1.0):
246
+ Power factor.
247
+ last_epoch (`int`, *optional*, defaults to -1):
248
+ The index of the last epoch when resuming training.
249
+
250
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
251
+ implementation at
252
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
253
+
254
+ Return:
255
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
256
+
257
+ """
258
+
259
+ lr_init = optimizer.defaults["lr"]
260
+ if not (lr_init > lr_end):
261
+ raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
262
+
263
+ def lr_lambda(current_step: int):
264
+ if current_step < num_warmup_steps:
265
+ return float(current_step) / float(max(1, num_warmup_steps))
266
+ elif current_step > num_training_steps:
267
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
268
+ else:
269
+ lr_range = lr_init - lr_end
270
+ decay_steps = num_training_steps - num_warmup_steps
271
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
272
+ decay = lr_range * pct_remaining**power + lr_end
273
+ return decay / lr_init # as LambdaLR multiplies by lr_init
274
+
275
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
276
+
277
+
278
+ TYPE_TO_SCHEDULER_FUNCTION = {
279
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
280
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
281
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
282
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
283
+ SchedulerType.CONSTANT: get_constant_schedule,
284
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
285
+ SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
286
+ }
287
+
288
+
289
+ def get_scheduler(
290
+ name: Union[str, SchedulerType],
291
+ optimizer: Optimizer,
292
+ step_rules: Optional[str] = None,
293
+ num_warmup_steps: Optional[int] = None,
294
+ num_training_steps: Optional[int] = None,
295
+ num_cycles: int = 1,
296
+ power: float = 1.0,
297
+ last_epoch: int = -1,
298
+ ) -> LambdaLR:
299
+ """
300
+ Unified API to get any scheduler from its name.
301
+
302
+ Args:
303
+ name (`str` or `SchedulerType`):
304
+ The name of the scheduler to use.
305
+ optimizer (`torch.optim.Optimizer`):
306
+ The optimizer that will be used during training.
307
+ step_rules (`str`, *optional*):
308
+ A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
309
+ num_warmup_steps (`int`, *optional*):
310
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
311
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
312
+ num_training_steps (`int``, *optional*):
313
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
314
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
315
+ num_cycles (`int`, *optional*):
316
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
317
+ power (`float`, *optional*, defaults to 1.0):
318
+ Power factor. See `POLYNOMIAL` scheduler
319
+ last_epoch (`int`, *optional*, defaults to -1):
320
+ The index of the last epoch when resuming training.
321
+ """
322
+ name = SchedulerType(name)
323
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
324
+ if name == SchedulerType.CONSTANT:
325
+ return schedule_func(optimizer, last_epoch=last_epoch)
326
+
327
+ if name == SchedulerType.PIECEWISE_CONSTANT:
328
+ return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
329
+
330
+ # All other schedulers require `num_warmup_steps`
331
+ if num_warmup_steps is None:
332
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
333
+
334
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
335
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
336
+
337
+ # All other schedulers require `num_training_steps`
338
+ if num_training_steps is None:
339
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
340
+
341
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
342
+ return schedule_func(
343
+ optimizer,
344
+ num_warmup_steps=num_warmup_steps,
345
+ num_training_steps=num_training_steps,
346
+ num_cycles=num_cycles,
347
+ last_epoch=last_epoch,
348
+ )
349
+
350
+ if name == SchedulerType.POLYNOMIAL:
351
+ return schedule_func(
352
+ optimizer,
353
+ num_warmup_steps=num_warmup_steps,
354
+ num_training_steps=num_training_steps,
355
+ power=power,
356
+ last_epoch=last_epoch,
357
+ )
358
+
359
+ return schedule_func(
360
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
361
+ )
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_normals.cpython-310.pyc ADDED
Binary file (22.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/py.typed ADDED
File without changes
pythonProject/.venv/Lib/site-packages/diffusers/training_utils.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import copy
3
+ import gc
4
+ import math
5
+ import random
6
+ import re
7
+ import warnings
8
+ from contextlib import contextmanager
9
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from .models import UNet2DConditionModel
15
+ from .pipelines import DiffusionPipeline
16
+ from .schedulers import SchedulerMixin
17
+ from .utils import (
18
+ convert_state_dict_to_diffusers,
19
+ convert_state_dict_to_peft,
20
+ deprecate,
21
+ is_peft_available,
22
+ is_torch_npu_available,
23
+ is_torchvision_available,
24
+ is_transformers_available,
25
+ )
26
+
27
+
28
+ if is_transformers_available():
29
+ import transformers
30
+
31
+ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
32
+ import deepspeed
33
+
34
+ if is_peft_available():
35
+ from peft import set_peft_model_state_dict
36
+
37
+ if is_torchvision_available():
38
+ from torchvision import transforms
39
+
40
+ if is_torch_npu_available():
41
+ import torch_npu # noqa: F401
42
+
43
+
44
+ def set_seed(seed: int):
45
+ """
46
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
47
+
48
+ Args:
49
+ seed (`int`): The seed to set.
50
+
51
+ Returns:
52
+ `None`
53
+ """
54
+ random.seed(seed)
55
+ np.random.seed(seed)
56
+ torch.manual_seed(seed)
57
+ if is_torch_npu_available():
58
+ torch.npu.manual_seed_all(seed)
59
+ else:
60
+ torch.cuda.manual_seed_all(seed)
61
+ # ^^ safe to call this function even if cuda is not available
62
+
63
+
64
+ def compute_snr(noise_scheduler, timesteps):
65
+ """
66
+ Computes SNR as per
67
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
68
+ for the given timesteps using the provided noise scheduler.
69
+
70
+ Args:
71
+ noise_scheduler (`NoiseScheduler`):
72
+ An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
73
+ the SNR values.
74
+ timesteps (`torch.Tensor`):
75
+ A tensor of timesteps for which the SNR is computed.
76
+
77
+ Returns:
78
+ `torch.Tensor`: A tensor containing the computed SNR values for each timestep.
79
+ """
80
+ alphas_cumprod = noise_scheduler.alphas_cumprod
81
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
82
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
83
+
84
+ # Expand the tensors.
85
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
86
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
87
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
88
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
89
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
90
+
91
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
92
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
93
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
94
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
95
+
96
+ # Compute SNR.
97
+ snr = (alpha / sigma) ** 2
98
+ return snr
99
+
100
+
101
+ def resolve_interpolation_mode(interpolation_type: str):
102
+ """
103
+ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
104
+ full list of supported enums is documented at
105
+ https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
106
+
107
+ Args:
108
+ interpolation_type (`str`):
109
+ A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
110
+ `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
111
+ in torchvision.
112
+
113
+ Returns:
114
+ `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
115
+ transform.
116
+ """
117
+ if not is_torchvision_available():
118
+ raise ImportError(
119
+ "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
120
+ )
121
+
122
+ if interpolation_type == "bilinear":
123
+ interpolation_mode = transforms.InterpolationMode.BILINEAR
124
+ elif interpolation_type == "bicubic":
125
+ interpolation_mode = transforms.InterpolationMode.BICUBIC
126
+ elif interpolation_type == "box":
127
+ interpolation_mode = transforms.InterpolationMode.BOX
128
+ elif interpolation_type == "nearest":
129
+ interpolation_mode = transforms.InterpolationMode.NEAREST
130
+ elif interpolation_type == "nearest_exact":
131
+ interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
132
+ elif interpolation_type == "hamming":
133
+ interpolation_mode = transforms.InterpolationMode.HAMMING
134
+ elif interpolation_type == "lanczos":
135
+ interpolation_mode = transforms.InterpolationMode.LANCZOS
136
+ else:
137
+ raise ValueError(
138
+ f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
139
+ f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
140
+ )
141
+
142
+ return interpolation_mode
143
+
144
+
145
+ def compute_dream_and_update_latents(
146
+ unet: UNet2DConditionModel,
147
+ noise_scheduler: SchedulerMixin,
148
+ timesteps: torch.Tensor,
149
+ noise: torch.Tensor,
150
+ noisy_latents: torch.Tensor,
151
+ target: torch.Tensor,
152
+ encoder_hidden_states: torch.Tensor,
153
+ dream_detail_preservation: float = 1.0,
154
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
155
+ """
156
+ Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
157
+ https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
158
+ efficient and accurate at the cost of an extra forward step without gradients.
159
+
160
+ Args:
161
+ `unet`: The state unet to use to make a prediction.
162
+ `noise_scheduler`: The noise scheduler used to add noise for the given timestep.
163
+ `timesteps`: The timesteps for the noise_scheduler to user.
164
+ `noise`: A tensor of noise in the shape of noisy_latents.
165
+ `noisy_latents`: Previously noise latents from the training loop.
166
+ `target`: The ground-truth tensor to predict after eps is removed.
167
+ `encoder_hidden_states`: Text embeddings from the text model.
168
+ `dream_detail_preservation`: A float value that indicates detail preservation level.
169
+ See reference.
170
+
171
+ Returns:
172
+ `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
173
+ """
174
+ alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
175
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
176
+
177
+ # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
178
+ dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
179
+
180
+ pred = None
181
+ with torch.no_grad():
182
+ pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
183
+
184
+ _noisy_latents, _target = (None, None)
185
+ if noise_scheduler.config.prediction_type == "epsilon":
186
+ predicted_noise = pred
187
+ delta_noise = (noise - predicted_noise).detach()
188
+ delta_noise.mul_(dream_lambda)
189
+ _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
190
+ _target = target.add(delta_noise)
191
+ elif noise_scheduler.config.prediction_type == "v_prediction":
192
+ raise NotImplementedError("DREAM has not been implemented for v-prediction")
193
+ else:
194
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
195
+
196
+ return _noisy_latents, _target
197
+
198
+
199
+ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
200
+ r"""
201
+ Returns:
202
+ A state dict containing just the LoRA parameters.
203
+ """
204
+ lora_state_dict = {}
205
+
206
+ for name, module in unet.named_modules():
207
+ if hasattr(module, "set_lora_layer"):
208
+ lora_layer = getattr(module, "lora_layer")
209
+ if lora_layer is not None:
210
+ current_lora_layer_sd = lora_layer.state_dict()
211
+ for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
212
+ # The matrix name can either be "down" or "up".
213
+ lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
214
+
215
+ return lora_state_dict
216
+
217
+
218
+ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
219
+ """
220
+ Casts the training parameters of the model to the specified data type.
221
+
222
+ Args:
223
+ model: The PyTorch model whose parameters will be cast.
224
+ dtype: The data type to which the model parameters will be cast.
225
+ """
226
+ if not isinstance(model, list):
227
+ model = [model]
228
+ for m in model:
229
+ for param in m.parameters():
230
+ # only upcast trainable parameters into fp32
231
+ if param.requires_grad:
232
+ param.data = param.to(dtype)
233
+
234
+
235
+ def _set_state_dict_into_text_encoder(
236
+ lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
237
+ ):
238
+ """
239
+ Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
240
+
241
+ Args:
242
+ lora_state_dict: The state dictionary to be set.
243
+ prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
244
+ text_encoder: Where the `lora_state_dict` is to be set.
245
+ """
246
+
247
+ text_encoder_state_dict = {
248
+ f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
249
+ }
250
+ text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
251
+ set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
252
+
253
+
254
+ def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
255
+ metadatas = {}
256
+ for module_name, module in modules_to_save.items():
257
+ if module is not None:
258
+ metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
259
+ return metadatas
260
+
261
+
262
+ def compute_density_for_timestep_sampling(
263
+ weighting_scheme: str,
264
+ batch_size: int,
265
+ logit_mean: float = None,
266
+ logit_std: float = None,
267
+ mode_scale: float = None,
268
+ device: Union[torch.device, str] = "cpu",
269
+ generator: Optional[torch.Generator] = None,
270
+ ):
271
+ """
272
+ Compute the density for sampling the timesteps when doing SD3 training.
273
+
274
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
275
+
276
+ SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
277
+ """
278
+ if weighting_scheme == "logit_normal":
279
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
280
+ u = torch.nn.functional.sigmoid(u)
281
+ elif weighting_scheme == "mode":
282
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
283
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
284
+ else:
285
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
286
+ return u
287
+
288
+
289
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
290
+ """
291
+ Computes loss weighting scheme for SD3 training.
292
+
293
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
294
+
295
+ SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
296
+ """
297
+ if weighting_scheme == "sigma_sqrt":
298
+ weighting = (sigmas**-2.0).float()
299
+ elif weighting_scheme == "cosmap":
300
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
301
+ weighting = 2 / (math.pi * bot)
302
+ else:
303
+ weighting = torch.ones_like(sigmas)
304
+ return weighting
305
+
306
+
307
+ def free_memory():
308
+ """
309
+ Runs garbage collection. Then clears the cache of the available accelerator.
310
+ """
311
+ gc.collect()
312
+
313
+ if torch.cuda.is_available():
314
+ torch.cuda.empty_cache()
315
+ elif torch.backends.mps.is_available():
316
+ torch.mps.empty_cache()
317
+ elif is_torch_npu_available():
318
+ torch_npu.npu.empty_cache()
319
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
320
+ torch.xpu.empty_cache()
321
+
322
+
323
+ @contextmanager
324
+ def offload_models(
325
+ *modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
326
+ ):
327
+ """
328
+ Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
329
+ device on exit.
330
+
331
+ Args:
332
+ device (`str` or `torch.Device`): Device to move the `modules` to.
333
+ offload (`bool`): Flag to enable offloading.
334
+ """
335
+ if offload:
336
+ is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
337
+ # record where each module was
338
+ if is_model:
339
+ original_devices = [next(m.parameters()).device for m in modules]
340
+ else:
341
+ assert len(modules) == 1
342
+ # For DiffusionPipeline, wrap the device in a list to make it iterable
343
+ original_devices = [modules[0].device]
344
+ # move to target device
345
+ for m in modules:
346
+ m.to(device)
347
+
348
+ try:
349
+ yield
350
+ finally:
351
+ if offload:
352
+ # move back to original devices
353
+ for m, orig_dev in zip(modules, original_devices):
354
+ m.to(orig_dev)
355
+
356
+
357
+ def parse_buckets_string(buckets_str):
358
+ """Parses a string defining buckets into a list of (height, width) tuples."""
359
+ if not buckets_str:
360
+ raise ValueError("Bucket string cannot be empty.")
361
+
362
+ bucket_pairs = buckets_str.strip().split(";")
363
+ parsed_buckets = []
364
+ for pair_str in bucket_pairs:
365
+ match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
366
+ if not match:
367
+ raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
368
+ try:
369
+ height = int(match.group(1))
370
+ width = int(match.group(2))
371
+ if height <= 0 or width <= 0:
372
+ raise ValueError("Bucket dimensions must be positive integers.")
373
+ if height % 8 != 0 or width % 8 != 0:
374
+ warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
375
+ parsed_buckets.append((height, width))
376
+ except ValueError as e:
377
+ raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e
378
+
379
+ if not parsed_buckets:
380
+ raise ValueError("No valid buckets found in the provided string.")
381
+
382
+ return parsed_buckets
383
+
384
+
385
+ def find_nearest_bucket(h, w, bucket_options):
386
+ """Finds the closes bucket to the given height and width."""
387
+ min_metric = float("inf")
388
+ best_bucket_idx = None
389
+ for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
390
+ metric = abs(h * bucket_w - w * bucket_h)
391
+ if metric <= min_metric:
392
+ min_metric = metric
393
+ best_bucket_idx = bucket_idx
394
+ return best_bucket_idx
395
+
396
+
397
+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
398
+ class EMAModel:
399
+ """
400
+ Exponential Moving Average of models weights
401
+ """
402
+
403
+ def __init__(
404
+ self,
405
+ parameters: Iterable[torch.nn.Parameter],
406
+ decay: float = 0.9999,
407
+ min_decay: float = 0.0,
408
+ update_after_step: int = 0,
409
+ use_ema_warmup: bool = False,
410
+ inv_gamma: Union[float, int] = 1.0,
411
+ power: Union[float, int] = 2 / 3,
412
+ foreach: bool = False,
413
+ model_cls: Optional[Any] = None,
414
+ model_config: Dict[str, Any] = None,
415
+ **kwargs,
416
+ ):
417
+ """
418
+ Args:
419
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
420
+ decay (float): The decay factor for the exponential moving average.
421
+ min_decay (float): The minimum decay factor for the exponential moving average.
422
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
423
+ use_ema_warmup (bool): Whether to use EMA warmup.
424
+ inv_gamma (float):
425
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
426
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
427
+ foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
428
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
429
+ weights will be stored on CPU.
430
+
431
+ @crowsonkb's notes on EMA Warmup:
432
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
433
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
434
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
435
+ at 215.4k steps).
436
+ """
437
+
438
+ if isinstance(parameters, torch.nn.Module):
439
+ deprecation_message = (
440
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
441
+ "Please pass the parameters of the module instead."
442
+ )
443
+ deprecate(
444
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
445
+ "1.0.0",
446
+ deprecation_message,
447
+ standard_warn=False,
448
+ )
449
+ parameters = parameters.parameters()
450
+
451
+ # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
452
+ use_ema_warmup = True
453
+
454
+ if kwargs.get("max_value", None) is not None:
455
+ deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
456
+ deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
457
+ decay = kwargs["max_value"]
458
+
459
+ if kwargs.get("min_value", None) is not None:
460
+ deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
461
+ deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
462
+ min_decay = kwargs["min_value"]
463
+
464
+ parameters = list(parameters)
465
+ self.shadow_params = [p.clone().detach() for p in parameters]
466
+
467
+ if kwargs.get("device", None) is not None:
468
+ deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
469
+ deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
470
+ self.to(device=kwargs["device"])
471
+
472
+ self.temp_stored_params = None
473
+
474
+ self.decay = decay
475
+ self.min_decay = min_decay
476
+ self.update_after_step = update_after_step
477
+ self.use_ema_warmup = use_ema_warmup
478
+ self.inv_gamma = inv_gamma
479
+ self.power = power
480
+ self.optimization_step = 0
481
+ self.cur_decay_value = None # set in `step()`
482
+ self.foreach = foreach
483
+
484
+ self.model_cls = model_cls
485
+ self.model_config = model_config
486
+
487
+ @classmethod
488
+ def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
489
+ _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
490
+ model = model_cls.from_pretrained(path)
491
+
492
+ ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
493
+
494
+ ema_model.load_state_dict(ema_kwargs)
495
+ return ema_model
496
+
497
+ def save_pretrained(self, path):
498
+ if self.model_cls is None:
499
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
500
+
501
+ if self.model_config is None:
502
+ raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
503
+
504
+ model = self.model_cls.from_config(self.model_config)
505
+ state_dict = self.state_dict()
506
+ state_dict.pop("shadow_params", None)
507
+
508
+ model.register_to_config(**state_dict)
509
+ self.copy_to(model.parameters())
510
+ model.save_pretrained(path)
511
+
512
+ def get_decay(self, optimization_step: int) -> float:
513
+ """
514
+ Compute the decay factor for the exponential moving average.
515
+ """
516
+ step = max(0, optimization_step - self.update_after_step - 1)
517
+
518
+ if step <= 0:
519
+ return 0.0
520
+
521
+ if self.use_ema_warmup:
522
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
523
+ else:
524
+ cur_decay_value = (1 + step) / (10 + step)
525
+
526
+ cur_decay_value = min(cur_decay_value, self.decay)
527
+ # make sure decay is not smaller than min_decay
528
+ cur_decay_value = max(cur_decay_value, self.min_decay)
529
+ return cur_decay_value
530
+
531
+ @torch.no_grad()
532
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
533
+ if isinstance(parameters, torch.nn.Module):
534
+ deprecation_message = (
535
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
536
+ "Please pass the parameters of the module instead."
537
+ )
538
+ deprecate(
539
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
540
+ "1.0.0",
541
+ deprecation_message,
542
+ standard_warn=False,
543
+ )
544
+ parameters = parameters.parameters()
545
+
546
+ parameters = list(parameters)
547
+
548
+ self.optimization_step += 1
549
+
550
+ # Compute the decay factor for the exponential moving average.
551
+ decay = self.get_decay(self.optimization_step)
552
+ self.cur_decay_value = decay
553
+ one_minus_decay = 1 - decay
554
+
555
+ context_manager = contextlib.nullcontext()
556
+
557
+ if self.foreach:
558
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
559
+ context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
560
+
561
+ with context_manager:
562
+ params_grad = [param for param in parameters if param.requires_grad]
563
+ s_params_grad = [
564
+ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
565
+ ]
566
+
567
+ if len(params_grad) < len(parameters):
568
+ torch._foreach_copy_(
569
+ [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
570
+ [param for param in parameters if not param.requires_grad],
571
+ non_blocking=True,
572
+ )
573
+
574
+ torch._foreach_sub_(
575
+ s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
576
+ )
577
+
578
+ else:
579
+ for s_param, param in zip(self.shadow_params, parameters):
580
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
581
+ context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
582
+
583
+ with context_manager:
584
+ if param.requires_grad:
585
+ s_param.sub_(one_minus_decay * (s_param - param))
586
+ else:
587
+ s_param.copy_(param)
588
+
589
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
590
+ """
591
+ Copy current averaged parameters into given collection of parameters.
592
+
593
+ Args:
594
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
595
+ updated with the stored moving averages. If `None`, the parameters with which this
596
+ `ExponentialMovingAverage` was initialized will be used.
597
+ """
598
+ parameters = list(parameters)
599
+ if self.foreach:
600
+ torch._foreach_copy_(
601
+ [param.data for param in parameters],
602
+ [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
603
+ )
604
+ else:
605
+ for s_param, param in zip(self.shadow_params, parameters):
606
+ param.data.copy_(s_param.to(param.device).data)
607
+
608
+ def pin_memory(self) -> None:
609
+ r"""
610
+ Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
611
+ offloading EMA params to the host.
612
+ """
613
+
614
+ self.shadow_params = [p.pin_memory() for p in self.shadow_params]
615
+
616
+ def to(self, device=None, dtype=None, non_blocking=False) -> None:
617
+ r"""
618
+ Move internal buffers of the ExponentialMovingAverage to `device`.
619
+
620
+ Args:
621
+ device: like `device` argument to `torch.Tensor.to`
622
+ """
623
+ # .to() on the tensors handles None correctly
624
+ self.shadow_params = [
625
+ p.to(device=device, dtype=dtype, non_blocking=non_blocking)
626
+ if p.is_floating_point()
627
+ else p.to(device=device, non_blocking=non_blocking)
628
+ for p in self.shadow_params
629
+ ]
630
+
631
+ def state_dict(self) -> dict:
632
+ r"""
633
+ Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
634
+ checkpointing to save the ema state dict.
635
+ """
636
+ # Following PyTorch conventions, references to tensors are returned:
637
+ # "returns a reference to the state and not its copy!" -
638
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
639
+ return {
640
+ "decay": self.decay,
641
+ "min_decay": self.min_decay,
642
+ "optimization_step": self.optimization_step,
643
+ "update_after_step": self.update_after_step,
644
+ "use_ema_warmup": self.use_ema_warmup,
645
+ "inv_gamma": self.inv_gamma,
646
+ "power": self.power,
647
+ "shadow_params": self.shadow_params,
648
+ }
649
+
650
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
651
+ r"""
652
+ Saves the current parameters for restoring later.
653
+
654
+ Args:
655
+ parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
656
+ """
657
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
658
+
659
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
660
+ r"""
661
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
662
+ without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
663
+ validation (or model saving), use this to restore the former parameters.
664
+
665
+ Args:
666
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
667
+ updated with the stored parameters. If `None`, the parameters with which this
668
+ `ExponentialMovingAverage` was initialized will be used.
669
+ """
670
+
671
+ if self.temp_stored_params is None:
672
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
673
+ if self.foreach:
674
+ torch._foreach_copy_(
675
+ [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
676
+ )
677
+ else:
678
+ for c_param, param in zip(self.temp_stored_params, parameters):
679
+ param.data.copy_(c_param.data)
680
+
681
+ # Better memory-wise.
682
+ self.temp_stored_params = None
683
+
684
+ def load_state_dict(self, state_dict: dict) -> None:
685
+ r"""
686
+ Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
687
+ ema state dict.
688
+
689
+ Args:
690
+ state_dict (dict): EMA state. Should be an object returned
691
+ from a call to :meth:`state_dict`.
692
+ """
693
+ # deepcopy, to be consistent with module API
694
+ state_dict = copy.deepcopy(state_dict)
695
+
696
+ self.decay = state_dict.get("decay", self.decay)
697
+ if self.decay < 0.0 or self.decay > 1.0:
698
+ raise ValueError("Decay must be between 0 and 1")
699
+
700
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
701
+ if not isinstance(self.min_decay, float):
702
+ raise ValueError("Invalid min_decay")
703
+
704
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
705
+ if not isinstance(self.optimization_step, int):
706
+ raise ValueError("Invalid optimization_step")
707
+
708
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
709
+ if not isinstance(self.update_after_step, int):
710
+ raise ValueError("Invalid update_after_step")
711
+
712
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
713
+ if not isinstance(self.use_ema_warmup, bool):
714
+ raise ValueError("Invalid use_ema_warmup")
715
+
716
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
717
+ if not isinstance(self.inv_gamma, (float, int)):
718
+ raise ValueError("Invalid inv_gamma")
719
+
720
+ self.power = state_dict.get("power", self.power)
721
+ if not isinstance(self.power, (float, int)):
722
+ raise ValueError("Invalid power")
723
+
724
+ shadow_params = state_dict.get("shadow_params", None)
725
+ if shadow_params is not None:
726
+ self.shadow_params = shadow_params
727
+ if not isinstance(self.shadow_params, list):
728
+ raise ValueError("shadow_params must be a list")
729
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
730
+ raise ValueError("shadow_params must all be Tensors")