xiaoanyu123 commited on
Commit
e77236c
·
verified ·
1 Parent(s): 6698547

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/diffusers-main/build/lib/diffusers/__init__.py +1384 -0
  2. pythonProject/diffusers-main/build/lib/diffusers/commands/__init__.py +27 -0
  3. pythonProject/diffusers-main/build/lib/diffusers/commands/custom_blocks.py +134 -0
  4. pythonProject/diffusers-main/build/lib/diffusers/commands/diffusers_cli.py +45 -0
  5. pythonProject/diffusers-main/build/lib/diffusers/commands/env.py +180 -0
  6. pythonProject/diffusers-main/build/lib/diffusers/commands/fp16_safetensors.py +132 -0
  7. pythonProject/diffusers-main/build/lib/diffusers/experimental/__init__.py +1 -0
  8. pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/__init__.py +1 -0
  9. pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/value_guided_sampling.py +153 -0
  10. pythonProject/diffusers-main/build/lib/diffusers/guiders/adaptive_projected_guidance.py +188 -0
  11. pythonProject/diffusers-main/build/lib/diffusers/guiders/auto_guidance.py +190 -0
  12. pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_guidance.py +141 -0
  13. pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  14. pythonProject/diffusers-main/build/lib/diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  15. pythonProject/diffusers-main/build/lib/diffusers/guiders/guider_utils.py +315 -0
  16. pythonProject/diffusers-main/build/lib/diffusers/guiders/perturbed_attention_guidance.py +271 -0
  17. pythonProject/diffusers-main/build/lib/diffusers/guiders/skip_layer_guidance.py +262 -0
  18. pythonProject/diffusers-main/build/lib/diffusers/guiders/smoothed_energy_guidance.py +251 -0
  19. pythonProject/diffusers-main/build/lib/diffusers/training_utils.py +730 -0
  20. pythonProject/diffusers-main/build/lib/diffusers/video_processor.py +113 -0
pythonProject/diffusers-main/build/lib/diffusers/__init__.py ADDED
@@ -0,0 +1,1384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.36.0.dev0"
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .utils import (
6
+ DIFFUSERS_SLOW_IMPORT,
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_accelerate_available,
10
+ is_bitsandbytes_available,
11
+ is_flax_available,
12
+ is_gguf_available,
13
+ is_k_diffusion_available,
14
+ is_librosa_available,
15
+ is_note_seq_available,
16
+ is_nvidia_modelopt_available,
17
+ is_onnx_available,
18
+ is_opencv_available,
19
+ is_optimum_quanto_available,
20
+ is_scipy_available,
21
+ is_sentencepiece_available,
22
+ is_torch_available,
23
+ is_torchao_available,
24
+ is_torchsde_available,
25
+ is_transformers_available,
26
+ )
27
+
28
+
29
+ # Lazy Import based on
30
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
31
+
32
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
33
+ # and is used to defer the actual importing for when the objects are requested.
34
+ # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
35
+
36
+ _import_structure = {
37
+ "configuration_utils": ["ConfigMixin"],
38
+ "guiders": [],
39
+ "hooks": [],
40
+ "loaders": ["FromOriginalModelMixin"],
41
+ "models": [],
42
+ "modular_pipelines": [],
43
+ "pipelines": [],
44
+ "quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
45
+ "quantizers.quantization_config": [],
46
+ "schedulers": [],
47
+ "utils": [
48
+ "OptionalDependencyNotAvailable",
49
+ "is_flax_available",
50
+ "is_inflect_available",
51
+ "is_invisible_watermark_available",
52
+ "is_k_diffusion_available",
53
+ "is_k_diffusion_version",
54
+ "is_librosa_available",
55
+ "is_note_seq_available",
56
+ "is_onnx_available",
57
+ "is_scipy_available",
58
+ "is_torch_available",
59
+ "is_torchsde_available",
60
+ "is_transformers_available",
61
+ "is_transformers_version",
62
+ "is_unidecode_available",
63
+ "logging",
64
+ ],
65
+ }
66
+
67
+ try:
68
+ if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
69
+ raise OptionalDependencyNotAvailable()
70
+ except OptionalDependencyNotAvailable:
71
+ from .utils import dummy_bitsandbytes_objects
72
+
73
+ _import_structure["utils.dummy_bitsandbytes_objects"] = [
74
+ name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
75
+ ]
76
+ else:
77
+ _import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
78
+
79
+ try:
80
+ if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
81
+ raise OptionalDependencyNotAvailable()
82
+ except OptionalDependencyNotAvailable:
83
+ from .utils import dummy_gguf_objects
84
+
85
+ _import_structure["utils.dummy_gguf_objects"] = [
86
+ name for name in dir(dummy_gguf_objects) if not name.startswith("_")
87
+ ]
88
+ else:
89
+ _import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
90
+
91
+ try:
92
+ if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
93
+ raise OptionalDependencyNotAvailable()
94
+ except OptionalDependencyNotAvailable:
95
+ from .utils import dummy_torchao_objects
96
+
97
+ _import_structure["utils.dummy_torchao_objects"] = [
98
+ name for name in dir(dummy_torchao_objects) if not name.startswith("_")
99
+ ]
100
+ else:
101
+ _import_structure["quantizers.quantization_config"].append("TorchAoConfig")
102
+
103
+ try:
104
+ if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
105
+ raise OptionalDependencyNotAvailable()
106
+ except OptionalDependencyNotAvailable:
107
+ from .utils import dummy_optimum_quanto_objects
108
+
109
+ _import_structure["utils.dummy_optimum_quanto_objects"] = [
110
+ name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
111
+ ]
112
+ else:
113
+ _import_structure["quantizers.quantization_config"].append("QuantoConfig")
114
+
115
+ try:
116
+ if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
117
+ raise OptionalDependencyNotAvailable()
118
+ except OptionalDependencyNotAvailable:
119
+ from .utils import dummy_nvidia_modelopt_objects
120
+
121
+ _import_structure["utils.dummy_nvidia_modelopt_objects"] = [
122
+ name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
123
+ ]
124
+ else:
125
+ _import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")
126
+
127
+ try:
128
+ if not is_onnx_available():
129
+ raise OptionalDependencyNotAvailable()
130
+ except OptionalDependencyNotAvailable:
131
+ from .utils import dummy_onnx_objects # noqa F403
132
+
133
+ _import_structure["utils.dummy_onnx_objects"] = [
134
+ name for name in dir(dummy_onnx_objects) if not name.startswith("_")
135
+ ]
136
+
137
+ else:
138
+ _import_structure["pipelines"].extend(["OnnxRuntimeModel"])
139
+
140
+ try:
141
+ if not is_torch_available():
142
+ raise OptionalDependencyNotAvailable()
143
+ except OptionalDependencyNotAvailable:
144
+ from .utils import dummy_pt_objects # noqa F403
145
+
146
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
147
+
148
+ else:
149
+ _import_structure["guiders"].extend(
150
+ [
151
+ "AdaptiveProjectedGuidance",
152
+ "AutoGuidance",
153
+ "ClassifierFreeGuidance",
154
+ "ClassifierFreeZeroStarGuidance",
155
+ "FrequencyDecoupledGuidance",
156
+ "PerturbedAttentionGuidance",
157
+ "SkipLayerGuidance",
158
+ "SmoothedEnergyGuidance",
159
+ "TangentialClassifierFreeGuidance",
160
+ ]
161
+ )
162
+ _import_structure["hooks"].extend(
163
+ [
164
+ "FasterCacheConfig",
165
+ "FirstBlockCacheConfig",
166
+ "HookRegistry",
167
+ "LayerSkipConfig",
168
+ "PyramidAttentionBroadcastConfig",
169
+ "SmoothedEnergyGuidanceConfig",
170
+ "apply_faster_cache",
171
+ "apply_first_block_cache",
172
+ "apply_layer_skip",
173
+ "apply_pyramid_attention_broadcast",
174
+ ]
175
+ )
176
+ _import_structure["models"].extend(
177
+ [
178
+ "AllegroTransformer3DModel",
179
+ "AsymmetricAutoencoderKL",
180
+ "AttentionBackendName",
181
+ "AuraFlowTransformer2DModel",
182
+ "AutoencoderDC",
183
+ "AutoencoderKL",
184
+ "AutoencoderKLAllegro",
185
+ "AutoencoderKLCogVideoX",
186
+ "AutoencoderKLCosmos",
187
+ "AutoencoderKLHunyuanVideo",
188
+ "AutoencoderKLLTXVideo",
189
+ "AutoencoderKLMagvit",
190
+ "AutoencoderKLMochi",
191
+ "AutoencoderKLQwenImage",
192
+ "AutoencoderKLTemporalDecoder",
193
+ "AutoencoderKLWan",
194
+ "AutoencoderOobleck",
195
+ "AutoencoderTiny",
196
+ "AutoModel",
197
+ "BriaTransformer2DModel",
198
+ "CacheMixin",
199
+ "ChromaTransformer2DModel",
200
+ "CogVideoXTransformer3DModel",
201
+ "CogView3PlusTransformer2DModel",
202
+ "CogView4Transformer2DModel",
203
+ "ConsisIDTransformer3DModel",
204
+ "ConsistencyDecoderVAE",
205
+ "ControlNetModel",
206
+ "ControlNetUnionModel",
207
+ "ControlNetXSAdapter",
208
+ "CosmosTransformer3DModel",
209
+ "DiTTransformer2DModel",
210
+ "EasyAnimateTransformer3DModel",
211
+ "FluxControlNetModel",
212
+ "FluxMultiControlNetModel",
213
+ "FluxTransformer2DModel",
214
+ "HiDreamImageTransformer2DModel",
215
+ "HunyuanDiT2DControlNetModel",
216
+ "HunyuanDiT2DModel",
217
+ "HunyuanDiT2DMultiControlNetModel",
218
+ "HunyuanVideoFramepackTransformer3DModel",
219
+ "HunyuanVideoTransformer3DModel",
220
+ "I2VGenXLUNet",
221
+ "Kandinsky3UNet",
222
+ "LatteTransformer3DModel",
223
+ "LTXVideoTransformer3DModel",
224
+ "Lumina2Transformer2DModel",
225
+ "LuminaNextDiT2DModel",
226
+ "MochiTransformer3DModel",
227
+ "ModelMixin",
228
+ "MotionAdapter",
229
+ "MultiAdapter",
230
+ "MultiControlNetModel",
231
+ "OmniGenTransformer2DModel",
232
+ "PixArtTransformer2DModel",
233
+ "PriorTransformer",
234
+ "QwenImageControlNetModel",
235
+ "QwenImageMultiControlNetModel",
236
+ "QwenImageTransformer2DModel",
237
+ "SanaControlNetModel",
238
+ "SanaTransformer2DModel",
239
+ "SD3ControlNetModel",
240
+ "SD3MultiControlNetModel",
241
+ "SD3Transformer2DModel",
242
+ "SkyReelsV2Transformer3DModel",
243
+ "SparseControlNetModel",
244
+ "StableAudioDiTModel",
245
+ "StableCascadeUNet",
246
+ "T2IAdapter",
247
+ "T5FilmDecoder",
248
+ "Transformer2DModel",
249
+ "TransformerTemporalModel",
250
+ "UNet1DModel",
251
+ "UNet2DConditionModel",
252
+ "UNet2DModel",
253
+ "UNet3DConditionModel",
254
+ "UNetControlNetXSModel",
255
+ "UNetMotionModel",
256
+ "UNetSpatioTemporalConditionModel",
257
+ "UVit2DModel",
258
+ "VQModel",
259
+ "WanTransformer3DModel",
260
+ "WanVACETransformer3DModel",
261
+ "attention_backend",
262
+ ]
263
+ )
264
+ _import_structure["modular_pipelines"].extend(
265
+ [
266
+ "ComponentsManager",
267
+ "ComponentSpec",
268
+ "ModularPipeline",
269
+ "ModularPipelineBlocks",
270
+ ]
271
+ )
272
+ _import_structure["optimization"] = [
273
+ "get_constant_schedule",
274
+ "get_constant_schedule_with_warmup",
275
+ "get_cosine_schedule_with_warmup",
276
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
277
+ "get_linear_schedule_with_warmup",
278
+ "get_polynomial_decay_schedule_with_warmup",
279
+ "get_scheduler",
280
+ ]
281
+ _import_structure["pipelines"].extend(
282
+ [
283
+ "AudioPipelineOutput",
284
+ "AutoPipelineForImage2Image",
285
+ "AutoPipelineForInpainting",
286
+ "AutoPipelineForText2Image",
287
+ "ConsistencyModelPipeline",
288
+ "DanceDiffusionPipeline",
289
+ "DDIMPipeline",
290
+ "DDPMPipeline",
291
+ "DiffusionPipeline",
292
+ "DiTPipeline",
293
+ "ImagePipelineOutput",
294
+ "KarrasVePipeline",
295
+ "LDMPipeline",
296
+ "LDMSuperResolutionPipeline",
297
+ "PNDMPipeline",
298
+ "RePaintPipeline",
299
+ "ScoreSdeVePipeline",
300
+ "StableDiffusionMixin",
301
+ ]
302
+ )
303
+ _import_structure["quantizers"] = ["DiffusersQuantizer"]
304
+ _import_structure["schedulers"].extend(
305
+ [
306
+ "AmusedScheduler",
307
+ "CMStochasticIterativeScheduler",
308
+ "CogVideoXDDIMScheduler",
309
+ "CogVideoXDPMScheduler",
310
+ "DDIMInverseScheduler",
311
+ "DDIMParallelScheduler",
312
+ "DDIMScheduler",
313
+ "DDPMParallelScheduler",
314
+ "DDPMScheduler",
315
+ "DDPMWuerstchenScheduler",
316
+ "DEISMultistepScheduler",
317
+ "DPMSolverMultistepInverseScheduler",
318
+ "DPMSolverMultistepScheduler",
319
+ "DPMSolverSinglestepScheduler",
320
+ "EDMDPMSolverMultistepScheduler",
321
+ "EDMEulerScheduler",
322
+ "EulerAncestralDiscreteScheduler",
323
+ "EulerDiscreteScheduler",
324
+ "FlowMatchEulerDiscreteScheduler",
325
+ "FlowMatchHeunDiscreteScheduler",
326
+ "FlowMatchLCMScheduler",
327
+ "HeunDiscreteScheduler",
328
+ "IPNDMScheduler",
329
+ "KarrasVeScheduler",
330
+ "KDPM2AncestralDiscreteScheduler",
331
+ "KDPM2DiscreteScheduler",
332
+ "LCMScheduler",
333
+ "PNDMScheduler",
334
+ "RePaintScheduler",
335
+ "SASolverScheduler",
336
+ "SchedulerMixin",
337
+ "SCMScheduler",
338
+ "ScoreSdeVeScheduler",
339
+ "TCDScheduler",
340
+ "UnCLIPScheduler",
341
+ "UniPCMultistepScheduler",
342
+ "VQDiffusionScheduler",
343
+ ]
344
+ )
345
+ _import_structure["training_utils"] = ["EMAModel"]
346
+
347
+ try:
348
+ if not (is_torch_available() and is_scipy_available()):
349
+ raise OptionalDependencyNotAvailable()
350
+ except OptionalDependencyNotAvailable:
351
+ from .utils import dummy_torch_and_scipy_objects # noqa F403
352
+
353
+ _import_structure["utils.dummy_torch_and_scipy_objects"] = [
354
+ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
355
+ ]
356
+
357
+ else:
358
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
359
+
360
+ try:
361
+ if not (is_torch_available() and is_torchsde_available()):
362
+ raise OptionalDependencyNotAvailable()
363
+ except OptionalDependencyNotAvailable:
364
+ from .utils import dummy_torch_and_torchsde_objects # noqa F403
365
+
366
+ _import_structure["utils.dummy_torch_and_torchsde_objects"] = [
367
+ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
368
+ ]
369
+
370
+ else:
371
+ _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
372
+
373
+ try:
374
+ if not (is_torch_available() and is_transformers_available()):
375
+ raise OptionalDependencyNotAvailable()
376
+ except OptionalDependencyNotAvailable:
377
+ from .utils import dummy_torch_and_transformers_objects # noqa F403
378
+
379
+ _import_structure["utils.dummy_torch_and_transformers_objects"] = [
380
+ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
381
+ ]
382
+
383
+ else:
384
+ _import_structure["modular_pipelines"].extend(
385
+ [
386
+ "FluxAutoBlocks",
387
+ "FluxModularPipeline",
388
+ "QwenImageAutoBlocks",
389
+ "QwenImageEditAutoBlocks",
390
+ "QwenImageEditModularPipeline",
391
+ "QwenImageModularPipeline",
392
+ "StableDiffusionXLAutoBlocks",
393
+ "StableDiffusionXLModularPipeline",
394
+ "WanAutoBlocks",
395
+ "WanModularPipeline",
396
+ ]
397
+ )
398
+ _import_structure["pipelines"].extend(
399
+ [
400
+ "AllegroPipeline",
401
+ "AltDiffusionImg2ImgPipeline",
402
+ "AltDiffusionPipeline",
403
+ "AmusedImg2ImgPipeline",
404
+ "AmusedInpaintPipeline",
405
+ "AmusedPipeline",
406
+ "AnimateDiffControlNetPipeline",
407
+ "AnimateDiffPAGPipeline",
408
+ "AnimateDiffPipeline",
409
+ "AnimateDiffSDXLPipeline",
410
+ "AnimateDiffSparseControlNetPipeline",
411
+ "AnimateDiffVideoToVideoControlNetPipeline",
412
+ "AnimateDiffVideoToVideoPipeline",
413
+ "AudioLDM2Pipeline",
414
+ "AudioLDM2ProjectionModel",
415
+ "AudioLDM2UNet2DConditionModel",
416
+ "AudioLDMPipeline",
417
+ "AuraFlowPipeline",
418
+ "BlipDiffusionControlNetPipeline",
419
+ "BlipDiffusionPipeline",
420
+ "BriaPipeline",
421
+ "ChromaImg2ImgPipeline",
422
+ "ChromaPipeline",
423
+ "CLIPImageProjection",
424
+ "CogVideoXFunControlPipeline",
425
+ "CogVideoXImageToVideoPipeline",
426
+ "CogVideoXPipeline",
427
+ "CogVideoXVideoToVideoPipeline",
428
+ "CogView3PlusPipeline",
429
+ "CogView4ControlPipeline",
430
+ "CogView4Pipeline",
431
+ "ConsisIDPipeline",
432
+ "Cosmos2TextToImagePipeline",
433
+ "Cosmos2VideoToWorldPipeline",
434
+ "CosmosTextToWorldPipeline",
435
+ "CosmosVideoToWorldPipeline",
436
+ "CycleDiffusionPipeline",
437
+ "EasyAnimateControlPipeline",
438
+ "EasyAnimateInpaintPipeline",
439
+ "EasyAnimatePipeline",
440
+ "FluxControlImg2ImgPipeline",
441
+ "FluxControlInpaintPipeline",
442
+ "FluxControlNetImg2ImgPipeline",
443
+ "FluxControlNetInpaintPipeline",
444
+ "FluxControlNetPipeline",
445
+ "FluxControlPipeline",
446
+ "FluxFillPipeline",
447
+ "FluxImg2ImgPipeline",
448
+ "FluxInpaintPipeline",
449
+ "FluxKontextInpaintPipeline",
450
+ "FluxKontextPipeline",
451
+ "FluxPipeline",
452
+ "FluxPriorReduxPipeline",
453
+ "HiDreamImagePipeline",
454
+ "HunyuanDiTControlNetPipeline",
455
+ "HunyuanDiTPAGPipeline",
456
+ "HunyuanDiTPipeline",
457
+ "HunyuanSkyreelsImageToVideoPipeline",
458
+ "HunyuanVideoFramepackPipeline",
459
+ "HunyuanVideoImageToVideoPipeline",
460
+ "HunyuanVideoPipeline",
461
+ "I2VGenXLPipeline",
462
+ "IFImg2ImgPipeline",
463
+ "IFImg2ImgSuperResolutionPipeline",
464
+ "IFInpaintingPipeline",
465
+ "IFInpaintingSuperResolutionPipeline",
466
+ "IFPipeline",
467
+ "IFSuperResolutionPipeline",
468
+ "ImageTextPipelineOutput",
469
+ "Kandinsky3Img2ImgPipeline",
470
+ "Kandinsky3Pipeline",
471
+ "KandinskyCombinedPipeline",
472
+ "KandinskyImg2ImgCombinedPipeline",
473
+ "KandinskyImg2ImgPipeline",
474
+ "KandinskyInpaintCombinedPipeline",
475
+ "KandinskyInpaintPipeline",
476
+ "KandinskyPipeline",
477
+ "KandinskyPriorPipeline",
478
+ "KandinskyV22CombinedPipeline",
479
+ "KandinskyV22ControlnetImg2ImgPipeline",
480
+ "KandinskyV22ControlnetPipeline",
481
+ "KandinskyV22Img2ImgCombinedPipeline",
482
+ "KandinskyV22Img2ImgPipeline",
483
+ "KandinskyV22InpaintCombinedPipeline",
484
+ "KandinskyV22InpaintPipeline",
485
+ "KandinskyV22Pipeline",
486
+ "KandinskyV22PriorEmb2EmbPipeline",
487
+ "KandinskyV22PriorPipeline",
488
+ "LatentConsistencyModelImg2ImgPipeline",
489
+ "LatentConsistencyModelPipeline",
490
+ "LattePipeline",
491
+ "LDMTextToImagePipeline",
492
+ "LEditsPPPipelineStableDiffusion",
493
+ "LEditsPPPipelineStableDiffusionXL",
494
+ "LTXConditionPipeline",
495
+ "LTXImageToVideoPipeline",
496
+ "LTXLatentUpsamplePipeline",
497
+ "LTXPipeline",
498
+ "Lumina2Pipeline",
499
+ "Lumina2Text2ImgPipeline",
500
+ "LuminaPipeline",
501
+ "LuminaText2ImgPipeline",
502
+ "MarigoldDepthPipeline",
503
+ "MarigoldIntrinsicsPipeline",
504
+ "MarigoldNormalsPipeline",
505
+ "MochiPipeline",
506
+ "MusicLDMPipeline",
507
+ "OmniGenPipeline",
508
+ "PaintByExamplePipeline",
509
+ "PIAPipeline",
510
+ "PixArtAlphaPipeline",
511
+ "PixArtSigmaPAGPipeline",
512
+ "PixArtSigmaPipeline",
513
+ "QwenImageControlNetInpaintPipeline",
514
+ "QwenImageControlNetPipeline",
515
+ "QwenImageEditInpaintPipeline",
516
+ "QwenImageEditPipeline",
517
+ "QwenImageImg2ImgPipeline",
518
+ "QwenImageInpaintPipeline",
519
+ "QwenImagePipeline",
520
+ "ReduxImageEncoder",
521
+ "SanaControlNetPipeline",
522
+ "SanaPAGPipeline",
523
+ "SanaPipeline",
524
+ "SanaSprintImg2ImgPipeline",
525
+ "SanaSprintPipeline",
526
+ "SemanticStableDiffusionPipeline",
527
+ "ShapEImg2ImgPipeline",
528
+ "ShapEPipeline",
529
+ "SkyReelsV2DiffusionForcingImageToVideoPipeline",
530
+ "SkyReelsV2DiffusionForcingPipeline",
531
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline",
532
+ "SkyReelsV2ImageToVideoPipeline",
533
+ "SkyReelsV2Pipeline",
534
+ "StableAudioPipeline",
535
+ "StableAudioProjectionModel",
536
+ "StableCascadeCombinedPipeline",
537
+ "StableCascadeDecoderPipeline",
538
+ "StableCascadePriorPipeline",
539
+ "StableDiffusion3ControlNetInpaintingPipeline",
540
+ "StableDiffusion3ControlNetPipeline",
541
+ "StableDiffusion3Img2ImgPipeline",
542
+ "StableDiffusion3InpaintPipeline",
543
+ "StableDiffusion3PAGImg2ImgPipeline",
544
+ "StableDiffusion3PAGImg2ImgPipeline",
545
+ "StableDiffusion3PAGPipeline",
546
+ "StableDiffusion3Pipeline",
547
+ "StableDiffusionAdapterPipeline",
548
+ "StableDiffusionAttendAndExcitePipeline",
549
+ "StableDiffusionControlNetImg2ImgPipeline",
550
+ "StableDiffusionControlNetInpaintPipeline",
551
+ "StableDiffusionControlNetPAGInpaintPipeline",
552
+ "StableDiffusionControlNetPAGPipeline",
553
+ "StableDiffusionControlNetPipeline",
554
+ "StableDiffusionControlNetXSPipeline",
555
+ "StableDiffusionDepth2ImgPipeline",
556
+ "StableDiffusionDiffEditPipeline",
557
+ "StableDiffusionGLIGENPipeline",
558
+ "StableDiffusionGLIGENTextImagePipeline",
559
+ "StableDiffusionImageVariationPipeline",
560
+ "StableDiffusionImg2ImgPipeline",
561
+ "StableDiffusionInpaintPipeline",
562
+ "StableDiffusionInpaintPipelineLegacy",
563
+ "StableDiffusionInstructPix2PixPipeline",
564
+ "StableDiffusionLatentUpscalePipeline",
565
+ "StableDiffusionLDM3DPipeline",
566
+ "StableDiffusionModelEditingPipeline",
567
+ "StableDiffusionPAGImg2ImgPipeline",
568
+ "StableDiffusionPAGInpaintPipeline",
569
+ "StableDiffusionPAGPipeline",
570
+ "StableDiffusionPanoramaPipeline",
571
+ "StableDiffusionParadigmsPipeline",
572
+ "StableDiffusionPipeline",
573
+ "StableDiffusionPipelineSafe",
574
+ "StableDiffusionPix2PixZeroPipeline",
575
+ "StableDiffusionSAGPipeline",
576
+ "StableDiffusionUpscalePipeline",
577
+ "StableDiffusionXLAdapterPipeline",
578
+ "StableDiffusionXLControlNetImg2ImgPipeline",
579
+ "StableDiffusionXLControlNetInpaintPipeline",
580
+ "StableDiffusionXLControlNetPAGImg2ImgPipeline",
581
+ "StableDiffusionXLControlNetPAGPipeline",
582
+ "StableDiffusionXLControlNetPipeline",
583
+ "StableDiffusionXLControlNetUnionImg2ImgPipeline",
584
+ "StableDiffusionXLControlNetUnionInpaintPipeline",
585
+ "StableDiffusionXLControlNetUnionPipeline",
586
+ "StableDiffusionXLControlNetXSPipeline",
587
+ "StableDiffusionXLImg2ImgPipeline",
588
+ "StableDiffusionXLInpaintPipeline",
589
+ "StableDiffusionXLInstructPix2PixPipeline",
590
+ "StableDiffusionXLPAGImg2ImgPipeline",
591
+ "StableDiffusionXLPAGInpaintPipeline",
592
+ "StableDiffusionXLPAGPipeline",
593
+ "StableDiffusionXLPipeline",
594
+ "StableUnCLIPImg2ImgPipeline",
595
+ "StableUnCLIPPipeline",
596
+ "StableVideoDiffusionPipeline",
597
+ "TextToVideoSDPipeline",
598
+ "TextToVideoZeroPipeline",
599
+ "TextToVideoZeroSDXLPipeline",
600
+ "UnCLIPImageVariationPipeline",
601
+ "UnCLIPPipeline",
602
+ "UniDiffuserModel",
603
+ "UniDiffuserPipeline",
604
+ "UniDiffuserTextDecoder",
605
+ "VersatileDiffusionDualGuidedPipeline",
606
+ "VersatileDiffusionImageVariationPipeline",
607
+ "VersatileDiffusionPipeline",
608
+ "VersatileDiffusionTextToImagePipeline",
609
+ "VideoToVideoSDPipeline",
610
+ "VisualClozeGenerationPipeline",
611
+ "VisualClozePipeline",
612
+ "VQDiffusionPipeline",
613
+ "WanImageToVideoPipeline",
614
+ "WanPipeline",
615
+ "WanVACEPipeline",
616
+ "WanVideoToVideoPipeline",
617
+ "WuerstchenCombinedPipeline",
618
+ "WuerstchenDecoderPipeline",
619
+ "WuerstchenPriorPipeline",
620
+ ]
621
+ )
622
+
623
+
624
+ try:
625
+ if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
626
+ raise OptionalDependencyNotAvailable()
627
+ except OptionalDependencyNotAvailable:
628
+ from .utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
629
+
630
+ _import_structure["utils.dummy_torch_and_transformers_and_opencv_objects"] = [
631
+ name for name in dir(dummy_torch_and_transformers_and_opencv_objects) if not name.startswith("_")
632
+ ]
633
+
634
+ else:
635
+ _import_structure["pipelines"].extend(["ConsisIDPipeline"])
636
+
637
+ try:
638
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
639
+ raise OptionalDependencyNotAvailable()
640
+ except OptionalDependencyNotAvailable:
641
+ from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
642
+
643
+ _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
644
+ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
645
+ ]
646
+
647
+ else:
648
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
649
+
650
+ try:
651
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
652
+ raise OptionalDependencyNotAvailable()
653
+ except OptionalDependencyNotAvailable:
654
+ from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
655
+
656
+ _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
657
+ name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
658
+ ]
659
+
660
+ else:
661
+ _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
662
+
663
+ try:
664
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
665
+ raise OptionalDependencyNotAvailable()
666
+ except OptionalDependencyNotAvailable:
667
+ from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
668
+
669
+ _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
670
+ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
671
+ ]
672
+
673
+ else:
674
+ _import_structure["pipelines"].extend(
675
+ [
676
+ "OnnxStableDiffusionImg2ImgPipeline",
677
+ "OnnxStableDiffusionInpaintPipeline",
678
+ "OnnxStableDiffusionInpaintPipelineLegacy",
679
+ "OnnxStableDiffusionPipeline",
680
+ "OnnxStableDiffusionUpscalePipeline",
681
+ "StableDiffusionOnnxPipeline",
682
+ ]
683
+ )
684
+
685
+ try:
686
+ if not (is_torch_available() and is_librosa_available()):
687
+ raise OptionalDependencyNotAvailable()
688
+ except OptionalDependencyNotAvailable:
689
+ from .utils import dummy_torch_and_librosa_objects # noqa F403
690
+
691
+ _import_structure["utils.dummy_torch_and_librosa_objects"] = [
692
+ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
693
+ ]
694
+
695
+ else:
696
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
697
+
698
+ try:
699
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
700
+ raise OptionalDependencyNotAvailable()
701
+ except OptionalDependencyNotAvailable:
702
+ from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
703
+
704
+ _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
705
+ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
706
+ ]
707
+
708
+
709
+ else:
710
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
711
+
712
+ try:
713
+ if not is_flax_available():
714
+ raise OptionalDependencyNotAvailable()
715
+ except OptionalDependencyNotAvailable:
716
+ from .utils import dummy_flax_objects # noqa F403
717
+
718
+ _import_structure["utils.dummy_flax_objects"] = [
719
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
720
+ ]
721
+
722
+
723
+ else:
724
+ _import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
725
+ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
726
+ _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
727
+ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
728
+ _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
729
+ _import_structure["schedulers"].extend(
730
+ [
731
+ "FlaxDDIMScheduler",
732
+ "FlaxDDPMScheduler",
733
+ "FlaxDPMSolverMultistepScheduler",
734
+ "FlaxEulerDiscreteScheduler",
735
+ "FlaxKarrasVeScheduler",
736
+ "FlaxLMSDiscreteScheduler",
737
+ "FlaxPNDMScheduler",
738
+ "FlaxSchedulerMixin",
739
+ "FlaxScoreSdeVeScheduler",
740
+ ]
741
+ )
742
+
743
+
744
+ try:
745
+ if not (is_flax_available() and is_transformers_available()):
746
+ raise OptionalDependencyNotAvailable()
747
+ except OptionalDependencyNotAvailable:
748
+ from .utils import dummy_flax_and_transformers_objects # noqa F403
749
+
750
+ _import_structure["utils.dummy_flax_and_transformers_objects"] = [
751
+ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
752
+ ]
753
+
754
+
755
+ else:
756
+ _import_structure["pipelines"].extend(
757
+ [
758
+ "FlaxStableDiffusionControlNetPipeline",
759
+ "FlaxStableDiffusionImg2ImgPipeline",
760
+ "FlaxStableDiffusionInpaintPipeline",
761
+ "FlaxStableDiffusionPipeline",
762
+ "FlaxStableDiffusionXLPipeline",
763
+ ]
764
+ )
765
+
766
+ try:
767
+ if not (is_note_seq_available()):
768
+ raise OptionalDependencyNotAvailable()
769
+ except OptionalDependencyNotAvailable:
770
+ from .utils import dummy_note_seq_objects # noqa F403
771
+
772
+ _import_structure["utils.dummy_note_seq_objects"] = [
773
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
774
+ ]
775
+
776
+
777
+ else:
778
+ _import_structure["pipelines"].extend(["MidiProcessor"])
779
+
780
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
781
+ from .configuration_utils import ConfigMixin
782
+ from .quantizers import PipelineQuantizationConfig
783
+
784
+ try:
785
+ if not is_bitsandbytes_available():
786
+ raise OptionalDependencyNotAvailable()
787
+ except OptionalDependencyNotAvailable:
788
+ from .utils.dummy_bitsandbytes_objects import *
789
+ else:
790
+ from .quantizers.quantization_config import BitsAndBytesConfig
791
+
792
+ try:
793
+ if not is_gguf_available():
794
+ raise OptionalDependencyNotAvailable()
795
+ except OptionalDependencyNotAvailable:
796
+ from .utils.dummy_gguf_objects import *
797
+ else:
798
+ from .quantizers.quantization_config import GGUFQuantizationConfig
799
+
800
+ try:
801
+ if not is_torchao_available():
802
+ raise OptionalDependencyNotAvailable()
803
+ except OptionalDependencyNotAvailable:
804
+ from .utils.dummy_torchao_objects import *
805
+ else:
806
+ from .quantizers.quantization_config import TorchAoConfig
807
+
808
+ try:
809
+ if not is_optimum_quanto_available():
810
+ raise OptionalDependencyNotAvailable()
811
+ except OptionalDependencyNotAvailable:
812
+ from .utils.dummy_optimum_quanto_objects import *
813
+ else:
814
+ from .quantizers.quantization_config import QuantoConfig
815
+
816
+ try:
817
+ if not is_nvidia_modelopt_available():
818
+ raise OptionalDependencyNotAvailable()
819
+ except OptionalDependencyNotAvailable:
820
+ from .utils.dummy_nvidia_modelopt_objects import *
821
+ else:
822
+ from .quantizers.quantization_config import NVIDIAModelOptConfig
823
+
824
+ try:
825
+ if not is_onnx_available():
826
+ raise OptionalDependencyNotAvailable()
827
+ except OptionalDependencyNotAvailable:
828
+ from .utils.dummy_onnx_objects import * # noqa F403
829
+ else:
830
+ from .pipelines import OnnxRuntimeModel
831
+
832
+ try:
833
+ if not is_torch_available():
834
+ raise OptionalDependencyNotAvailable()
835
+ except OptionalDependencyNotAvailable:
836
+ from .utils.dummy_pt_objects import * # noqa F403
837
+ else:
838
+ from .guiders import (
839
+ AdaptiveProjectedGuidance,
840
+ AutoGuidance,
841
+ ClassifierFreeGuidance,
842
+ ClassifierFreeZeroStarGuidance,
843
+ FrequencyDecoupledGuidance,
844
+ PerturbedAttentionGuidance,
845
+ SkipLayerGuidance,
846
+ SmoothedEnergyGuidance,
847
+ TangentialClassifierFreeGuidance,
848
+ )
849
+ from .hooks import (
850
+ FasterCacheConfig,
851
+ FirstBlockCacheConfig,
852
+ HookRegistry,
853
+ LayerSkipConfig,
854
+ PyramidAttentionBroadcastConfig,
855
+ SmoothedEnergyGuidanceConfig,
856
+ apply_faster_cache,
857
+ apply_first_block_cache,
858
+ apply_layer_skip,
859
+ apply_pyramid_attention_broadcast,
860
+ )
861
+ from .models import (
862
+ AllegroTransformer3DModel,
863
+ AsymmetricAutoencoderKL,
864
+ AttentionBackendName,
865
+ AuraFlowTransformer2DModel,
866
+ AutoencoderDC,
867
+ AutoencoderKL,
868
+ AutoencoderKLAllegro,
869
+ AutoencoderKLCogVideoX,
870
+ AutoencoderKLCosmos,
871
+ AutoencoderKLHunyuanVideo,
872
+ AutoencoderKLLTXVideo,
873
+ AutoencoderKLMagvit,
874
+ AutoencoderKLMochi,
875
+ AutoencoderKLQwenImage,
876
+ AutoencoderKLTemporalDecoder,
877
+ AutoencoderKLWan,
878
+ AutoencoderOobleck,
879
+ AutoencoderTiny,
880
+ AutoModel,
881
+ BriaTransformer2DModel,
882
+ CacheMixin,
883
+ ChromaTransformer2DModel,
884
+ CogVideoXTransformer3DModel,
885
+ CogView3PlusTransformer2DModel,
886
+ CogView4Transformer2DModel,
887
+ ConsisIDTransformer3DModel,
888
+ ConsistencyDecoderVAE,
889
+ ControlNetModel,
890
+ ControlNetUnionModel,
891
+ ControlNetXSAdapter,
892
+ CosmosTransformer3DModel,
893
+ DiTTransformer2DModel,
894
+ EasyAnimateTransformer3DModel,
895
+ FluxControlNetModel,
896
+ FluxMultiControlNetModel,
897
+ FluxTransformer2DModel,
898
+ HiDreamImageTransformer2DModel,
899
+ HunyuanDiT2DControlNetModel,
900
+ HunyuanDiT2DModel,
901
+ HunyuanDiT2DMultiControlNetModel,
902
+ HunyuanVideoFramepackTransformer3DModel,
903
+ HunyuanVideoTransformer3DModel,
904
+ I2VGenXLUNet,
905
+ Kandinsky3UNet,
906
+ LatteTransformer3DModel,
907
+ LTXVideoTransformer3DModel,
908
+ Lumina2Transformer2DModel,
909
+ LuminaNextDiT2DModel,
910
+ MochiTransformer3DModel,
911
+ ModelMixin,
912
+ MotionAdapter,
913
+ MultiAdapter,
914
+ MultiControlNetModel,
915
+ OmniGenTransformer2DModel,
916
+ PixArtTransformer2DModel,
917
+ PriorTransformer,
918
+ QwenImageControlNetModel,
919
+ QwenImageMultiControlNetModel,
920
+ QwenImageTransformer2DModel,
921
+ SanaControlNetModel,
922
+ SanaTransformer2DModel,
923
+ SD3ControlNetModel,
924
+ SD3MultiControlNetModel,
925
+ SD3Transformer2DModel,
926
+ SkyReelsV2Transformer3DModel,
927
+ SparseControlNetModel,
928
+ StableAudioDiTModel,
929
+ T2IAdapter,
930
+ T5FilmDecoder,
931
+ Transformer2DModel,
932
+ TransformerTemporalModel,
933
+ UNet1DModel,
934
+ UNet2DConditionModel,
935
+ UNet2DModel,
936
+ UNet3DConditionModel,
937
+ UNetControlNetXSModel,
938
+ UNetMotionModel,
939
+ UNetSpatioTemporalConditionModel,
940
+ UVit2DModel,
941
+ VQModel,
942
+ WanTransformer3DModel,
943
+ WanVACETransformer3DModel,
944
+ attention_backend,
945
+ )
946
+ from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
947
+ from .optimization import (
948
+ get_constant_schedule,
949
+ get_constant_schedule_with_warmup,
950
+ get_cosine_schedule_with_warmup,
951
+ get_cosine_with_hard_restarts_schedule_with_warmup,
952
+ get_linear_schedule_with_warmup,
953
+ get_polynomial_decay_schedule_with_warmup,
954
+ get_scheduler,
955
+ )
956
+ from .pipelines import (
957
+ AudioPipelineOutput,
958
+ AutoPipelineForImage2Image,
959
+ AutoPipelineForInpainting,
960
+ AutoPipelineForText2Image,
961
+ BlipDiffusionControlNetPipeline,
962
+ BlipDiffusionPipeline,
963
+ CLIPImageProjection,
964
+ ConsistencyModelPipeline,
965
+ DanceDiffusionPipeline,
966
+ DDIMPipeline,
967
+ DDPMPipeline,
968
+ DiffusionPipeline,
969
+ DiTPipeline,
970
+ ImagePipelineOutput,
971
+ KarrasVePipeline,
972
+ LDMPipeline,
973
+ LDMSuperResolutionPipeline,
974
+ PNDMPipeline,
975
+ RePaintPipeline,
976
+ ScoreSdeVePipeline,
977
+ StableDiffusionMixin,
978
+ )
979
+ from .quantizers import DiffusersQuantizer
980
+ from .schedulers import (
981
+ AmusedScheduler,
982
+ CMStochasticIterativeScheduler,
983
+ CogVideoXDDIMScheduler,
984
+ CogVideoXDPMScheduler,
985
+ DDIMInverseScheduler,
986
+ DDIMParallelScheduler,
987
+ DDIMScheduler,
988
+ DDPMParallelScheduler,
989
+ DDPMScheduler,
990
+ DDPMWuerstchenScheduler,
991
+ DEISMultistepScheduler,
992
+ DPMSolverMultistepInverseScheduler,
993
+ DPMSolverMultistepScheduler,
994
+ DPMSolverSinglestepScheduler,
995
+ EDMDPMSolverMultistepScheduler,
996
+ EDMEulerScheduler,
997
+ EulerAncestralDiscreteScheduler,
998
+ EulerDiscreteScheduler,
999
+ FlowMatchEulerDiscreteScheduler,
1000
+ FlowMatchHeunDiscreteScheduler,
1001
+ FlowMatchLCMScheduler,
1002
+ HeunDiscreteScheduler,
1003
+ IPNDMScheduler,
1004
+ KarrasVeScheduler,
1005
+ KDPM2AncestralDiscreteScheduler,
1006
+ KDPM2DiscreteScheduler,
1007
+ LCMScheduler,
1008
+ PNDMScheduler,
1009
+ RePaintScheduler,
1010
+ SASolverScheduler,
1011
+ SchedulerMixin,
1012
+ SCMScheduler,
1013
+ ScoreSdeVeScheduler,
1014
+ TCDScheduler,
1015
+ UnCLIPScheduler,
1016
+ UniPCMultistepScheduler,
1017
+ VQDiffusionScheduler,
1018
+ )
1019
+ from .training_utils import EMAModel
1020
+
1021
+ try:
1022
+ if not (is_torch_available() and is_scipy_available()):
1023
+ raise OptionalDependencyNotAvailable()
1024
+ except OptionalDependencyNotAvailable:
1025
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
1026
+ else:
1027
+ from .schedulers import LMSDiscreteScheduler
1028
+
1029
+ try:
1030
+ if not (is_torch_available() and is_torchsde_available()):
1031
+ raise OptionalDependencyNotAvailable()
1032
+ except OptionalDependencyNotAvailable:
1033
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
1034
+ else:
1035
+ from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler
1036
+
1037
+ try:
1038
+ if not (is_torch_available() and is_transformers_available()):
1039
+ raise OptionalDependencyNotAvailable()
1040
+ except OptionalDependencyNotAvailable:
1041
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
1042
+ else:
1043
+ from .modular_pipelines import (
1044
+ FluxAutoBlocks,
1045
+ FluxModularPipeline,
1046
+ QwenImageAutoBlocks,
1047
+ QwenImageEditAutoBlocks,
1048
+ QwenImageEditModularPipeline,
1049
+ QwenImageModularPipeline,
1050
+ StableDiffusionXLAutoBlocks,
1051
+ StableDiffusionXLModularPipeline,
1052
+ WanAutoBlocks,
1053
+ WanModularPipeline,
1054
+ )
1055
+ from .pipelines import (
1056
+ AllegroPipeline,
1057
+ AltDiffusionImg2ImgPipeline,
1058
+ AltDiffusionPipeline,
1059
+ AmusedImg2ImgPipeline,
1060
+ AmusedInpaintPipeline,
1061
+ AmusedPipeline,
1062
+ AnimateDiffControlNetPipeline,
1063
+ AnimateDiffPAGPipeline,
1064
+ AnimateDiffPipeline,
1065
+ AnimateDiffSDXLPipeline,
1066
+ AnimateDiffSparseControlNetPipeline,
1067
+ AnimateDiffVideoToVideoControlNetPipeline,
1068
+ AnimateDiffVideoToVideoPipeline,
1069
+ AudioLDM2Pipeline,
1070
+ AudioLDM2ProjectionModel,
1071
+ AudioLDM2UNet2DConditionModel,
1072
+ AudioLDMPipeline,
1073
+ AuraFlowPipeline,
1074
+ BriaPipeline,
1075
+ ChromaImg2ImgPipeline,
1076
+ ChromaPipeline,
1077
+ CLIPImageProjection,
1078
+ CogVideoXFunControlPipeline,
1079
+ CogVideoXImageToVideoPipeline,
1080
+ CogVideoXPipeline,
1081
+ CogVideoXVideoToVideoPipeline,
1082
+ CogView3PlusPipeline,
1083
+ CogView4ControlPipeline,
1084
+ CogView4Pipeline,
1085
+ ConsisIDPipeline,
1086
+ Cosmos2TextToImagePipeline,
1087
+ Cosmos2VideoToWorldPipeline,
1088
+ CosmosTextToWorldPipeline,
1089
+ CosmosVideoToWorldPipeline,
1090
+ CycleDiffusionPipeline,
1091
+ EasyAnimateControlPipeline,
1092
+ EasyAnimateInpaintPipeline,
1093
+ EasyAnimatePipeline,
1094
+ FluxControlImg2ImgPipeline,
1095
+ FluxControlInpaintPipeline,
1096
+ FluxControlNetImg2ImgPipeline,
1097
+ FluxControlNetInpaintPipeline,
1098
+ FluxControlNetPipeline,
1099
+ FluxControlPipeline,
1100
+ FluxFillPipeline,
1101
+ FluxImg2ImgPipeline,
1102
+ FluxInpaintPipeline,
1103
+ FluxKontextInpaintPipeline,
1104
+ FluxKontextPipeline,
1105
+ FluxPipeline,
1106
+ FluxPriorReduxPipeline,
1107
+ HiDreamImagePipeline,
1108
+ HunyuanDiTControlNetPipeline,
1109
+ HunyuanDiTPAGPipeline,
1110
+ HunyuanDiTPipeline,
1111
+ HunyuanSkyreelsImageToVideoPipeline,
1112
+ HunyuanVideoFramepackPipeline,
1113
+ HunyuanVideoImageToVideoPipeline,
1114
+ HunyuanVideoPipeline,
1115
+ I2VGenXLPipeline,
1116
+ IFImg2ImgPipeline,
1117
+ IFImg2ImgSuperResolutionPipeline,
1118
+ IFInpaintingPipeline,
1119
+ IFInpaintingSuperResolutionPipeline,
1120
+ IFPipeline,
1121
+ IFSuperResolutionPipeline,
1122
+ ImageTextPipelineOutput,
1123
+ Kandinsky3Img2ImgPipeline,
1124
+ Kandinsky3Pipeline,
1125
+ KandinskyCombinedPipeline,
1126
+ KandinskyImg2ImgCombinedPipeline,
1127
+ KandinskyImg2ImgPipeline,
1128
+ KandinskyInpaintCombinedPipeline,
1129
+ KandinskyInpaintPipeline,
1130
+ KandinskyPipeline,
1131
+ KandinskyPriorPipeline,
1132
+ KandinskyV22CombinedPipeline,
1133
+ KandinskyV22ControlnetImg2ImgPipeline,
1134
+ KandinskyV22ControlnetPipeline,
1135
+ KandinskyV22Img2ImgCombinedPipeline,
1136
+ KandinskyV22Img2ImgPipeline,
1137
+ KandinskyV22InpaintCombinedPipeline,
1138
+ KandinskyV22InpaintPipeline,
1139
+ KandinskyV22Pipeline,
1140
+ KandinskyV22PriorEmb2EmbPipeline,
1141
+ KandinskyV22PriorPipeline,
1142
+ LatentConsistencyModelImg2ImgPipeline,
1143
+ LatentConsistencyModelPipeline,
1144
+ LattePipeline,
1145
+ LDMTextToImagePipeline,
1146
+ LEditsPPPipelineStableDiffusion,
1147
+ LEditsPPPipelineStableDiffusionXL,
1148
+ LTXConditionPipeline,
1149
+ LTXImageToVideoPipeline,
1150
+ LTXLatentUpsamplePipeline,
1151
+ LTXPipeline,
1152
+ Lumina2Pipeline,
1153
+ Lumina2Text2ImgPipeline,
1154
+ LuminaPipeline,
1155
+ LuminaText2ImgPipeline,
1156
+ MarigoldDepthPipeline,
1157
+ MarigoldIntrinsicsPipeline,
1158
+ MarigoldNormalsPipeline,
1159
+ MochiPipeline,
1160
+ MusicLDMPipeline,
1161
+ OmniGenPipeline,
1162
+ PaintByExamplePipeline,
1163
+ PIAPipeline,
1164
+ PixArtAlphaPipeline,
1165
+ PixArtSigmaPAGPipeline,
1166
+ PixArtSigmaPipeline,
1167
+ QwenImageControlNetInpaintPipeline,
1168
+ QwenImageControlNetPipeline,
1169
+ QwenImageEditInpaintPipeline,
1170
+ QwenImageEditPipeline,
1171
+ QwenImageImg2ImgPipeline,
1172
+ QwenImageInpaintPipeline,
1173
+ QwenImagePipeline,
1174
+ ReduxImageEncoder,
1175
+ SanaControlNetPipeline,
1176
+ SanaPAGPipeline,
1177
+ SanaPipeline,
1178
+ SanaSprintImg2ImgPipeline,
1179
+ SanaSprintPipeline,
1180
+ SemanticStableDiffusionPipeline,
1181
+ ShapEImg2ImgPipeline,
1182
+ ShapEPipeline,
1183
+ SkyReelsV2DiffusionForcingImageToVideoPipeline,
1184
+ SkyReelsV2DiffusionForcingPipeline,
1185
+ SkyReelsV2DiffusionForcingVideoToVideoPipeline,
1186
+ SkyReelsV2ImageToVideoPipeline,
1187
+ SkyReelsV2Pipeline,
1188
+ StableAudioPipeline,
1189
+ StableAudioProjectionModel,
1190
+ StableCascadeCombinedPipeline,
1191
+ StableCascadeDecoderPipeline,
1192
+ StableCascadePriorPipeline,
1193
+ StableDiffusion3ControlNetInpaintingPipeline,
1194
+ StableDiffusion3ControlNetPipeline,
1195
+ StableDiffusion3Img2ImgPipeline,
1196
+ StableDiffusion3InpaintPipeline,
1197
+ StableDiffusion3PAGImg2ImgPipeline,
1198
+ StableDiffusion3PAGPipeline,
1199
+ StableDiffusion3Pipeline,
1200
+ StableDiffusionAdapterPipeline,
1201
+ StableDiffusionAttendAndExcitePipeline,
1202
+ StableDiffusionControlNetImg2ImgPipeline,
1203
+ StableDiffusionControlNetInpaintPipeline,
1204
+ StableDiffusionControlNetPAGInpaintPipeline,
1205
+ StableDiffusionControlNetPAGPipeline,
1206
+ StableDiffusionControlNetPipeline,
1207
+ StableDiffusionControlNetXSPipeline,
1208
+ StableDiffusionDepth2ImgPipeline,
1209
+ StableDiffusionDiffEditPipeline,
1210
+ StableDiffusionGLIGENPipeline,
1211
+ StableDiffusionGLIGENTextImagePipeline,
1212
+ StableDiffusionImageVariationPipeline,
1213
+ StableDiffusionImg2ImgPipeline,
1214
+ StableDiffusionInpaintPipeline,
1215
+ StableDiffusionInpaintPipelineLegacy,
1216
+ StableDiffusionInstructPix2PixPipeline,
1217
+ StableDiffusionLatentUpscalePipeline,
1218
+ StableDiffusionLDM3DPipeline,
1219
+ StableDiffusionModelEditingPipeline,
1220
+ StableDiffusionPAGImg2ImgPipeline,
1221
+ StableDiffusionPAGInpaintPipeline,
1222
+ StableDiffusionPAGPipeline,
1223
+ StableDiffusionPanoramaPipeline,
1224
+ StableDiffusionParadigmsPipeline,
1225
+ StableDiffusionPipeline,
1226
+ StableDiffusionPipelineSafe,
1227
+ StableDiffusionPix2PixZeroPipeline,
1228
+ StableDiffusionSAGPipeline,
1229
+ StableDiffusionUpscalePipeline,
1230
+ StableDiffusionXLAdapterPipeline,
1231
+ StableDiffusionXLControlNetImg2ImgPipeline,
1232
+ StableDiffusionXLControlNetInpaintPipeline,
1233
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
1234
+ StableDiffusionXLControlNetPAGPipeline,
1235
+ StableDiffusionXLControlNetPipeline,
1236
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
1237
+ StableDiffusionXLControlNetUnionInpaintPipeline,
1238
+ StableDiffusionXLControlNetUnionPipeline,
1239
+ StableDiffusionXLControlNetXSPipeline,
1240
+ StableDiffusionXLImg2ImgPipeline,
1241
+ StableDiffusionXLInpaintPipeline,
1242
+ StableDiffusionXLInstructPix2PixPipeline,
1243
+ StableDiffusionXLPAGImg2ImgPipeline,
1244
+ StableDiffusionXLPAGInpaintPipeline,
1245
+ StableDiffusionXLPAGPipeline,
1246
+ StableDiffusionXLPipeline,
1247
+ StableUnCLIPImg2ImgPipeline,
1248
+ StableUnCLIPPipeline,
1249
+ StableVideoDiffusionPipeline,
1250
+ TextToVideoSDPipeline,
1251
+ TextToVideoZeroPipeline,
1252
+ TextToVideoZeroSDXLPipeline,
1253
+ UnCLIPImageVariationPipeline,
1254
+ UnCLIPPipeline,
1255
+ UniDiffuserModel,
1256
+ UniDiffuserPipeline,
1257
+ UniDiffuserTextDecoder,
1258
+ VersatileDiffusionDualGuidedPipeline,
1259
+ VersatileDiffusionImageVariationPipeline,
1260
+ VersatileDiffusionPipeline,
1261
+ VersatileDiffusionTextToImagePipeline,
1262
+ VideoToVideoSDPipeline,
1263
+ VisualClozeGenerationPipeline,
1264
+ VisualClozePipeline,
1265
+ VQDiffusionPipeline,
1266
+ WanImageToVideoPipeline,
1267
+ WanPipeline,
1268
+ WanVACEPipeline,
1269
+ WanVideoToVideoPipeline,
1270
+ WuerstchenCombinedPipeline,
1271
+ WuerstchenDecoderPipeline,
1272
+ WuerstchenPriorPipeline,
1273
+ )
1274
+
1275
+ try:
1276
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
1277
+ raise OptionalDependencyNotAvailable()
1278
+ except OptionalDependencyNotAvailable:
1279
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
1280
+ else:
1281
+ from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
1282
+
1283
+ try:
1284
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
1285
+ raise OptionalDependencyNotAvailable()
1286
+ except OptionalDependencyNotAvailable:
1287
+ from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
1288
+ else:
1289
+ from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
1290
+
1291
+ try:
1292
+ if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
1293
+ raise OptionalDependencyNotAvailable()
1294
+ except OptionalDependencyNotAvailable:
1295
+ from .utils.dummy_torch_and_transformers_and_opencv_objects import * # noqa F403
1296
+ else:
1297
+ from .pipelines import ConsisIDPipeline
1298
+
1299
+ try:
1300
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
1301
+ raise OptionalDependencyNotAvailable()
1302
+ except OptionalDependencyNotAvailable:
1303
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
1304
+ else:
1305
+ from .pipelines import (
1306
+ OnnxStableDiffusionImg2ImgPipeline,
1307
+ OnnxStableDiffusionInpaintPipeline,
1308
+ OnnxStableDiffusionInpaintPipelineLegacy,
1309
+ OnnxStableDiffusionPipeline,
1310
+ OnnxStableDiffusionUpscalePipeline,
1311
+ StableDiffusionOnnxPipeline,
1312
+ )
1313
+
1314
+ try:
1315
+ if not (is_torch_available() and is_librosa_available()):
1316
+ raise OptionalDependencyNotAvailable()
1317
+ except OptionalDependencyNotAvailable:
1318
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
1319
+ else:
1320
+ from .pipelines import AudioDiffusionPipeline, Mel
1321
+
1322
+ try:
1323
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
1324
+ raise OptionalDependencyNotAvailable()
1325
+ except OptionalDependencyNotAvailable:
1326
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
1327
+ else:
1328
+ from .pipelines import SpectrogramDiffusionPipeline
1329
+
1330
+ try:
1331
+ if not is_flax_available():
1332
+ raise OptionalDependencyNotAvailable()
1333
+ except OptionalDependencyNotAvailable:
1334
+ from .utils.dummy_flax_objects import * # noqa F403
1335
+ else:
1336
+ from .models.controlnets.controlnet_flax import FlaxControlNetModel
1337
+ from .models.modeling_flax_utils import FlaxModelMixin
1338
+ from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
1339
+ from .models.vae_flax import FlaxAutoencoderKL
1340
+ from .pipelines import FlaxDiffusionPipeline
1341
+ from .schedulers import (
1342
+ FlaxDDIMScheduler,
1343
+ FlaxDDPMScheduler,
1344
+ FlaxDPMSolverMultistepScheduler,
1345
+ FlaxEulerDiscreteScheduler,
1346
+ FlaxKarrasVeScheduler,
1347
+ FlaxLMSDiscreteScheduler,
1348
+ FlaxPNDMScheduler,
1349
+ FlaxSchedulerMixin,
1350
+ FlaxScoreSdeVeScheduler,
1351
+ )
1352
+
1353
+ try:
1354
+ if not (is_flax_available() and is_transformers_available()):
1355
+ raise OptionalDependencyNotAvailable()
1356
+ except OptionalDependencyNotAvailable:
1357
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
1358
+ else:
1359
+ from .pipelines import (
1360
+ FlaxStableDiffusionControlNetPipeline,
1361
+ FlaxStableDiffusionImg2ImgPipeline,
1362
+ FlaxStableDiffusionInpaintPipeline,
1363
+ FlaxStableDiffusionPipeline,
1364
+ FlaxStableDiffusionXLPipeline,
1365
+ )
1366
+
1367
+ try:
1368
+ if not (is_note_seq_available()):
1369
+ raise OptionalDependencyNotAvailable()
1370
+ except OptionalDependencyNotAvailable:
1371
+ from .utils.dummy_note_seq_objects import * # noqa F403
1372
+ else:
1373
+ from .pipelines import MidiProcessor
1374
+
1375
+ else:
1376
+ import sys
1377
+
1378
+ sys.modules[__name__] = _LazyModule(
1379
+ __name__,
1380
+ globals()["__file__"],
1381
+ _import_structure,
1382
+ module_spec=__spec__,
1383
+ extra_objects={"__version__": __version__},
1384
+ )
pythonProject/diffusers-main/build/lib/diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
pythonProject/diffusers-main/build/lib/diffusers/commands/custom_blocks.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ TODO
18
+ """
19
+
20
+ import ast
21
+ import importlib.util
22
+ import os
23
+ from argparse import ArgumentParser, Namespace
24
+ from pathlib import Path
25
+
26
+ from ..utils import logging
27
+ from . import BaseDiffusersCLICommand
28
+
29
+
30
+ EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
31
+ CONFIG = "config.json"
32
+
33
+
34
+ def conversion_command_factory(args: Namespace):
35
+ return CustomBlocksCommand(args.block_module_name, args.block_class_name)
36
+
37
+
38
+ class CustomBlocksCommand(BaseDiffusersCLICommand):
39
+ @staticmethod
40
+ def register_subcommand(parser: ArgumentParser):
41
+ conversion_parser = parser.add_parser("custom_blocks")
42
+ conversion_parser.add_argument(
43
+ "--block_module_name",
44
+ type=str,
45
+ default="block.py",
46
+ help="Module filename in which the custom block will be implemented.",
47
+ )
48
+ conversion_parser.add_argument(
49
+ "--block_class_name",
50
+ type=str,
51
+ default=None,
52
+ help="Name of the custom block. If provided None, we will try to infer it.",
53
+ )
54
+ conversion_parser.set_defaults(func=conversion_command_factory)
55
+
56
+ def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
57
+ self.logger = logging.get_logger("diffusers-cli/custom_blocks")
58
+ self.block_module_name = Path(block_module_name)
59
+ self.block_class_name = block_class_name
60
+
61
+ def run(self):
62
+ # determine the block to be saved.
63
+ out = self._get_class_names(self.block_module_name)
64
+ classes_found = list({cls for cls, _ in out})
65
+
66
+ if self.block_class_name is not None:
67
+ child_class, parent_class = self._choose_block(out, self.block_class_name)
68
+ if child_class is None and parent_class is None:
69
+ raise ValueError(
70
+ "`block_class_name` could not be retrieved. Available classes from "
71
+ f"{self.block_module_name}:\n{classes_found}"
72
+ )
73
+ else:
74
+ self.logger.info(
75
+ f"Found classes: {classes_found} will be using {classes_found[0]}. "
76
+ "If this needs to be changed, re-run the command specifying `block_class_name`."
77
+ )
78
+ child_class, parent_class = out[0][0], out[0][1]
79
+
80
+ # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
81
+ # the user is responsible for running it, so I guess that is safe?
82
+ module_name = f"__dynamic__{self.block_module_name.stem}"
83
+ spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
84
+ module = importlib.util.module_from_spec(spec)
85
+ spec.loader.exec_module(module)
86
+ getattr(module, child_class)().save_pretrained(os.getcwd())
87
+
88
+ # or, we could create it manually.
89
+ # automap = self._create_automap(parent_class=parent_class, child_class=child_class)
90
+ # with open(CONFIG, "w") as f:
91
+ # json.dump(automap, f)
92
+ with open("requirements.txt", "w") as f:
93
+ f.write("")
94
+
95
+ def _choose_block(self, candidates, chosen=None):
96
+ for cls, base in candidates:
97
+ if cls == chosen:
98
+ return cls, base
99
+ return None, None
100
+
101
+ def _get_class_names(self, file_path):
102
+ source = file_path.read_text(encoding="utf-8")
103
+ try:
104
+ tree = ast.parse(source, filename=file_path)
105
+ except SyntaxError as e:
106
+ raise ValueError(f"Could not parse {file_path!r}: {e}") from e
107
+
108
+ results: list[tuple[str, str]] = []
109
+ for node in tree.body:
110
+ if not isinstance(node, ast.ClassDef):
111
+ continue
112
+
113
+ # extract all base names for this class
114
+ base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
115
+
116
+ # for each allowed base that appears in the class's bases, emit a tuple
117
+ for allowed in EXPECTED_PARENT_CLASSES:
118
+ if allowed in base_names:
119
+ results.append((node.name, allowed))
120
+
121
+ return results
122
+
123
+ def _get_base_name(self, node: ast.expr):
124
+ if isinstance(node, ast.Name):
125
+ return node.id
126
+ elif isinstance(node, ast.Attribute):
127
+ val = self._get_base_name(node.value)
128
+ return f"{val}.{node.attr}" if val else node.attr
129
+ return None
130
+
131
+ def _create_automap(self, parent_class, child_class):
132
+ module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
133
+ auto_map = {f"{parent_class}": f"{module}.{child_class}"}
134
+ return {"auto_map": auto_map}
pythonProject/diffusers-main/build/lib/diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .custom_blocks import CustomBlocksCommand
19
+ from .env import EnvironmentCommand
20
+ from .fp16_safetensors import FP16SafetensorsCommand
21
+
22
+
23
+ def main():
24
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
25
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
26
+
27
+ # Register commands
28
+ EnvironmentCommand.register_subcommand(commands_parser)
29
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
30
+ CustomBlocksCommand.register_subcommand(commands_parser)
31
+
32
+ # Let's go
33
+ args = parser.parse_args()
34
+
35
+ if not hasattr(args, "func"):
36
+ parser.print_help()
37
+ exit(1)
38
+
39
+ # Run
40
+ service = args.func(args)
41
+ service.run()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
pythonProject/diffusers-main/build/lib/diffusers/commands/env.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ import subprocess
17
+ from argparse import ArgumentParser
18
+
19
+ import huggingface_hub
20
+
21
+ from .. import __version__ as version
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_bitsandbytes_available,
25
+ is_flax_available,
26
+ is_google_colab,
27
+ is_peft_available,
28
+ is_safetensors_available,
29
+ is_torch_available,
30
+ is_transformers_available,
31
+ is_xformers_available,
32
+ )
33
+ from . import BaseDiffusersCLICommand
34
+
35
+
36
+ def info_command_factory(_):
37
+ return EnvironmentCommand()
38
+
39
+
40
+ class EnvironmentCommand(BaseDiffusersCLICommand):
41
+ @staticmethod
42
+ def register_subcommand(parser: ArgumentParser) -> None:
43
+ download_parser = parser.add_parser("env")
44
+ download_parser.set_defaults(func=info_command_factory)
45
+
46
+ def run(self) -> dict:
47
+ hub_version = huggingface_hub.__version__
48
+
49
+ safetensors_version = "not installed"
50
+ if is_safetensors_available():
51
+ import safetensors
52
+
53
+ safetensors_version = safetensors.__version__
54
+
55
+ pt_version = "not installed"
56
+ pt_cuda_available = "NA"
57
+ if is_torch_available():
58
+ import torch
59
+
60
+ pt_version = torch.__version__
61
+ pt_cuda_available = torch.cuda.is_available()
62
+
63
+ flax_version = "not installed"
64
+ jax_version = "not installed"
65
+ jaxlib_version = "not installed"
66
+ jax_backend = "NA"
67
+ if is_flax_available():
68
+ import flax
69
+ import jax
70
+ import jaxlib
71
+
72
+ flax_version = flax.__version__
73
+ jax_version = jax.__version__
74
+ jaxlib_version = jaxlib.__version__
75
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
76
+
77
+ transformers_version = "not installed"
78
+ if is_transformers_available():
79
+ import transformers
80
+
81
+ transformers_version = transformers.__version__
82
+
83
+ accelerate_version = "not installed"
84
+ if is_accelerate_available():
85
+ import accelerate
86
+
87
+ accelerate_version = accelerate.__version__
88
+
89
+ peft_version = "not installed"
90
+ if is_peft_available():
91
+ import peft
92
+
93
+ peft_version = peft.__version__
94
+
95
+ bitsandbytes_version = "not installed"
96
+ if is_bitsandbytes_available():
97
+ import bitsandbytes
98
+
99
+ bitsandbytes_version = bitsandbytes.__version__
100
+
101
+ xformers_version = "not installed"
102
+ if is_xformers_available():
103
+ import xformers
104
+
105
+ xformers_version = xformers.__version__
106
+
107
+ platform_info = platform.platform()
108
+
109
+ is_google_colab_str = "Yes" if is_google_colab() else "No"
110
+
111
+ accelerator = "NA"
112
+ if platform.system() in {"Linux", "Windows"}:
113
+ try:
114
+ sp = subprocess.Popen(
115
+ ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
116
+ stdout=subprocess.PIPE,
117
+ stderr=subprocess.PIPE,
118
+ )
119
+ out_str, _ = sp.communicate()
120
+ out_str = out_str.decode("utf-8")
121
+
122
+ if len(out_str) > 0:
123
+ accelerator = out_str.strip()
124
+ except FileNotFoundError:
125
+ pass
126
+ elif platform.system() == "Darwin": # Mac OS
127
+ try:
128
+ sp = subprocess.Popen(
129
+ ["system_profiler", "SPDisplaysDataType"],
130
+ stdout=subprocess.PIPE,
131
+ stderr=subprocess.PIPE,
132
+ )
133
+ out_str, _ = sp.communicate()
134
+ out_str = out_str.decode("utf-8")
135
+
136
+ start = out_str.find("Chipset Model:")
137
+ if start != -1:
138
+ start += len("Chipset Model:")
139
+ end = out_str.find("\n", start)
140
+ accelerator = out_str[start:end].strip()
141
+
142
+ start = out_str.find("VRAM (Total):")
143
+ if start != -1:
144
+ start += len("VRAM (Total):")
145
+ end = out_str.find("\n", start)
146
+ accelerator += " VRAM: " + out_str[start:end].strip()
147
+ except FileNotFoundError:
148
+ pass
149
+ else:
150
+ print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
151
+
152
+ info = {
153
+ "🤗 Diffusers version": version,
154
+ "Platform": platform_info,
155
+ "Running on Google Colab?": is_google_colab_str,
156
+ "Python version": platform.python_version(),
157
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
158
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
159
+ "Jax version": jax_version,
160
+ "JaxLib version": jaxlib_version,
161
+ "Huggingface_hub version": hub_version,
162
+ "Transformers version": transformers_version,
163
+ "Accelerate version": accelerate_version,
164
+ "PEFT version": peft_version,
165
+ "Bitsandbytes version": bitsandbytes_version,
166
+ "Safetensors version": safetensors_version,
167
+ "xFormers version": xformers_version,
168
+ "Accelerator": accelerator,
169
+ "Using GPU in script?": "<fill in>",
170
+ "Using distributed or parallel set-up in script?": "<fill in>",
171
+ }
172
+
173
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
174
+ print(self.format_dict(info))
175
+
176
+ return info
177
+
178
+ @staticmethod
179
+ def format_dict(d: dict) -> str:
180
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
pythonProject/diffusers-main/build/lib/diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ import warnings
23
+ from argparse import ArgumentParser, Namespace
24
+ from importlib import import_module
25
+
26
+ import huggingface_hub
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from packaging import version
30
+
31
+ from ..utils import logging
32
+ from . import BaseDiffusersCLICommand
33
+
34
+
35
+ def conversion_command_factory(args: Namespace):
36
+ if args.use_auth_token:
37
+ warnings.warn(
38
+ "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
39
+ " handled automatically if user is logged in."
40
+ )
41
+ return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
42
+
43
+
44
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
45
+ @staticmethod
46
+ def register_subcommand(parser: ArgumentParser):
47
+ conversion_parser = parser.add_parser("fp16_safetensors")
48
+ conversion_parser.add_argument(
49
+ "--ckpt_id",
50
+ type=str,
51
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
52
+ )
53
+ conversion_parser.add_argument(
54
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
55
+ )
56
+ conversion_parser.add_argument(
57
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
58
+ )
59
+ conversion_parser.add_argument(
60
+ "--use_auth_token",
61
+ action="store_true",
62
+ help="When working with checkpoints having private visibility. When used `hf auth login` needs to be run beforehand.",
63
+ )
64
+ conversion_parser.set_defaults(func=conversion_command_factory)
65
+
66
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
67
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
68
+ self.ckpt_id = ckpt_id
69
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
70
+ self.fp16 = fp16
71
+
72
+ self.use_safetensors = use_safetensors
73
+
74
+ if not self.use_safetensors and not self.fp16:
75
+ raise NotImplementedError(
76
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
77
+ )
78
+
79
+ def run(self):
80
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
81
+ raise ImportError(
82
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
83
+ " installation."
84
+ )
85
+ else:
86
+ from huggingface_hub import create_commit
87
+ from huggingface_hub._commit_api import CommitOperationAdd
88
+
89
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
90
+ with open(model_index, "r") as f:
91
+ pipeline_class_name = json.load(f)["_class_name"]
92
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
93
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
94
+
95
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
96
+ # here, but just to avoid any rough edge cases.
97
+ pipeline = pipeline_class.from_pretrained(
98
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
99
+ )
100
+ pipeline.save_pretrained(
101
+ self.local_ckpt_dir,
102
+ safe_serialization=True if self.use_safetensors else False,
103
+ variant="fp16" if self.fp16 else None,
104
+ )
105
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
106
+
107
+ # Fetch all the paths.
108
+ if self.fp16:
109
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
110
+ elif self.use_safetensors:
111
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
112
+
113
+ # Prepare for the PR.
114
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
115
+ operations = []
116
+ for path in modified_paths:
117
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
118
+
119
+ # Open the PR.
120
+ commit_description = (
121
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
122
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
123
+ )
124
+ hub_pr_url = create_commit(
125
+ repo_id=self.ckpt_id,
126
+ operations=operations,
127
+ commit_message=commit_message,
128
+ commit_description=commit_description,
129
+ repo_type="model",
130
+ create_pr=True,
131
+ ).pr_url
132
+ self.logger.info(f"PR created here: {hub_pr_url}.")
pythonProject/diffusers-main/build/lib/diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unets.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+
53
+ self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
54
+
55
+ self.data = env.get_dataset()
56
+ self.means = {}
57
+ for key in self.data.keys():
58
+ try:
59
+ self.means[key] = self.data[key].mean()
60
+ except: # noqa: E722
61
+ pass
62
+ self.stds = {}
63
+ for key in self.data.keys():
64
+ try:
65
+ self.stds[key] = self.data[key].std()
66
+ except: # noqa: E722
67
+ pass
68
+ self.state_dim = env.observation_space.shape[0]
69
+ self.action_dim = env.action_space.shape[0]
70
+
71
+ def normalize(self, x_in, key):
72
+ return (x_in - self.means[key]) / self.stds[key]
73
+
74
+ def de_normalize(self, x_in, key):
75
+ return x_in * self.stds[key] + self.means[key]
76
+
77
+ def to_torch(self, x_in):
78
+ if isinstance(x_in, dict):
79
+ return {k: self.to_torch(v) for k, v in x_in.items()}
80
+ elif torch.is_tensor(x_in):
81
+ return x_in.to(self.unet.device)
82
+ return torch.tensor(x_in, device=self.unet.device)
83
+
84
+ def reset_x0(self, x_in, cond, act_dim):
85
+ for key, val in cond.items():
86
+ x_in[:, key, act_dim:] = val.clone()
87
+ return x_in
88
+
89
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
90
+ batch_size = x.shape[0]
91
+ y = None
92
+ for i in tqdm.tqdm(self.scheduler.timesteps):
93
+ # create batch of timesteps to pass into model
94
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
95
+ for _ in range(n_guide_steps):
96
+ with torch.enable_grad():
97
+ x.requires_grad_()
98
+
99
+ # permute to match dimension for pre-trained models
100
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
101
+ grad = torch.autograd.grad([y.sum()], [x])[0]
102
+
103
+ posterior_variance = self.scheduler._get_variance(i)
104
+ model_std = torch.exp(0.5 * posterior_variance)
105
+ grad = model_std * grad
106
+
107
+ grad[timesteps < 2] = 0
108
+ x = x.detach()
109
+ x = x + scale * grad
110
+ x = self.reset_x0(x, conditions, self.action_dim)
111
+
112
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
113
+
114
+ # TODO: verify deprecation of this kwarg
115
+ x = self.scheduler.step(prev_x, i, x)["prev_sample"]
116
+
117
+ # apply conditions to the trajectory (set the initial state)
118
+ x = self.reset_x0(x, conditions, self.action_dim)
119
+ x = self.to_torch(x)
120
+ return x, y
121
+
122
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
123
+ # normalize the observations and create batch dimension
124
+ obs = self.normalize(obs, "observations")
125
+ obs = obs[None].repeat(batch_size, axis=0)
126
+
127
+ conditions = {0: self.to_torch(obs)}
128
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
129
+
130
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
131
+ x1 = randn_tensor(shape, device=self.unet.device)
132
+ x = self.reset_x0(x1, conditions, self.action_dim)
133
+ x = self.to_torch(x)
134
+
135
+ # run the diffusion process
136
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
137
+
138
+ # sort output trajectories by value
139
+ sorted_idx = y.argsort(0, descending=True).squeeze()
140
+ sorted_values = x[sorted_idx]
141
+ actions = sorted_values[:, :, : self.action_dim]
142
+ actions = actions.detach().cpu().numpy()
143
+ denorm_actions = self.de_normalize(actions, key="actions")
144
+
145
+ # select the action with the highest value
146
+ if y is not None:
147
+ selected_index = 0
148
+ else:
149
+ # if we didn't run value guiding, select a random action
150
+ selected_index = np.random.randint(0, batch_size)
151
+
152
+ denorm_actions = denorm_actions[selected_index, 0]
153
+ return denorm_actions
pythonProject/diffusers-main/build/lib/diffusers/guiders/adaptive_projected_guidance.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..modular_pipelines.modular_pipeline import BlockState
26
+
27
+
28
+ class AdaptiveProjectedGuidance(BaseGuidance):
29
+ """
30
+ Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
31
+
32
+ Args:
33
+ guidance_scale (`float`, defaults to `7.5`):
34
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
35
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
36
+ deterioration of image quality.
37
+ adaptive_projected_guidance_momentum (`float`, defaults to `None`):
38
+ The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
39
+ adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
40
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
41
+ guidance_rescale (`float`, defaults to `0.0`):
42
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
43
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
44
+ Flawed](https://huggingface.co/papers/2305.08891).
45
+ use_original_formulation (`bool`, defaults to `False`):
46
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
47
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
48
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
49
+ start (`float`, defaults to `0.0`):
50
+ The fraction of the total number of denoising steps after which guidance starts.
51
+ stop (`float`, defaults to `1.0`):
52
+ The fraction of the total number of denoising steps after which guidance stops.
53
+ """
54
+
55
+ _input_predictions = ["pred_cond", "pred_uncond"]
56
+
57
+ @register_to_config
58
+ def __init__(
59
+ self,
60
+ guidance_scale: float = 7.5,
61
+ adaptive_projected_guidance_momentum: Optional[float] = None,
62
+ adaptive_projected_guidance_rescale: float = 15.0,
63
+ eta: float = 1.0,
64
+ guidance_rescale: float = 0.0,
65
+ use_original_formulation: bool = False,
66
+ start: float = 0.0,
67
+ stop: float = 1.0,
68
+ ):
69
+ super().__init__(start, stop)
70
+
71
+ self.guidance_scale = guidance_scale
72
+ self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
73
+ self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
74
+ self.eta = eta
75
+ self.guidance_rescale = guidance_rescale
76
+ self.use_original_formulation = use_original_formulation
77
+ self.momentum_buffer = None
78
+
79
+ def prepare_inputs(
80
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
81
+ ) -> List["BlockState"]:
82
+ if input_fields is None:
83
+ input_fields = self._input_fields
84
+
85
+ if self._step == 0:
86
+ if self.adaptive_projected_guidance_momentum is not None:
87
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
88
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
89
+ data_batches = []
90
+ for i in range(self.num_conditions):
91
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
92
+ data_batches.append(data_batch)
93
+ return data_batches
94
+
95
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
96
+ pred = None
97
+
98
+ if not self._is_apg_enabled():
99
+ pred = pred_cond
100
+ else:
101
+ pred = normalized_guidance(
102
+ pred_cond,
103
+ pred_uncond,
104
+ self.guidance_scale,
105
+ self.momentum_buffer,
106
+ self.eta,
107
+ self.adaptive_projected_guidance_rescale,
108
+ self.use_original_formulation,
109
+ )
110
+
111
+ if self.guidance_rescale > 0.0:
112
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
113
+
114
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
115
+
116
+ @property
117
+ def is_conditional(self) -> bool:
118
+ return self._count_prepared == 1
119
+
120
+ @property
121
+ def num_conditions(self) -> int:
122
+ num_conditions = 1
123
+ if self._is_apg_enabled():
124
+ num_conditions += 1
125
+ return num_conditions
126
+
127
+ def _is_apg_enabled(self) -> bool:
128
+ if not self._enabled:
129
+ return False
130
+
131
+ is_within_range = True
132
+ if self._num_inference_steps is not None:
133
+ skip_start_step = int(self._start * self._num_inference_steps)
134
+ skip_stop_step = int(self._stop * self._num_inference_steps)
135
+ is_within_range = skip_start_step <= self._step < skip_stop_step
136
+
137
+ is_close = False
138
+ if self.use_original_formulation:
139
+ is_close = math.isclose(self.guidance_scale, 0.0)
140
+ else:
141
+ is_close = math.isclose(self.guidance_scale, 1.0)
142
+
143
+ return is_within_range and not is_close
144
+
145
+
146
+ class MomentumBuffer:
147
+ def __init__(self, momentum: float):
148
+ self.momentum = momentum
149
+ self.running_average = 0
150
+
151
+ def update(self, update_value: torch.Tensor):
152
+ new_average = self.momentum * self.running_average
153
+ self.running_average = update_value + new_average
154
+
155
+
156
+ def normalized_guidance(
157
+ pred_cond: torch.Tensor,
158
+ pred_uncond: torch.Tensor,
159
+ guidance_scale: float,
160
+ momentum_buffer: Optional[MomentumBuffer] = None,
161
+ eta: float = 1.0,
162
+ norm_threshold: float = 0.0,
163
+ use_original_formulation: bool = False,
164
+ ):
165
+ diff = pred_cond - pred_uncond
166
+ dim = [-i for i in range(1, len(diff.shape))]
167
+
168
+ if momentum_buffer is not None:
169
+ momentum_buffer.update(diff)
170
+ diff = momentum_buffer.running_average
171
+
172
+ if norm_threshold > 0:
173
+ ones = torch.ones_like(diff)
174
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
175
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
176
+ diff = diff * scale_factor
177
+
178
+ v0, v1 = diff.double(), pred_cond.double()
179
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
180
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
181
+ v0_orthogonal = v0 - v0_parallel
182
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
183
+ normalized_update = diff_orthogonal + eta * diff_parallel
184
+
185
+ pred = pred_cond if use_original_formulation else pred_uncond
186
+ pred = pred + guidance_scale * normalized_update
187
+
188
+ return pred
pythonProject/diffusers-main/build/lib/diffusers/guiders/auto_guidance.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..hooks import HookRegistry, LayerSkipConfig
22
+ from ..hooks.layer_skip import _apply_layer_skip_hook
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class AutoGuidance(BaseGuidance):
31
+ """
32
+ AutoGuidance: https://huggingface.co/papers/2406.02507
33
+
34
+ Args:
35
+ guidance_scale (`float`, defaults to `7.5`):
36
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
37
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
38
+ deterioration of image quality.
39
+ auto_guidance_layers (`int` or `List[int]`, *optional*):
40
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
41
+ provided, `skip_layer_config` must be provided.
42
+ auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
43
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
44
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
45
+ dropout (`float`, *optional*):
46
+ The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
47
+ `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
48
+ guidance_rescale (`float`, defaults to `0.0`):
49
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
50
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
51
+ Flawed](https://huggingface.co/papers/2305.08891).
52
+ use_original_formulation (`bool`, defaults to `False`):
53
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
54
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
55
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
56
+ start (`float`, defaults to `0.0`):
57
+ The fraction of the total number of denoising steps after which guidance starts.
58
+ stop (`float`, defaults to `1.0`):
59
+ The fraction of the total number of denoising steps after which guidance stops.
60
+ """
61
+
62
+ _input_predictions = ["pred_cond", "pred_uncond"]
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ guidance_scale: float = 7.5,
68
+ auto_guidance_layers: Optional[Union[int, List[int]]] = None,
69
+ auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
70
+ dropout: Optional[float] = None,
71
+ guidance_rescale: float = 0.0,
72
+ use_original_formulation: bool = False,
73
+ start: float = 0.0,
74
+ stop: float = 1.0,
75
+ ):
76
+ super().__init__(start, stop)
77
+
78
+ self.guidance_scale = guidance_scale
79
+ self.auto_guidance_layers = auto_guidance_layers
80
+ self.auto_guidance_config = auto_guidance_config
81
+ self.dropout = dropout
82
+ self.guidance_rescale = guidance_rescale
83
+ self.use_original_formulation = use_original_formulation
84
+
85
+ is_layer_or_config_provided = auto_guidance_layers is not None or auto_guidance_config is not None
86
+ is_layer_and_config_provided = auto_guidance_layers is not None and auto_guidance_config is not None
87
+ if not is_layer_or_config_provided:
88
+ raise ValueError(
89
+ "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable AutoGuidance."
90
+ )
91
+ if is_layer_and_config_provided:
92
+ raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
93
+ if auto_guidance_config is None and dropout is None:
94
+ raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
95
+
96
+ if auto_guidance_layers is not None:
97
+ if isinstance(auto_guidance_layers, int):
98
+ auto_guidance_layers = [auto_guidance_layers]
99
+ if not isinstance(auto_guidance_layers, list):
100
+ raise ValueError(
101
+ f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
102
+ )
103
+ auto_guidance_config = [
104
+ LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
105
+ ]
106
+
107
+ if isinstance(auto_guidance_config, dict):
108
+ auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
109
+
110
+ if isinstance(auto_guidance_config, LayerSkipConfig):
111
+ auto_guidance_config = [auto_guidance_config]
112
+
113
+ if not isinstance(auto_guidance_config, list):
114
+ raise ValueError(
115
+ f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
116
+ )
117
+ elif isinstance(next(iter(auto_guidance_config), None), dict):
118
+ auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
119
+
120
+ self.auto_guidance_config = auto_guidance_config
121
+ self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
122
+
123
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
124
+ self._count_prepared += 1
125
+ if self._is_ag_enabled() and self.is_unconditional:
126
+ for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
127
+ _apply_layer_skip_hook(denoiser, config, name=name)
128
+
129
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
130
+ if self._is_ag_enabled() and self.is_unconditional:
131
+ for name in self._auto_guidance_hook_names:
132
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
133
+ registry.remove_hook(name, recurse=True)
134
+
135
+ def prepare_inputs(
136
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
137
+ ) -> List["BlockState"]:
138
+ if input_fields is None:
139
+ input_fields = self._input_fields
140
+
141
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
142
+ data_batches = []
143
+ for i in range(self.num_conditions):
144
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
145
+ data_batches.append(data_batch)
146
+ return data_batches
147
+
148
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
149
+ pred = None
150
+
151
+ if not self._is_ag_enabled():
152
+ pred = pred_cond
153
+ else:
154
+ shift = pred_cond - pred_uncond
155
+ pred = pred_cond if self.use_original_formulation else pred_uncond
156
+ pred = pred + self.guidance_scale * shift
157
+
158
+ if self.guidance_rescale > 0.0:
159
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
160
+
161
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
162
+
163
+ @property
164
+ def is_conditional(self) -> bool:
165
+ return self._count_prepared == 1
166
+
167
+ @property
168
+ def num_conditions(self) -> int:
169
+ num_conditions = 1
170
+ if self._is_ag_enabled():
171
+ num_conditions += 1
172
+ return num_conditions
173
+
174
+ def _is_ag_enabled(self) -> bool:
175
+ if not self._enabled:
176
+ return False
177
+
178
+ is_within_range = True
179
+ if self._num_inference_steps is not None:
180
+ skip_start_step = int(self._start * self._num_inference_steps)
181
+ skip_stop_step = int(self._stop * self._num_inference_steps)
182
+ is_within_range = skip_start_step <= self._step < skip_stop_step
183
+
184
+ is_close = False
185
+ if self.use_original_formulation:
186
+ is_close = math.isclose(self.guidance_scale, 0.0)
187
+ else:
188
+ is_close = math.isclose(self.guidance_scale, 1.0)
189
+
190
+ return is_within_range and not is_close
pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_guidance.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..modular_pipelines.modular_pipeline import BlockState
26
+
27
+
28
+ class ClassifierFreeGuidance(BaseGuidance):
29
+ """
30
+ Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
31
+
32
+ CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
33
+ jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
34
+ inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
35
+ proposes scaling and shifting the conditional distribution based on the difference between conditional and
36
+ unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
37
+
38
+ Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
39
+ paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
40
+ theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
41
+
42
+ The intution behind the original formulation can be thought of as moving the conditional distribution estimates
43
+ further away from the unconditional distribution estimates, while the diffusers-native implementation can be
44
+ thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
45
+ the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
46
+
47
+ The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
48
+ paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
49
+
50
+ Args:
51
+ guidance_scale (`float`, defaults to `7.5`):
52
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
53
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
54
+ deterioration of image quality.
55
+ guidance_rescale (`float`, defaults to `0.0`):
56
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
57
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
58
+ Flawed](https://huggingface.co/papers/2305.08891).
59
+ use_original_formulation (`bool`, defaults to `False`):
60
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
61
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
62
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
63
+ start (`float`, defaults to `0.0`):
64
+ The fraction of the total number of denoising steps after which guidance starts.
65
+ stop (`float`, defaults to `1.0`):
66
+ The fraction of the total number of denoising steps after which guidance stops.
67
+ """
68
+
69
+ _input_predictions = ["pred_cond", "pred_uncond"]
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ guidance_scale: float = 7.5,
75
+ guidance_rescale: float = 0.0,
76
+ use_original_formulation: bool = False,
77
+ start: float = 0.0,
78
+ stop: float = 1.0,
79
+ ):
80
+ super().__init__(start, stop)
81
+
82
+ self.guidance_scale = guidance_scale
83
+ self.guidance_rescale = guidance_rescale
84
+ self.use_original_formulation = use_original_formulation
85
+
86
+ def prepare_inputs(
87
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
88
+ ) -> List["BlockState"]:
89
+ if input_fields is None:
90
+ input_fields = self._input_fields
91
+
92
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
93
+ data_batches = []
94
+ for i in range(self.num_conditions):
95
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
96
+ data_batches.append(data_batch)
97
+ return data_batches
98
+
99
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
100
+ pred = None
101
+
102
+ if not self._is_cfg_enabled():
103
+ pred = pred_cond
104
+ else:
105
+ shift = pred_cond - pred_uncond
106
+ pred = pred_cond if self.use_original_formulation else pred_uncond
107
+ pred = pred + self.guidance_scale * shift
108
+
109
+ if self.guidance_rescale > 0.0:
110
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
111
+
112
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
113
+
114
+ @property
115
+ def is_conditional(self) -> bool:
116
+ return self._count_prepared == 1
117
+
118
+ @property
119
+ def num_conditions(self) -> int:
120
+ num_conditions = 1
121
+ if self._is_cfg_enabled():
122
+ num_conditions += 1
123
+ return num_conditions
124
+
125
+ def _is_cfg_enabled(self) -> bool:
126
+ if not self._enabled:
127
+ return False
128
+
129
+ is_within_range = True
130
+ if self._num_inference_steps is not None:
131
+ skip_start_step = int(self._start * self._num_inference_steps)
132
+ skip_stop_step = int(self._stop * self._num_inference_steps)
133
+ is_within_range = skip_start_step <= self._step < skip_stop_step
134
+
135
+ is_close = False
136
+ if self.use_original_formulation:
137
+ is_close = math.isclose(self.guidance_scale, 0.0)
138
+ else:
139
+ is_close = math.isclose(self.guidance_scale, 1.0)
140
+
141
+ return is_within_range and not is_close
pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_zero_star_guidance.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..modular_pipelines.modular_pipeline import BlockState
26
+
27
+
28
+ class ClassifierFreeZeroStarGuidance(BaseGuidance):
29
+ """
30
+ Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
31
+
32
+ This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
33
+ guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
34
+ process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
35
+ quality of generated images.
36
+
37
+ The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
38
+
39
+ Args:
40
+ guidance_scale (`float`, defaults to `7.5`):
41
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
42
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
43
+ deterioration of image quality.
44
+ zero_init_steps (`int`, defaults to `1`):
45
+ The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
46
+ guidance_rescale (`float`, defaults to `0.0`):
47
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
48
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
49
+ Flawed](https://huggingface.co/papers/2305.08891).
50
+ use_original_formulation (`bool`, defaults to `False`):
51
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
52
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
53
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
54
+ start (`float`, defaults to `0.01`):
55
+ The fraction of the total number of denoising steps after which guidance starts.
56
+ stop (`float`, defaults to `0.2`):
57
+ The fraction of the total number of denoising steps after which guidance stops.
58
+ """
59
+
60
+ _input_predictions = ["pred_cond", "pred_uncond"]
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ guidance_scale: float = 7.5,
66
+ zero_init_steps: int = 1,
67
+ guidance_rescale: float = 0.0,
68
+ use_original_formulation: bool = False,
69
+ start: float = 0.0,
70
+ stop: float = 1.0,
71
+ ):
72
+ super().__init__(start, stop)
73
+
74
+ self.guidance_scale = guidance_scale
75
+ self.zero_init_steps = zero_init_steps
76
+ self.guidance_rescale = guidance_rescale
77
+ self.use_original_formulation = use_original_formulation
78
+
79
+ def prepare_inputs(
80
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
81
+ ) -> List["BlockState"]:
82
+ if input_fields is None:
83
+ input_fields = self._input_fields
84
+
85
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
86
+ data_batches = []
87
+ for i in range(self.num_conditions):
88
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
89
+ data_batches.append(data_batch)
90
+ return data_batches
91
+
92
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
93
+ pred = None
94
+
95
+ if self._step < self.zero_init_steps:
96
+ pred = torch.zeros_like(pred_cond)
97
+ elif not self._is_cfg_enabled():
98
+ pred = pred_cond
99
+ else:
100
+ pred_cond_flat = pred_cond.flatten(1)
101
+ pred_uncond_flat = pred_uncond.flatten(1)
102
+ alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
103
+ alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
104
+ pred_uncond = pred_uncond * alpha
105
+ shift = pred_cond - pred_uncond
106
+ pred = pred_cond if self.use_original_formulation else pred_uncond
107
+ pred = pred + self.guidance_scale * shift
108
+
109
+ if self.guidance_rescale > 0.0:
110
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
111
+
112
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
113
+
114
+ @property
115
+ def is_conditional(self) -> bool:
116
+ return self._count_prepared == 1
117
+
118
+ @property
119
+ def num_conditions(self) -> int:
120
+ num_conditions = 1
121
+ if self._is_cfg_enabled():
122
+ num_conditions += 1
123
+ return num_conditions
124
+
125
+ def _is_cfg_enabled(self) -> bool:
126
+ if not self._enabled:
127
+ return False
128
+
129
+ is_within_range = True
130
+ if self._num_inference_steps is not None:
131
+ skip_start_step = int(self._start * self._num_inference_steps)
132
+ skip_stop_step = int(self._stop * self._num_inference_steps)
133
+ is_within_range = skip_start_step <= self._step < skip_stop_step
134
+
135
+ is_close = False
136
+ if self.use_original_formulation:
137
+ is_close = math.isclose(self.guidance_scale, 0.0)
138
+ else:
139
+ is_close = math.isclose(self.guidance_scale, 1.0)
140
+
141
+ return is_within_range and not is_close
142
+
143
+
144
+ def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
145
+ cond_dtype = cond.dtype
146
+ cond = cond.float()
147
+ uncond = uncond.float()
148
+ dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
149
+ squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
150
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
151
+ scale = dot_product / squared_norm
152
+ return scale.to(dtype=cond_dtype)
pythonProject/diffusers-main/build/lib/diffusers/guiders/frequency_decoupled_guidance.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..utils import is_kornia_available
22
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from ..modular_pipelines.modular_pipeline import BlockState
27
+
28
+
29
+ _CAN_USE_KORNIA = is_kornia_available()
30
+
31
+
32
+ if _CAN_USE_KORNIA:
33
+ from kornia.geometry import pyrup as upsample_and_blur_func
34
+ from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
35
+ else:
36
+ upsample_and_blur_func = None
37
+ build_laplacian_pyramid_func = None
38
+
39
+
40
+ def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
43
+ (Algorithm 2).
44
+ """
45
+ # v0 shape: [B, ...]
46
+ # v1 shape: [B, ...]
47
+ # Assume first dim is a batch dim and all other dims are channel or "spatial" dims
48
+ all_dims_but_first = list(range(1, len(v0.shape)))
49
+ if upcast_to_double:
50
+ dtype = v0.dtype
51
+ v0, v1 = v0.double(), v1.double()
52
+ v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
53
+ v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
54
+ v0_orthogonal = v0 - v0_parallel
55
+ if upcast_to_double:
56
+ v0_parallel = v0_parallel.to(dtype)
57
+ v0_orthogonal = v0_orthogonal.to(dtype)
58
+ return v0_parallel, v0_orthogonal
59
+
60
+
61
+ def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
62
+ """
63
+ Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
64
+ (Algorithm 2).
65
+ """
66
+ # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
67
+ img = pyramid[-1]
68
+ for i in range(len(pyramid) - 2, -1, -1):
69
+ img = upsample_and_blur_func(img) + pyramid[i]
70
+ return img
71
+
72
+
73
+ class FrequencyDecoupledGuidance(BaseGuidance):
74
+ """
75
+ Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
76
+
77
+ FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
78
+ quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
79
+ conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
80
+ how CFG works, you can check out the CFG guider.)
81
+
82
+ FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
83
+ using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
84
+ separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
85
+ frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
86
+ to form the final FDG prediction.
87
+
88
+ For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
89
+ diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
90
+ sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
91
+ the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
92
+ example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
93
+
94
+ As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
95
+ paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
96
+ theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
97
+
98
+ The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
99
+ paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
100
+
101
+ Args:
102
+ guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
103
+ The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
104
+ frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
105
+ values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
106
+ image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
107
+ lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
108
+ descending order).
109
+ guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
110
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
111
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
112
+ Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
113
+ `guidance_scales`.
114
+ parallel_weights (`float` or `List[float]`, *optional*):
115
+ Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
116
+ set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
117
+ (that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
118
+ recommended. If a list is supplied, it should be the same length as `guidance_scales`.
119
+ use_original_formulation (`bool`, defaults to `False`):
120
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
121
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
122
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
123
+ start (`float` or `List[float]`, defaults to `0.0`):
124
+ The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
125
+ should be the same length as `guidance_scales`.
126
+ stop (`float` or `List[float]`, defaults to `1.0`):
127
+ The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
128
+ should be the same length as `guidance_scales`.
129
+ guidance_rescale_space (`str`, defaults to `"data"`):
130
+ Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
131
+ `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
132
+ speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
133
+ will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
134
+ upcast_to_double (`bool`, defaults to `True`):
135
+ Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
136
+ float64 when performing guidance. This may result in better performance at the cost of increased runtime.
137
+ """
138
+
139
+ _input_predictions = ["pred_cond", "pred_uncond"]
140
+
141
+ @register_to_config
142
+ def __init__(
143
+ self,
144
+ guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
145
+ guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
146
+ parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
147
+ use_original_formulation: bool = False,
148
+ start: Union[float, List[float], Tuple[float]] = 0.0,
149
+ stop: Union[float, List[float], Tuple[float]] = 1.0,
150
+ guidance_rescale_space: str = "data",
151
+ upcast_to_double: bool = True,
152
+ ):
153
+ if not _CAN_USE_KORNIA:
154
+ raise ImportError(
155
+ "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
156
+ "it depends is not available in the current environment. You can install `kornia` with `pip install "
157
+ "kornia`."
158
+ )
159
+
160
+ # Set start to earliest start for any freq component and stop to latest stop for any freq component
161
+ min_start = start if isinstance(start, float) else min(start)
162
+ max_stop = stop if isinstance(stop, float) else max(stop)
163
+ super().__init__(min_start, max_stop)
164
+
165
+ self.guidance_scales = guidance_scales
166
+ self.levels = len(guidance_scales)
167
+
168
+ if isinstance(guidance_rescale, float):
169
+ self.guidance_rescale = [guidance_rescale] * self.levels
170
+ elif len(guidance_rescale) == self.levels:
171
+ self.guidance_rescale = guidance_rescale
172
+ else:
173
+ raise ValueError(
174
+ f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
175
+ f"`guidance_scales` ({len(self.guidance_scales)})"
176
+ )
177
+ # Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
178
+ # transforming from frequency space back to data space)
179
+ if guidance_rescale_space not in ["data", "freq"]:
180
+ raise ValueError(
181
+ f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
182
+ )
183
+ self.guidance_rescale_space = guidance_rescale_space
184
+
185
+ if parallel_weights is None:
186
+ # Use normal CFG shift (equal weights for parallel and orthogonal components)
187
+ self.parallel_weights = [1.0] * self.levels
188
+ elif isinstance(parallel_weights, float):
189
+ self.parallel_weights = [parallel_weights] * self.levels
190
+ elif len(parallel_weights) == self.levels:
191
+ self.parallel_weights = parallel_weights
192
+ else:
193
+ raise ValueError(
194
+ f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
195
+ f"`guidance_scales` ({len(self.guidance_scales)})"
196
+ )
197
+
198
+ self.use_original_formulation = use_original_formulation
199
+ self.upcast_to_double = upcast_to_double
200
+
201
+ if isinstance(start, float):
202
+ self.guidance_start = [start] * self.levels
203
+ elif len(start) == self.levels:
204
+ self.guidance_start = start
205
+ else:
206
+ raise ValueError(
207
+ f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
208
+ f"({len(self.guidance_scales)})"
209
+ )
210
+ if isinstance(stop, float):
211
+ self.guidance_stop = [stop] * self.levels
212
+ elif len(stop) == self.levels:
213
+ self.guidance_stop = stop
214
+ else:
215
+ raise ValueError(
216
+ f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
217
+ f"({len(self.guidance_scales)})"
218
+ )
219
+
220
+ def prepare_inputs(
221
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
222
+ ) -> List["BlockState"]:
223
+ if input_fields is None:
224
+ input_fields = self._input_fields
225
+
226
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
227
+ data_batches = []
228
+ for i in range(self.num_conditions):
229
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
230
+ data_batches.append(data_batch)
231
+ return data_batches
232
+
233
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
234
+ pred = None
235
+
236
+ if not self._is_fdg_enabled():
237
+ pred = pred_cond
238
+ else:
239
+ # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
240
+ pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
241
+ pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
242
+
243
+ # From high frequencies to low frequencies, following the paper implementation
244
+ pred_guided_pyramid = []
245
+ parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
246
+ for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
247
+ if self._is_fdg_enabled_for_level(level):
248
+ # Get the cond/uncond preds (in freq space) at the current frequency level
249
+ pred_cond_freq = pred_cond_pyramid[level]
250
+ pred_uncond_freq = pred_uncond_pyramid[level]
251
+
252
+ shift = pred_cond_freq - pred_uncond_freq
253
+
254
+ # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
255
+ if not math.isclose(parallel_weight, 1.0):
256
+ shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
257
+ shift = parallel_weight * shift_parallel + shift_orthogonal
258
+
259
+ # Apply CFG update for the current frequency level
260
+ pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
261
+ pred = pred + guidance_scale * shift
262
+
263
+ if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
264
+ pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
265
+
266
+ # Add the current FDG guided level to the FDG prediction pyramid
267
+ pred_guided_pyramid.append(pred)
268
+ else:
269
+ # Add the current pred_cond_pyramid level as the "non-FDG" prediction
270
+ pred_guided_pyramid.append(pred_cond_freq)
271
+
272
+ # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
273
+ pred = build_image_from_pyramid(pred_guided_pyramid)
274
+
275
+ # If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
276
+ # across all freq levels
277
+ if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
278
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
279
+
280
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
281
+
282
+ @property
283
+ def is_conditional(self) -> bool:
284
+ return self._count_prepared == 1
285
+
286
+ @property
287
+ def num_conditions(self) -> int:
288
+ num_conditions = 1
289
+ if self._is_fdg_enabled():
290
+ num_conditions += 1
291
+ return num_conditions
292
+
293
+ def _is_fdg_enabled(self) -> bool:
294
+ if not self._enabled:
295
+ return False
296
+
297
+ is_within_range = True
298
+ if self._num_inference_steps is not None:
299
+ skip_start_step = int(self._start * self._num_inference_steps)
300
+ skip_stop_step = int(self._stop * self._num_inference_steps)
301
+ is_within_range = skip_start_step <= self._step < skip_stop_step
302
+
303
+ is_close = False
304
+ if self.use_original_formulation:
305
+ is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
306
+ else:
307
+ is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
308
+
309
+ return is_within_range and not is_close
310
+
311
+ def _is_fdg_enabled_for_level(self, level: int) -> bool:
312
+ if not self._enabled:
313
+ return False
314
+
315
+ is_within_range = True
316
+ if self._num_inference_steps is not None:
317
+ skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
318
+ skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
319
+ is_within_range = skip_start_step <= self._step < skip_stop_step
320
+
321
+ is_close = False
322
+ if self.use_original_formulation:
323
+ is_close = math.isclose(self.guidance_scales[level], 0.0)
324
+ else:
325
+ is_close = math.isclose(self.guidance_scales[level], 1.0)
326
+
327
+ return is_within_range and not is_close
pythonProject/diffusers-main/build/lib/diffusers/guiders/guider_utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from huggingface_hub.utils import validate_hf_hub_args
20
+ from typing_extensions import Self
21
+
22
+ from ..configuration_utils import ConfigMixin
23
+ from ..utils import BaseOutput, PushToHubMixin, get_logger
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ GUIDER_CONFIG_NAME = "guider_config.json"
31
+
32
+
33
+ logger = get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class BaseGuidance(ConfigMixin, PushToHubMixin):
37
+ r"""Base class providing the skeleton for implementing guidance techniques."""
38
+
39
+ config_name = GUIDER_CONFIG_NAME
40
+ _input_predictions = None
41
+ _identifier_key = "__guidance_identifier__"
42
+
43
+ def __init__(self, start: float = 0.0, stop: float = 1.0):
44
+ self._start = start
45
+ self._stop = stop
46
+ self._step: int = None
47
+ self._num_inference_steps: int = None
48
+ self._timestep: torch.LongTensor = None
49
+ self._count_prepared = 0
50
+ self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
51
+ self._enabled = True
52
+
53
+ if not (0.0 <= start < 1.0):
54
+ raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
55
+ if not (start <= stop <= 1.0):
56
+ raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
57
+
58
+ if self._input_predictions is None or not isinstance(self._input_predictions, list):
59
+ raise ValueError(
60
+ "`_input_predictions` must be a list of required prediction names for the guidance technique."
61
+ )
62
+
63
+ def disable(self):
64
+ self._enabled = False
65
+
66
+ def enable(self):
67
+ self._enabled = True
68
+
69
+ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
70
+ self._step = step
71
+ self._num_inference_steps = num_inference_steps
72
+ self._timestep = timestep
73
+ self._count_prepared = 0
74
+
75
+ def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
76
+ """
77
+ Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
78
+ attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
79
+ the values of the provided keyword arguments to this method.
80
+
81
+ Args:
82
+ **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
83
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
84
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
85
+ to look up the required data provided for preparation.
86
+
87
+ If a string is provided, it will be used as the conditional data (or unconditional if used with a
88
+ guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
89
+ conditional data identifier and the second element must be the unconditional data identifier or None.
90
+
91
+ Example:
92
+ ```
93
+ data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
94
+
95
+ BaseGuidance.set_input_fields(
96
+ latents="latents",
97
+ prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
98
+ )
99
+ ```
100
+ """
101
+ for key, value in kwargs.items():
102
+ is_string = isinstance(value, str)
103
+ is_tuple_of_str_with_len_2 = (
104
+ isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
105
+ )
106
+ if not (is_string or is_tuple_of_str_with_len_2):
107
+ raise ValueError(
108
+ f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
109
+ )
110
+ self._input_fields = kwargs
111
+
112
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
113
+ """
114
+ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
115
+ subclasses to implement specific model preparation logic.
116
+ """
117
+ self._count_prepared += 1
118
+
119
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
120
+ """
121
+ Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
122
+ in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
123
+ modifications made during `prepare_models`.
124
+ """
125
+ pass
126
+
127
+ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
128
+ raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
129
+
130
+ def __call__(self, data: List["BlockState"]) -> Any:
131
+ if not all(hasattr(d, "noise_pred") for d in data):
132
+ raise ValueError("Expected all data to have `noise_pred` attribute.")
133
+ if len(data) != self.num_conditions:
134
+ raise ValueError(
135
+ f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
136
+ )
137
+ forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
138
+ return self.forward(**forward_inputs)
139
+
140
+ def forward(self, *args, **kwargs) -> Any:
141
+ raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
142
+
143
+ @property
144
+ def is_conditional(self) -> bool:
145
+ raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
146
+
147
+ @property
148
+ def is_unconditional(self) -> bool:
149
+ return not self.is_conditional
150
+
151
+ @property
152
+ def num_conditions(self) -> int:
153
+ raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
154
+
155
+ @classmethod
156
+ def _prepare_batch(
157
+ cls,
158
+ input_fields: Dict[str, Union[str, Tuple[str, str]]],
159
+ data: "BlockState",
160
+ tuple_index: int,
161
+ identifier: str,
162
+ ) -> "BlockState":
163
+ """
164
+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
165
+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
166
+
167
+ Args:
168
+ input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
169
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
170
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
171
+ to look up the required data provided for preparation. If a string is provided, it will be used as the
172
+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
173
+ length 2 is provided, the first element must be the conditional data identifier and the second element
174
+ must be the unconditional data identifier or None.
175
+ data (`BlockState`):
176
+ The input data to be prepared.
177
+ tuple_index (`int`):
178
+ The index to use when accessing input fields that are tuples.
179
+
180
+ Returns:
181
+ `BlockState`: The prepared batch of data.
182
+ """
183
+ from ..modular_pipelines.modular_pipeline import BlockState
184
+
185
+ if input_fields is None:
186
+ raise ValueError(
187
+ "Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
188
+ )
189
+ data_batch = {}
190
+ for key, value in input_fields.items():
191
+ try:
192
+ if isinstance(value, str):
193
+ data_batch[key] = getattr(data, value)
194
+ elif isinstance(value, tuple):
195
+ data_batch[key] = getattr(data, value[tuple_index])
196
+ else:
197
+ # We've already checked that value is a string or a tuple of strings with length 2
198
+ pass
199
+ except AttributeError:
200
+ logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
201
+ data_batch[cls._identifier_key] = identifier
202
+ return BlockState(**data_batch)
203
+
204
+ @classmethod
205
+ @validate_hf_hub_args
206
+ def from_pretrained(
207
+ cls,
208
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
209
+ subfolder: Optional[str] = None,
210
+ return_unused_kwargs=False,
211
+ **kwargs,
212
+ ) -> Self:
213
+ r"""
214
+ Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
215
+
216
+ Parameters:
217
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
218
+ Can be either:
219
+
220
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
221
+ the Hub.
222
+ - A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
223
+ saved with [`~BaseGuidance.save_pretrained`].
224
+ subfolder (`str`, *optional*):
225
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
226
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
227
+ Whether kwargs that are not consumed by the Python class should be returned or not.
228
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
229
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
230
+ is not used.
231
+ force_download (`bool`, *optional*, defaults to `False`):
232
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
233
+ cached versions if they exist.
234
+
235
+ proxies (`Dict[str, str]`, *optional*):
236
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
237
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
238
+ output_loading_info(`bool`, *optional*, defaults to `False`):
239
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
240
+ local_files_only(`bool`, *optional*, defaults to `False`):
241
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
242
+ won't be downloaded from the Hub.
243
+ token (`str` or *bool*, *optional*):
244
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
245
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
246
+ revision (`str`, *optional*, defaults to `"main"`):
247
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
248
+ allowed by Git.
249
+
250
+ <Tip>
251
+
252
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
253
+ auth login`. You can also activate the special
254
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
255
+ firewalled environment.
256
+
257
+ </Tip>
258
+
259
+ """
260
+ config, kwargs, commit_hash = cls.load_config(
261
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
262
+ subfolder=subfolder,
263
+ return_unused_kwargs=True,
264
+ return_commit_hash=True,
265
+ **kwargs,
266
+ )
267
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
268
+
269
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
270
+ """
271
+ Save a guider configuration object to a directory so that it can be reloaded using the
272
+ [`~BaseGuidance.from_pretrained`] class method.
273
+
274
+ Args:
275
+ save_directory (`str` or `os.PathLike`):
276
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
277
+ push_to_hub (`bool`, *optional*, defaults to `False`):
278
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
279
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
280
+ namespace).
281
+ kwargs (`Dict[str, Any]`, *optional*):
282
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
283
+ """
284
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
285
+
286
+
287
+ class GuiderOutput(BaseOutput):
288
+ pred: torch.Tensor
289
+ pred_cond: Optional[torch.Tensor]
290
+ pred_uncond: Optional[torch.Tensor]
291
+
292
+
293
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
294
+ r"""
295
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
296
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
297
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
298
+
299
+ Args:
300
+ noise_cfg (`torch.Tensor`):
301
+ The predicted noise tensor for the guided diffusion process.
302
+ noise_pred_text (`torch.Tensor`):
303
+ The predicted noise tensor for the text-guided diffusion process.
304
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
305
+ A rescale factor applied to the noise predictions.
306
+ Returns:
307
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
308
+ """
309
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
310
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
311
+ # rescale the results from guidance (fixes overexposure)
312
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
313
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
314
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
315
+ return noise_cfg
pythonProject/diffusers-main/build/lib/diffusers/guiders/perturbed_attention_guidance.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..hooks import HookRegistry, LayerSkipConfig
22
+ from ..hooks.layer_skip import _apply_layer_skip_hook
23
+ from ..utils import get_logger
24
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from ..modular_pipelines.modular_pipeline import BlockState
29
+
30
+
31
+ logger = get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ class PerturbedAttentionGuidance(BaseGuidance):
35
+ """
36
+ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
37
+
38
+ The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
39
+ worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
40
+ of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
41
+ attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
42
+ layers.
43
+
44
+ Additional reading:
45
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
46
+
47
+ PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
48
+ and implementation details.
49
+
50
+ Args:
51
+ guidance_scale (`float`, defaults to `7.5`):
52
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
53
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
54
+ deterioration of image quality.
55
+ perturbed_guidance_scale (`float`, defaults to `2.8`):
56
+ The scale parameter for perturbed attention guidance.
57
+ perturbed_guidance_start (`float`, defaults to `0.01`):
58
+ The fraction of the total number of denoising steps after which perturbed attention guidance starts.
59
+ perturbed_guidance_stop (`float`, defaults to `0.2`):
60
+ The fraction of the total number of denoising steps after which perturbed attention guidance stops.
61
+ perturbed_guidance_layers (`int` or `List[int]`, *optional*):
62
+ The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
63
+ If not provided, `perturbed_guidance_config` must be provided.
64
+ perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
65
+ The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
66
+ `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
67
+ guidance_rescale (`float`, defaults to `0.0`):
68
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
69
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
70
+ Flawed](https://huggingface.co/papers/2305.08891).
71
+ use_original_formulation (`bool`, defaults to `False`):
72
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
73
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
74
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
75
+ start (`float`, defaults to `0.01`):
76
+ The fraction of the total number of denoising steps after which guidance starts.
77
+ stop (`float`, defaults to `0.2`):
78
+ The fraction of the total number of denoising steps after which guidance stops.
79
+ """
80
+
81
+ # NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
82
+ # the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
83
+ # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
84
+ # for each model architecture.
85
+
86
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ guidance_scale: float = 7.5,
92
+ perturbed_guidance_scale: float = 2.8,
93
+ perturbed_guidance_start: float = 0.01,
94
+ perturbed_guidance_stop: float = 0.2,
95
+ perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
96
+ perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
97
+ guidance_rescale: float = 0.0,
98
+ use_original_formulation: bool = False,
99
+ start: float = 0.0,
100
+ stop: float = 1.0,
101
+ ):
102
+ super().__init__(start, stop)
103
+
104
+ self.guidance_scale = guidance_scale
105
+ self.skip_layer_guidance_scale = perturbed_guidance_scale
106
+ self.skip_layer_guidance_start = perturbed_guidance_start
107
+ self.skip_layer_guidance_stop = perturbed_guidance_stop
108
+ self.guidance_rescale = guidance_rescale
109
+ self.use_original_formulation = use_original_formulation
110
+
111
+ if perturbed_guidance_config is None:
112
+ if perturbed_guidance_layers is None:
113
+ raise ValueError(
114
+ "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
115
+ )
116
+ perturbed_guidance_config = LayerSkipConfig(
117
+ indices=perturbed_guidance_layers,
118
+ fqn="auto",
119
+ skip_attention=False,
120
+ skip_attention_scores=True,
121
+ skip_ff=False,
122
+ )
123
+ else:
124
+ if perturbed_guidance_layers is not None:
125
+ raise ValueError(
126
+ "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
127
+ )
128
+
129
+ if isinstance(perturbed_guidance_config, dict):
130
+ perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
131
+
132
+ if isinstance(perturbed_guidance_config, LayerSkipConfig):
133
+ perturbed_guidance_config = [perturbed_guidance_config]
134
+
135
+ if not isinstance(perturbed_guidance_config, list):
136
+ raise ValueError(
137
+ "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
138
+ )
139
+ elif isinstance(next(iter(perturbed_guidance_config), None), dict):
140
+ perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
141
+
142
+ for config in perturbed_guidance_config:
143
+ if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
144
+ logger.warning(
145
+ "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
146
+ "Please check your configuration. Modifying the config to match the expected values."
147
+ )
148
+ config.skip_attention = False
149
+ config.skip_attention_scores = True
150
+ config.skip_ff = False
151
+
152
+ self.skip_layer_config = perturbed_guidance_config
153
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
154
+
155
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
156
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
157
+ self._count_prepared += 1
158
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
159
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
160
+ _apply_layer_skip_hook(denoiser, config, name=name)
161
+
162
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
163
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
164
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
165
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
166
+ # Remove the hooks after inference
167
+ for hook_name in self._skip_layer_hook_names:
168
+ registry.remove_hook(hook_name, recurse=True)
169
+
170
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
171
+ def prepare_inputs(
172
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
173
+ ) -> List["BlockState"]:
174
+ if input_fields is None:
175
+ input_fields = self._input_fields
176
+
177
+ if self.num_conditions == 1:
178
+ tuple_indices = [0]
179
+ input_predictions = ["pred_cond"]
180
+ elif self.num_conditions == 2:
181
+ tuple_indices = [0, 1]
182
+ input_predictions = (
183
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
184
+ )
185
+ else:
186
+ tuple_indices = [0, 1, 0]
187
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
188
+ data_batches = []
189
+ for i in range(self.num_conditions):
190
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
191
+ data_batches.append(data_batch)
192
+ return data_batches
193
+
194
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
195
+ def forward(
196
+ self,
197
+ pred_cond: torch.Tensor,
198
+ pred_uncond: Optional[torch.Tensor] = None,
199
+ pred_cond_skip: Optional[torch.Tensor] = None,
200
+ ) -> GuiderOutput:
201
+ pred = None
202
+
203
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
204
+ pred = pred_cond
205
+ elif not self._is_cfg_enabled():
206
+ shift = pred_cond - pred_cond_skip
207
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
208
+ pred = pred + self.skip_layer_guidance_scale * shift
209
+ elif not self._is_slg_enabled():
210
+ shift = pred_cond - pred_uncond
211
+ pred = pred_cond if self.use_original_formulation else pred_uncond
212
+ pred = pred + self.guidance_scale * shift
213
+ else:
214
+ shift = pred_cond - pred_uncond
215
+ shift_skip = pred_cond - pred_cond_skip
216
+ pred = pred_cond if self.use_original_formulation else pred_uncond
217
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
218
+
219
+ if self.guidance_rescale > 0.0:
220
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
221
+
222
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
223
+
224
+ @property
225
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
226
+ def is_conditional(self) -> bool:
227
+ return self._count_prepared == 1 or self._count_prepared == 3
228
+
229
+ @property
230
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
231
+ def num_conditions(self) -> int:
232
+ num_conditions = 1
233
+ if self._is_cfg_enabled():
234
+ num_conditions += 1
235
+ if self._is_slg_enabled():
236
+ num_conditions += 1
237
+ return num_conditions
238
+
239
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
240
+ def _is_cfg_enabled(self) -> bool:
241
+ if not self._enabled:
242
+ return False
243
+
244
+ is_within_range = True
245
+ if self._num_inference_steps is not None:
246
+ skip_start_step = int(self._start * self._num_inference_steps)
247
+ skip_stop_step = int(self._stop * self._num_inference_steps)
248
+ is_within_range = skip_start_step <= self._step < skip_stop_step
249
+
250
+ is_close = False
251
+ if self.use_original_formulation:
252
+ is_close = math.isclose(self.guidance_scale, 0.0)
253
+ else:
254
+ is_close = math.isclose(self.guidance_scale, 1.0)
255
+
256
+ return is_within_range and not is_close
257
+
258
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
259
+ def _is_slg_enabled(self) -> bool:
260
+ if not self._enabled:
261
+ return False
262
+
263
+ is_within_range = True
264
+ if self._num_inference_steps is not None:
265
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
266
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
267
+ is_within_range = skip_start_step < self._step < skip_stop_step
268
+
269
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
270
+
271
+ return is_within_range and not is_zero
pythonProject/diffusers-main/build/lib/diffusers/guiders/skip_layer_guidance.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..hooks import HookRegistry, LayerSkipConfig
22
+ from ..hooks.layer_skip import _apply_layer_skip_hook
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class SkipLayerGuidance(BaseGuidance):
31
+ """
32
+ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
33
+
34
+ Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
35
+
36
+ SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
37
+ skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
38
+ batch of data, apart from the conditional and unconditional batches already used in CFG
39
+ ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
40
+ based on the difference between conditional without skipping and conditional with skipping predictions.
41
+
42
+ The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
43
+ worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
44
+ version of the model for the conditional prediction).
45
+
46
+ STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
47
+ generation quality in video diffusion models.
48
+
49
+ Additional reading:
50
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
51
+
52
+ The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
53
+ defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
54
+
55
+ Args:
56
+ guidance_scale (`float`, defaults to `7.5`):
57
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
58
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
59
+ deterioration of image quality.
60
+ skip_layer_guidance_scale (`float`, defaults to `2.8`):
61
+ The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
62
+ values, but it may also lead to overexposure and saturation.
63
+ skip_layer_guidance_start (`float`, defaults to `0.01`):
64
+ The fraction of the total number of denoising steps after which skip layer guidance starts.
65
+ skip_layer_guidance_stop (`float`, defaults to `0.2`):
66
+ The fraction of the total number of denoising steps after which skip layer guidance stops.
67
+ skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
68
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
69
+ provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
70
+ 3.5 Medium.
71
+ skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
72
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
73
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
74
+ guidance_rescale (`float`, defaults to `0.0`):
75
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
76
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
77
+ Flawed](https://huggingface.co/papers/2305.08891).
78
+ use_original_formulation (`bool`, defaults to `False`):
79
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
80
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
81
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
82
+ start (`float`, defaults to `0.01`):
83
+ The fraction of the total number of denoising steps after which guidance starts.
84
+ stop (`float`, defaults to `0.2`):
85
+ The fraction of the total number of denoising steps after which guidance stops.
86
+ """
87
+
88
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
89
+
90
+ @register_to_config
91
+ def __init__(
92
+ self,
93
+ guidance_scale: float = 7.5,
94
+ skip_layer_guidance_scale: float = 2.8,
95
+ skip_layer_guidance_start: float = 0.01,
96
+ skip_layer_guidance_stop: float = 0.2,
97
+ skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
98
+ skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
99
+ guidance_rescale: float = 0.0,
100
+ use_original_formulation: bool = False,
101
+ start: float = 0.0,
102
+ stop: float = 1.0,
103
+ ):
104
+ super().__init__(start, stop)
105
+
106
+ self.guidance_scale = guidance_scale
107
+ self.skip_layer_guidance_scale = skip_layer_guidance_scale
108
+ self.skip_layer_guidance_start = skip_layer_guidance_start
109
+ self.skip_layer_guidance_stop = skip_layer_guidance_stop
110
+ self.guidance_rescale = guidance_rescale
111
+ self.use_original_formulation = use_original_formulation
112
+
113
+ if not (0.0 <= skip_layer_guidance_start < 1.0):
114
+ raise ValueError(
115
+ f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
116
+ )
117
+ if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
118
+ raise ValueError(
119
+ f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
120
+ )
121
+
122
+ if skip_layer_guidance_layers is None and skip_layer_config is None:
123
+ raise ValueError(
124
+ "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
125
+ )
126
+ if skip_layer_guidance_layers is not None and skip_layer_config is not None:
127
+ raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
128
+
129
+ if skip_layer_guidance_layers is not None:
130
+ if isinstance(skip_layer_guidance_layers, int):
131
+ skip_layer_guidance_layers = [skip_layer_guidance_layers]
132
+ if not isinstance(skip_layer_guidance_layers, list):
133
+ raise ValueError(
134
+ f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
135
+ )
136
+ skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
137
+
138
+ if isinstance(skip_layer_config, dict):
139
+ skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
140
+
141
+ if isinstance(skip_layer_config, LayerSkipConfig):
142
+ skip_layer_config = [skip_layer_config]
143
+
144
+ if not isinstance(skip_layer_config, list):
145
+ raise ValueError(
146
+ f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
147
+ )
148
+ elif isinstance(next(iter(skip_layer_config), None), dict):
149
+ skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
150
+
151
+ self.skip_layer_config = skip_layer_config
152
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
153
+
154
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
155
+ self._count_prepared += 1
156
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
157
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
158
+ _apply_layer_skip_hook(denoiser, config, name=name)
159
+
160
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
161
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
162
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
163
+ # Remove the hooks after inference
164
+ for hook_name in self._skip_layer_hook_names:
165
+ registry.remove_hook(hook_name, recurse=True)
166
+
167
+ def prepare_inputs(
168
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
169
+ ) -> List["BlockState"]:
170
+ if input_fields is None:
171
+ input_fields = self._input_fields
172
+
173
+ if self.num_conditions == 1:
174
+ tuple_indices = [0]
175
+ input_predictions = ["pred_cond"]
176
+ elif self.num_conditions == 2:
177
+ tuple_indices = [0, 1]
178
+ input_predictions = (
179
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
180
+ )
181
+ else:
182
+ tuple_indices = [0, 1, 0]
183
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
184
+ data_batches = []
185
+ for i in range(self.num_conditions):
186
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
187
+ data_batches.append(data_batch)
188
+ return data_batches
189
+
190
+ def forward(
191
+ self,
192
+ pred_cond: torch.Tensor,
193
+ pred_uncond: Optional[torch.Tensor] = None,
194
+ pred_cond_skip: Optional[torch.Tensor] = None,
195
+ ) -> GuiderOutput:
196
+ pred = None
197
+
198
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
199
+ pred = pred_cond
200
+ elif not self._is_cfg_enabled():
201
+ shift = pred_cond - pred_cond_skip
202
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
203
+ pred = pred + self.skip_layer_guidance_scale * shift
204
+ elif not self._is_slg_enabled():
205
+ shift = pred_cond - pred_uncond
206
+ pred = pred_cond if self.use_original_formulation else pred_uncond
207
+ pred = pred + self.guidance_scale * shift
208
+ else:
209
+ shift = pred_cond - pred_uncond
210
+ shift_skip = pred_cond - pred_cond_skip
211
+ pred = pred_cond if self.use_original_formulation else pred_uncond
212
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
213
+
214
+ if self.guidance_rescale > 0.0:
215
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
216
+
217
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
218
+
219
+ @property
220
+ def is_conditional(self) -> bool:
221
+ return self._count_prepared == 1 or self._count_prepared == 3
222
+
223
+ @property
224
+ def num_conditions(self) -> int:
225
+ num_conditions = 1
226
+ if self._is_cfg_enabled():
227
+ num_conditions += 1
228
+ if self._is_slg_enabled():
229
+ num_conditions += 1
230
+ return num_conditions
231
+
232
+ def _is_cfg_enabled(self) -> bool:
233
+ if not self._enabled:
234
+ return False
235
+
236
+ is_within_range = True
237
+ if self._num_inference_steps is not None:
238
+ skip_start_step = int(self._start * self._num_inference_steps)
239
+ skip_stop_step = int(self._stop * self._num_inference_steps)
240
+ is_within_range = skip_start_step <= self._step < skip_stop_step
241
+
242
+ is_close = False
243
+ if self.use_original_formulation:
244
+ is_close = math.isclose(self.guidance_scale, 0.0)
245
+ else:
246
+ is_close = math.isclose(self.guidance_scale, 1.0)
247
+
248
+ return is_within_range and not is_close
249
+
250
+ def _is_slg_enabled(self) -> bool:
251
+ if not self._enabled:
252
+ return False
253
+
254
+ is_within_range = True
255
+ if self._num_inference_steps is not None:
256
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
257
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
258
+ is_within_range = skip_start_step < self._step < skip_stop_step
259
+
260
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
261
+
262
+ return is_within_range and not is_zero
pythonProject/diffusers-main/build/lib/diffusers/guiders/smoothed_energy_guidance.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from ..hooks import HookRegistry
22
+ from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
23
+ from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from ..modular_pipelines.modular_pipeline import BlockState
28
+
29
+
30
+ class SmoothedEnergyGuidance(BaseGuidance):
31
+ """
32
+ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
33
+
34
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
35
+ future without warning or guarantee of reproducibility. This implementation assumes:
36
+ - Generated images are square (height == width)
37
+ - The model does not combine different modalities together (e.g., text and image latent streams are not combined
38
+ together such as Flux)
39
+
40
+ Args:
41
+ guidance_scale (`float`, defaults to `7.5`):
42
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
43
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
44
+ deterioration of image quality.
45
+ seg_guidance_scale (`float`, defaults to `3.0`):
46
+ The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
47
+ values, but it may also lead to overexposure and saturation.
48
+ seg_blur_sigma (`float`, defaults to `9999999.0`):
49
+ The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
50
+ infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
51
+ seg_blur_threshold_inf (`float`, defaults to `9999.0`):
52
+ The threshold above which the blur is considered infinite.
53
+ seg_guidance_start (`float`, defaults to `0.0`):
54
+ The fraction of the total number of denoising steps after which smoothed energy guidance starts.
55
+ seg_guidance_stop (`float`, defaults to `1.0`):
56
+ The fraction of the total number of denoising steps after which smoothed energy guidance stops.
57
+ seg_guidance_layers (`int` or `List[int]`, *optional*):
58
+ The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
59
+ not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
60
+ Diffusion 3.5 Medium.
61
+ seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
62
+ The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
63
+ a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
64
+ guidance_rescale (`float`, defaults to `0.0`):
65
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
66
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
67
+ Flawed](https://huggingface.co/papers/2305.08891).
68
+ use_original_formulation (`bool`, defaults to `False`):
69
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
70
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
71
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
72
+ start (`float`, defaults to `0.01`):
73
+ The fraction of the total number of denoising steps after which guidance starts.
74
+ stop (`float`, defaults to `0.2`):
75
+ The fraction of the total number of denoising steps after which guidance stops.
76
+ """
77
+
78
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
79
+
80
+ @register_to_config
81
+ def __init__(
82
+ self,
83
+ guidance_scale: float = 7.5,
84
+ seg_guidance_scale: float = 2.8,
85
+ seg_blur_sigma: float = 9999999.0,
86
+ seg_blur_threshold_inf: float = 9999.0,
87
+ seg_guidance_start: float = 0.0,
88
+ seg_guidance_stop: float = 1.0,
89
+ seg_guidance_layers: Optional[Union[int, List[int]]] = None,
90
+ seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
91
+ guidance_rescale: float = 0.0,
92
+ use_original_formulation: bool = False,
93
+ start: float = 0.0,
94
+ stop: float = 1.0,
95
+ ):
96
+ super().__init__(start, stop)
97
+
98
+ self.guidance_scale = guidance_scale
99
+ self.seg_guidance_scale = seg_guidance_scale
100
+ self.seg_blur_sigma = seg_blur_sigma
101
+ self.seg_blur_threshold_inf = seg_blur_threshold_inf
102
+ self.seg_guidance_start = seg_guidance_start
103
+ self.seg_guidance_stop = seg_guidance_stop
104
+ self.guidance_rescale = guidance_rescale
105
+ self.use_original_formulation = use_original_formulation
106
+
107
+ if not (0.0 <= seg_guidance_start < 1.0):
108
+ raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
109
+ if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
110
+ raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
111
+
112
+ if seg_guidance_layers is None and seg_guidance_config is None:
113
+ raise ValueError(
114
+ "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
115
+ )
116
+ if seg_guidance_layers is not None and seg_guidance_config is not None:
117
+ raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
118
+
119
+ if seg_guidance_layers is not None:
120
+ if isinstance(seg_guidance_layers, int):
121
+ seg_guidance_layers = [seg_guidance_layers]
122
+ if not isinstance(seg_guidance_layers, list):
123
+ raise ValueError(
124
+ f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
125
+ )
126
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
127
+
128
+ if isinstance(seg_guidance_config, dict):
129
+ seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
130
+
131
+ if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
132
+ seg_guidance_config = [seg_guidance_config]
133
+
134
+ if not isinstance(seg_guidance_config, list):
135
+ raise ValueError(
136
+ f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
137
+ )
138
+ elif isinstance(next(iter(seg_guidance_config), None), dict):
139
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
140
+
141
+ self.seg_guidance_config = seg_guidance_config
142
+ self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
143
+
144
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
145
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
146
+ for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
147
+ _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
148
+
149
+ def cleanup_models(self, denoiser: torch.nn.Module):
150
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
151
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
152
+ # Remove the hooks after inference
153
+ for hook_name in self._seg_layer_hook_names:
154
+ registry.remove_hook(hook_name, recurse=True)
155
+
156
+ def prepare_inputs(
157
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
158
+ ) -> List["BlockState"]:
159
+ if input_fields is None:
160
+ input_fields = self._input_fields
161
+
162
+ if self.num_conditions == 1:
163
+ tuple_indices = [0]
164
+ input_predictions = ["pred_cond"]
165
+ elif self.num_conditions == 2:
166
+ tuple_indices = [0, 1]
167
+ input_predictions = (
168
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
169
+ )
170
+ else:
171
+ tuple_indices = [0, 1, 0]
172
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
173
+ data_batches = []
174
+ for i in range(self.num_conditions):
175
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
176
+ data_batches.append(data_batch)
177
+ return data_batches
178
+
179
+ def forward(
180
+ self,
181
+ pred_cond: torch.Tensor,
182
+ pred_uncond: Optional[torch.Tensor] = None,
183
+ pred_cond_seg: Optional[torch.Tensor] = None,
184
+ ) -> GuiderOutput:
185
+ pred = None
186
+
187
+ if not self._is_cfg_enabled() and not self._is_seg_enabled():
188
+ pred = pred_cond
189
+ elif not self._is_cfg_enabled():
190
+ shift = pred_cond - pred_cond_seg
191
+ pred = pred_cond if self.use_original_formulation else pred_cond_seg
192
+ pred = pred + self.seg_guidance_scale * shift
193
+ elif not self._is_seg_enabled():
194
+ shift = pred_cond - pred_uncond
195
+ pred = pred_cond if self.use_original_formulation else pred_uncond
196
+ pred = pred + self.guidance_scale * shift
197
+ else:
198
+ shift = pred_cond - pred_uncond
199
+ shift_seg = pred_cond - pred_cond_seg
200
+ pred = pred_cond if self.use_original_formulation else pred_uncond
201
+ pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
202
+
203
+ if self.guidance_rescale > 0.0:
204
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
205
+
206
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
207
+
208
+ @property
209
+ def is_conditional(self) -> bool:
210
+ return self._count_prepared == 1 or self._count_prepared == 3
211
+
212
+ @property
213
+ def num_conditions(self) -> int:
214
+ num_conditions = 1
215
+ if self._is_cfg_enabled():
216
+ num_conditions += 1
217
+ if self._is_seg_enabled():
218
+ num_conditions += 1
219
+ return num_conditions
220
+
221
+ def _is_cfg_enabled(self) -> bool:
222
+ if not self._enabled:
223
+ return False
224
+
225
+ is_within_range = True
226
+ if self._num_inference_steps is not None:
227
+ skip_start_step = int(self._start * self._num_inference_steps)
228
+ skip_stop_step = int(self._stop * self._num_inference_steps)
229
+ is_within_range = skip_start_step <= self._step < skip_stop_step
230
+
231
+ is_close = False
232
+ if self.use_original_formulation:
233
+ is_close = math.isclose(self.guidance_scale, 0.0)
234
+ else:
235
+ is_close = math.isclose(self.guidance_scale, 1.0)
236
+
237
+ return is_within_range and not is_close
238
+
239
+ def _is_seg_enabled(self) -> bool:
240
+ if not self._enabled:
241
+ return False
242
+
243
+ is_within_range = True
244
+ if self._num_inference_steps is not None:
245
+ skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
246
+ skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
247
+ is_within_range = skip_start_step < self._step < skip_stop_step
248
+
249
+ is_zero = math.isclose(self.seg_guidance_scale, 0.0)
250
+
251
+ return is_within_range and not is_zero
pythonProject/diffusers-main/build/lib/diffusers/training_utils.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import copy
3
+ import gc
4
+ import math
5
+ import random
6
+ import re
7
+ import warnings
8
+ from contextlib import contextmanager
9
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from .models import UNet2DConditionModel
15
+ from .pipelines import DiffusionPipeline
16
+ from .schedulers import SchedulerMixin
17
+ from .utils import (
18
+ convert_state_dict_to_diffusers,
19
+ convert_state_dict_to_peft,
20
+ deprecate,
21
+ is_peft_available,
22
+ is_torch_npu_available,
23
+ is_torchvision_available,
24
+ is_transformers_available,
25
+ )
26
+
27
+
28
+ if is_transformers_available():
29
+ import transformers
30
+
31
+ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
32
+ import deepspeed
33
+
34
+ if is_peft_available():
35
+ from peft import set_peft_model_state_dict
36
+
37
+ if is_torchvision_available():
38
+ from torchvision import transforms
39
+
40
+ if is_torch_npu_available():
41
+ import torch_npu # noqa: F401
42
+
43
+
44
+ def set_seed(seed: int):
45
+ """
46
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
47
+
48
+ Args:
49
+ seed (`int`): The seed to set.
50
+
51
+ Returns:
52
+ `None`
53
+ """
54
+ random.seed(seed)
55
+ np.random.seed(seed)
56
+ torch.manual_seed(seed)
57
+ if is_torch_npu_available():
58
+ torch.npu.manual_seed_all(seed)
59
+ else:
60
+ torch.cuda.manual_seed_all(seed)
61
+ # ^^ safe to call this function even if cuda is not available
62
+
63
+
64
+ def compute_snr(noise_scheduler, timesteps):
65
+ """
66
+ Computes SNR as per
67
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
68
+ for the given timesteps using the provided noise scheduler.
69
+
70
+ Args:
71
+ noise_scheduler (`NoiseScheduler`):
72
+ An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
73
+ the SNR values.
74
+ timesteps (`torch.Tensor`):
75
+ A tensor of timesteps for which the SNR is computed.
76
+
77
+ Returns:
78
+ `torch.Tensor`: A tensor containing the computed SNR values for each timestep.
79
+ """
80
+ alphas_cumprod = noise_scheduler.alphas_cumprod
81
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
82
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
83
+
84
+ # Expand the tensors.
85
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
86
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
87
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
88
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
89
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
90
+
91
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
92
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
93
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
94
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
95
+
96
+ # Compute SNR.
97
+ snr = (alpha / sigma) ** 2
98
+ return snr
99
+
100
+
101
+ def resolve_interpolation_mode(interpolation_type: str):
102
+ """
103
+ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
104
+ full list of supported enums is documented at
105
+ https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
106
+
107
+ Args:
108
+ interpolation_type (`str`):
109
+ A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
110
+ `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
111
+ in torchvision.
112
+
113
+ Returns:
114
+ `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
115
+ transform.
116
+ """
117
+ if not is_torchvision_available():
118
+ raise ImportError(
119
+ "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
120
+ )
121
+
122
+ if interpolation_type == "bilinear":
123
+ interpolation_mode = transforms.InterpolationMode.BILINEAR
124
+ elif interpolation_type == "bicubic":
125
+ interpolation_mode = transforms.InterpolationMode.BICUBIC
126
+ elif interpolation_type == "box":
127
+ interpolation_mode = transforms.InterpolationMode.BOX
128
+ elif interpolation_type == "nearest":
129
+ interpolation_mode = transforms.InterpolationMode.NEAREST
130
+ elif interpolation_type == "nearest_exact":
131
+ interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
132
+ elif interpolation_type == "hamming":
133
+ interpolation_mode = transforms.InterpolationMode.HAMMING
134
+ elif interpolation_type == "lanczos":
135
+ interpolation_mode = transforms.InterpolationMode.LANCZOS
136
+ else:
137
+ raise ValueError(
138
+ f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
139
+ f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
140
+ )
141
+
142
+ return interpolation_mode
143
+
144
+
145
+ def compute_dream_and_update_latents(
146
+ unet: UNet2DConditionModel,
147
+ noise_scheduler: SchedulerMixin,
148
+ timesteps: torch.Tensor,
149
+ noise: torch.Tensor,
150
+ noisy_latents: torch.Tensor,
151
+ target: torch.Tensor,
152
+ encoder_hidden_states: torch.Tensor,
153
+ dream_detail_preservation: float = 1.0,
154
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
155
+ """
156
+ Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
157
+ https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
158
+ efficient and accurate at the cost of an extra forward step without gradients.
159
+
160
+ Args:
161
+ `unet`: The state unet to use to make a prediction.
162
+ `noise_scheduler`: The noise scheduler used to add noise for the given timestep.
163
+ `timesteps`: The timesteps for the noise_scheduler to user.
164
+ `noise`: A tensor of noise in the shape of noisy_latents.
165
+ `noisy_latents`: Previously noise latents from the training loop.
166
+ `target`: The ground-truth tensor to predict after eps is removed.
167
+ `encoder_hidden_states`: Text embeddings from the text model.
168
+ `dream_detail_preservation`: A float value that indicates detail preservation level.
169
+ See reference.
170
+
171
+ Returns:
172
+ `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
173
+ """
174
+ alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
175
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
176
+
177
+ # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
178
+ dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
179
+
180
+ pred = None
181
+ with torch.no_grad():
182
+ pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
183
+
184
+ _noisy_latents, _target = (None, None)
185
+ if noise_scheduler.config.prediction_type == "epsilon":
186
+ predicted_noise = pred
187
+ delta_noise = (noise - predicted_noise).detach()
188
+ delta_noise.mul_(dream_lambda)
189
+ _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
190
+ _target = target.add(delta_noise)
191
+ elif noise_scheduler.config.prediction_type == "v_prediction":
192
+ raise NotImplementedError("DREAM has not been implemented for v-prediction")
193
+ else:
194
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
195
+
196
+ return _noisy_latents, _target
197
+
198
+
199
+ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
200
+ r"""
201
+ Returns:
202
+ A state dict containing just the LoRA parameters.
203
+ """
204
+ lora_state_dict = {}
205
+
206
+ for name, module in unet.named_modules():
207
+ if hasattr(module, "set_lora_layer"):
208
+ lora_layer = getattr(module, "lora_layer")
209
+ if lora_layer is not None:
210
+ current_lora_layer_sd = lora_layer.state_dict()
211
+ for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
212
+ # The matrix name can either be "down" or "up".
213
+ lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
214
+
215
+ return lora_state_dict
216
+
217
+
218
+ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
219
+ """
220
+ Casts the training parameters of the model to the specified data type.
221
+
222
+ Args:
223
+ model: The PyTorch model whose parameters will be cast.
224
+ dtype: The data type to which the model parameters will be cast.
225
+ """
226
+ if not isinstance(model, list):
227
+ model = [model]
228
+ for m in model:
229
+ for param in m.parameters():
230
+ # only upcast trainable parameters into fp32
231
+ if param.requires_grad:
232
+ param.data = param.to(dtype)
233
+
234
+
235
+ def _set_state_dict_into_text_encoder(
236
+ lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
237
+ ):
238
+ """
239
+ Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
240
+
241
+ Args:
242
+ lora_state_dict: The state dictionary to be set.
243
+ prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
244
+ text_encoder: Where the `lora_state_dict` is to be set.
245
+ """
246
+
247
+ text_encoder_state_dict = {
248
+ f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
249
+ }
250
+ text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
251
+ set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
252
+
253
+
254
+ def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
255
+ metadatas = {}
256
+ for module_name, module in modules_to_save.items():
257
+ if module is not None:
258
+ metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
259
+ return metadatas
260
+
261
+
262
+ def compute_density_for_timestep_sampling(
263
+ weighting_scheme: str,
264
+ batch_size: int,
265
+ logit_mean: float = None,
266
+ logit_std: float = None,
267
+ mode_scale: float = None,
268
+ device: Union[torch.device, str] = "cpu",
269
+ generator: Optional[torch.Generator] = None,
270
+ ):
271
+ """
272
+ Compute the density for sampling the timesteps when doing SD3 training.
273
+
274
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
275
+
276
+ SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
277
+ """
278
+ if weighting_scheme == "logit_normal":
279
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
280
+ u = torch.nn.functional.sigmoid(u)
281
+ elif weighting_scheme == "mode":
282
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
283
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
284
+ else:
285
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
286
+ return u
287
+
288
+
289
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
290
+ """
291
+ Computes loss weighting scheme for SD3 training.
292
+
293
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
294
+
295
+ SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
296
+ """
297
+ if weighting_scheme == "sigma_sqrt":
298
+ weighting = (sigmas**-2.0).float()
299
+ elif weighting_scheme == "cosmap":
300
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
301
+ weighting = 2 / (math.pi * bot)
302
+ else:
303
+ weighting = torch.ones_like(sigmas)
304
+ return weighting
305
+
306
+
307
+ def free_memory():
308
+ """
309
+ Runs garbage collection. Then clears the cache of the available accelerator.
310
+ """
311
+ gc.collect()
312
+
313
+ if torch.cuda.is_available():
314
+ torch.cuda.empty_cache()
315
+ elif torch.backends.mps.is_available():
316
+ torch.mps.empty_cache()
317
+ elif is_torch_npu_available():
318
+ torch_npu.npu.empty_cache()
319
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
320
+ torch.xpu.empty_cache()
321
+
322
+
323
+ @contextmanager
324
+ def offload_models(
325
+ *modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
326
+ ):
327
+ """
328
+ Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
329
+ device on exit.
330
+
331
+ Args:
332
+ device (`str` or `torch.Device`): Device to move the `modules` to.
333
+ offload (`bool`): Flag to enable offloading.
334
+ """
335
+ if offload:
336
+ is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
337
+ # record where each module was
338
+ if is_model:
339
+ original_devices = [next(m.parameters()).device for m in modules]
340
+ else:
341
+ assert len(modules) == 1
342
+ # For DiffusionPipeline, wrap the device in a list to make it iterable
343
+ original_devices = [modules[0].device]
344
+ # move to target device
345
+ for m in modules:
346
+ m.to(device)
347
+
348
+ try:
349
+ yield
350
+ finally:
351
+ if offload:
352
+ # move back to original devices
353
+ for m, orig_dev in zip(modules, original_devices):
354
+ m.to(orig_dev)
355
+
356
+
357
+ def parse_buckets_string(buckets_str):
358
+ """Parses a string defining buckets into a list of (height, width) tuples."""
359
+ if not buckets_str:
360
+ raise ValueError("Bucket string cannot be empty.")
361
+
362
+ bucket_pairs = buckets_str.strip().split(";")
363
+ parsed_buckets = []
364
+ for pair_str in bucket_pairs:
365
+ match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
366
+ if not match:
367
+ raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
368
+ try:
369
+ height = int(match.group(1))
370
+ width = int(match.group(2))
371
+ if height <= 0 or width <= 0:
372
+ raise ValueError("Bucket dimensions must be positive integers.")
373
+ if height % 8 != 0 or width % 8 != 0:
374
+ warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
375
+ parsed_buckets.append((height, width))
376
+ except ValueError as e:
377
+ raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e
378
+
379
+ if not parsed_buckets:
380
+ raise ValueError("No valid buckets found in the provided string.")
381
+
382
+ return parsed_buckets
383
+
384
+
385
+ def find_nearest_bucket(h, w, bucket_options):
386
+ """Finds the closes bucket to the given height and width."""
387
+ min_metric = float("inf")
388
+ best_bucket_idx = None
389
+ for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
390
+ metric = abs(h * bucket_w - w * bucket_h)
391
+ if metric <= min_metric:
392
+ min_metric = metric
393
+ best_bucket_idx = bucket_idx
394
+ return best_bucket_idx
395
+
396
+
397
+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
398
+ class EMAModel:
399
+ """
400
+ Exponential Moving Average of models weights
401
+ """
402
+
403
+ def __init__(
404
+ self,
405
+ parameters: Iterable[torch.nn.Parameter],
406
+ decay: float = 0.9999,
407
+ min_decay: float = 0.0,
408
+ update_after_step: int = 0,
409
+ use_ema_warmup: bool = False,
410
+ inv_gamma: Union[float, int] = 1.0,
411
+ power: Union[float, int] = 2 / 3,
412
+ foreach: bool = False,
413
+ model_cls: Optional[Any] = None,
414
+ model_config: Dict[str, Any] = None,
415
+ **kwargs,
416
+ ):
417
+ """
418
+ Args:
419
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
420
+ decay (float): The decay factor for the exponential moving average.
421
+ min_decay (float): The minimum decay factor for the exponential moving average.
422
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
423
+ use_ema_warmup (bool): Whether to use EMA warmup.
424
+ inv_gamma (float):
425
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
426
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
427
+ foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
428
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
429
+ weights will be stored on CPU.
430
+
431
+ @crowsonkb's notes on EMA Warmup:
432
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
433
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
434
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
435
+ at 215.4k steps).
436
+ """
437
+
438
+ if isinstance(parameters, torch.nn.Module):
439
+ deprecation_message = (
440
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
441
+ "Please pass the parameters of the module instead."
442
+ )
443
+ deprecate(
444
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
445
+ "1.0.0",
446
+ deprecation_message,
447
+ standard_warn=False,
448
+ )
449
+ parameters = parameters.parameters()
450
+
451
+ # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
452
+ use_ema_warmup = True
453
+
454
+ if kwargs.get("max_value", None) is not None:
455
+ deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
456
+ deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
457
+ decay = kwargs["max_value"]
458
+
459
+ if kwargs.get("min_value", None) is not None:
460
+ deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
461
+ deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
462
+ min_decay = kwargs["min_value"]
463
+
464
+ parameters = list(parameters)
465
+ self.shadow_params = [p.clone().detach() for p in parameters]
466
+
467
+ if kwargs.get("device", None) is not None:
468
+ deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
469
+ deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
470
+ self.to(device=kwargs["device"])
471
+
472
+ self.temp_stored_params = None
473
+
474
+ self.decay = decay
475
+ self.min_decay = min_decay
476
+ self.update_after_step = update_after_step
477
+ self.use_ema_warmup = use_ema_warmup
478
+ self.inv_gamma = inv_gamma
479
+ self.power = power
480
+ self.optimization_step = 0
481
+ self.cur_decay_value = None # set in `step()`
482
+ self.foreach = foreach
483
+
484
+ self.model_cls = model_cls
485
+ self.model_config = model_config
486
+
487
+ @classmethod
488
+ def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
489
+ _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
490
+ model = model_cls.from_pretrained(path)
491
+
492
+ ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
493
+
494
+ ema_model.load_state_dict(ema_kwargs)
495
+ return ema_model
496
+
497
+ def save_pretrained(self, path):
498
+ if self.model_cls is None:
499
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
500
+
501
+ if self.model_config is None:
502
+ raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
503
+
504
+ model = self.model_cls.from_config(self.model_config)
505
+ state_dict = self.state_dict()
506
+ state_dict.pop("shadow_params", None)
507
+
508
+ model.register_to_config(**state_dict)
509
+ self.copy_to(model.parameters())
510
+ model.save_pretrained(path)
511
+
512
+ def get_decay(self, optimization_step: int) -> float:
513
+ """
514
+ Compute the decay factor for the exponential moving average.
515
+ """
516
+ step = max(0, optimization_step - self.update_after_step - 1)
517
+
518
+ if step <= 0:
519
+ return 0.0
520
+
521
+ if self.use_ema_warmup:
522
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
523
+ else:
524
+ cur_decay_value = (1 + step) / (10 + step)
525
+
526
+ cur_decay_value = min(cur_decay_value, self.decay)
527
+ # make sure decay is not smaller than min_decay
528
+ cur_decay_value = max(cur_decay_value, self.min_decay)
529
+ return cur_decay_value
530
+
531
+ @torch.no_grad()
532
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
533
+ if isinstance(parameters, torch.nn.Module):
534
+ deprecation_message = (
535
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
536
+ "Please pass the parameters of the module instead."
537
+ )
538
+ deprecate(
539
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
540
+ "1.0.0",
541
+ deprecation_message,
542
+ standard_warn=False,
543
+ )
544
+ parameters = parameters.parameters()
545
+
546
+ parameters = list(parameters)
547
+
548
+ self.optimization_step += 1
549
+
550
+ # Compute the decay factor for the exponential moving average.
551
+ decay = self.get_decay(self.optimization_step)
552
+ self.cur_decay_value = decay
553
+ one_minus_decay = 1 - decay
554
+
555
+ context_manager = contextlib.nullcontext()
556
+
557
+ if self.foreach:
558
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
559
+ context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
560
+
561
+ with context_manager:
562
+ params_grad = [param for param in parameters if param.requires_grad]
563
+ s_params_grad = [
564
+ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
565
+ ]
566
+
567
+ if len(params_grad) < len(parameters):
568
+ torch._foreach_copy_(
569
+ [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
570
+ [param for param in parameters if not param.requires_grad],
571
+ non_blocking=True,
572
+ )
573
+
574
+ torch._foreach_sub_(
575
+ s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
576
+ )
577
+
578
+ else:
579
+ for s_param, param in zip(self.shadow_params, parameters):
580
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
581
+ context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
582
+
583
+ with context_manager:
584
+ if param.requires_grad:
585
+ s_param.sub_(one_minus_decay * (s_param - param))
586
+ else:
587
+ s_param.copy_(param)
588
+
589
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
590
+ """
591
+ Copy current averaged parameters into given collection of parameters.
592
+
593
+ Args:
594
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
595
+ updated with the stored moving averages. If `None`, the parameters with which this
596
+ `ExponentialMovingAverage` was initialized will be used.
597
+ """
598
+ parameters = list(parameters)
599
+ if self.foreach:
600
+ torch._foreach_copy_(
601
+ [param.data for param in parameters],
602
+ [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
603
+ )
604
+ else:
605
+ for s_param, param in zip(self.shadow_params, parameters):
606
+ param.data.copy_(s_param.to(param.device).data)
607
+
608
+ def pin_memory(self) -> None:
609
+ r"""
610
+ Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
611
+ offloading EMA params to the host.
612
+ """
613
+
614
+ self.shadow_params = [p.pin_memory() for p in self.shadow_params]
615
+
616
+ def to(self, device=None, dtype=None, non_blocking=False) -> None:
617
+ r"""
618
+ Move internal buffers of the ExponentialMovingAverage to `device`.
619
+
620
+ Args:
621
+ device: like `device` argument to `torch.Tensor.to`
622
+ """
623
+ # .to() on the tensors handles None correctly
624
+ self.shadow_params = [
625
+ p.to(device=device, dtype=dtype, non_blocking=non_blocking)
626
+ if p.is_floating_point()
627
+ else p.to(device=device, non_blocking=non_blocking)
628
+ for p in self.shadow_params
629
+ ]
630
+
631
+ def state_dict(self) -> dict:
632
+ r"""
633
+ Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
634
+ checkpointing to save the ema state dict.
635
+ """
636
+ # Following PyTorch conventions, references to tensors are returned:
637
+ # "returns a reference to the state and not its copy!" -
638
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
639
+ return {
640
+ "decay": self.decay,
641
+ "min_decay": self.min_decay,
642
+ "optimization_step": self.optimization_step,
643
+ "update_after_step": self.update_after_step,
644
+ "use_ema_warmup": self.use_ema_warmup,
645
+ "inv_gamma": self.inv_gamma,
646
+ "power": self.power,
647
+ "shadow_params": self.shadow_params,
648
+ }
649
+
650
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
651
+ r"""
652
+ Saves the current parameters for restoring later.
653
+
654
+ Args:
655
+ parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
656
+ """
657
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
658
+
659
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
660
+ r"""
661
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
662
+ without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
663
+ validation (or model saving), use this to restore the former parameters.
664
+
665
+ Args:
666
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
667
+ updated with the stored parameters. If `None`, the parameters with which this
668
+ `ExponentialMovingAverage` was initialized will be used.
669
+ """
670
+
671
+ if self.temp_stored_params is None:
672
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
673
+ if self.foreach:
674
+ torch._foreach_copy_(
675
+ [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
676
+ )
677
+ else:
678
+ for c_param, param in zip(self.temp_stored_params, parameters):
679
+ param.data.copy_(c_param.data)
680
+
681
+ # Better memory-wise.
682
+ self.temp_stored_params = None
683
+
684
+ def load_state_dict(self, state_dict: dict) -> None:
685
+ r"""
686
+ Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
687
+ ema state dict.
688
+
689
+ Args:
690
+ state_dict (dict): EMA state. Should be an object returned
691
+ from a call to :meth:`state_dict`.
692
+ """
693
+ # deepcopy, to be consistent with module API
694
+ state_dict = copy.deepcopy(state_dict)
695
+
696
+ self.decay = state_dict.get("decay", self.decay)
697
+ if self.decay < 0.0 or self.decay > 1.0:
698
+ raise ValueError("Decay must be between 0 and 1")
699
+
700
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
701
+ if not isinstance(self.min_decay, float):
702
+ raise ValueError("Invalid min_decay")
703
+
704
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
705
+ if not isinstance(self.optimization_step, int):
706
+ raise ValueError("Invalid optimization_step")
707
+
708
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
709
+ if not isinstance(self.update_after_step, int):
710
+ raise ValueError("Invalid update_after_step")
711
+
712
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
713
+ if not isinstance(self.use_ema_warmup, bool):
714
+ raise ValueError("Invalid use_ema_warmup")
715
+
716
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
717
+ if not isinstance(self.inv_gamma, (float, int)):
718
+ raise ValueError("Invalid inv_gamma")
719
+
720
+ self.power = state_dict.get("power", self.power)
721
+ if not isinstance(self.power, (float, int)):
722
+ raise ValueError("Invalid power")
723
+
724
+ shadow_params = state_dict.get("shadow_params", None)
725
+ if shadow_params is not None:
726
+ self.shadow_params = shadow_params
727
+ if not isinstance(self.shadow_params, list):
728
+ raise ValueError("shadow_params must be a list")
729
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
730
+ raise ValueError("shadow_params must all be Tensors")
pythonProject/diffusers-main/build/lib/diffusers/video_processor.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+
22
+ from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
23
+
24
+
25
+ class VideoProcessor(VaeImageProcessor):
26
+ r"""Simple video processor."""
27
+
28
+ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor:
29
+ r"""
30
+ Preprocesses input video(s).
31
+
32
+ Args:
33
+ video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`):
34
+ The input video. It can be one of the following:
35
+ * List of the PIL images.
36
+ * List of list of PIL images.
37
+ * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`).
38
+ * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
39
+ * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height,
40
+ width)`).
41
+ * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
42
+ * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width,
43
+ num_channels)`.
44
+ * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height,
45
+ width)`.
46
+ height (`int`, *optional*, defaults to `None`):
47
+ The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to
48
+ get default height.
49
+ width (`int`, *optional*`, defaults to `None`):
50
+ The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
51
+ the default width.
52
+ """
53
+ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
54
+ warnings.warn(
55
+ "Passing `video` as a list of 5d np.ndarray is deprecated."
56
+ "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray",
57
+ FutureWarning,
58
+ )
59
+ video = np.concatenate(video, axis=0)
60
+ if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5:
61
+ warnings.warn(
62
+ "Passing `video` as a list of 5d torch.Tensor is deprecated."
63
+ "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor",
64
+ FutureWarning,
65
+ )
66
+ video = torch.cat(video, axis=0)
67
+
68
+ # ensure the input is a list of videos:
69
+ # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
70
+ # - if it is a single video, it is converted to a list of one video.
71
+ if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
72
+ video = list(video)
73
+ elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
74
+ video = [video]
75
+ elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
76
+ video = video
77
+ else:
78
+ raise ValueError(
79
+ "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
80
+ )
81
+
82
+ video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
83
+
84
+ # move the number of channels before the number of frames.
85
+ video = video.permute(0, 2, 1, 3, 4)
86
+
87
+ return video
88
+
89
+ def postprocess_video(
90
+ self, video: torch.Tensor, output_type: str = "np"
91
+ ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]:
92
+ r"""
93
+ Converts a video tensor to a list of frames for export.
94
+
95
+ Args:
96
+ video (`torch.Tensor`): The video as a tensor.
97
+ output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
98
+ """
99
+ batch_size = video.shape[0]
100
+ outputs = []
101
+ for batch_idx in range(batch_size):
102
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
103
+ batch_output = self.postprocess(batch_vid, output_type)
104
+ outputs.append(batch_output)
105
+
106
+ if output_type == "np":
107
+ outputs = np.stack(outputs)
108
+ elif output_type == "pt":
109
+ outputs = torch.stack(outputs)
110
+ elif not output_type == "pil":
111
+ raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
112
+
113
+ return outputs