Add files using upload-large-folder tool
Browse files- pythonProject/diffusers-main/build/lib/diffusers/__init__.py +1384 -0
- pythonProject/diffusers-main/build/lib/diffusers/commands/__init__.py +27 -0
- pythonProject/diffusers-main/build/lib/diffusers/commands/custom_blocks.py +134 -0
- pythonProject/diffusers-main/build/lib/diffusers/commands/diffusers_cli.py +45 -0
- pythonProject/diffusers-main/build/lib/diffusers/commands/env.py +180 -0
- pythonProject/diffusers-main/build/lib/diffusers/commands/fp16_safetensors.py +132 -0
- pythonProject/diffusers-main/build/lib/diffusers/experimental/__init__.py +1 -0
- pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/__init__.py +1 -0
- pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/value_guided_sampling.py +153 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/adaptive_projected_guidance.py +188 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/auto_guidance.py +190 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_guidance.py +141 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/guider_utils.py +315 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/perturbed_attention_guidance.py +271 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/skip_layer_guidance.py +262 -0
- pythonProject/diffusers-main/build/lib/diffusers/guiders/smoothed_energy_guidance.py +251 -0
- pythonProject/diffusers-main/build/lib/diffusers/training_utils.py +730 -0
- pythonProject/diffusers-main/build/lib/diffusers/video_processor.py +113 -0
pythonProject/diffusers-main/build/lib/diffusers/__init__.py
ADDED
|
@@ -0,0 +1,1384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.36.0.dev0"
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
from .utils import (
|
| 6 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 7 |
+
OptionalDependencyNotAvailable,
|
| 8 |
+
_LazyModule,
|
| 9 |
+
is_accelerate_available,
|
| 10 |
+
is_bitsandbytes_available,
|
| 11 |
+
is_flax_available,
|
| 12 |
+
is_gguf_available,
|
| 13 |
+
is_k_diffusion_available,
|
| 14 |
+
is_librosa_available,
|
| 15 |
+
is_note_seq_available,
|
| 16 |
+
is_nvidia_modelopt_available,
|
| 17 |
+
is_onnx_available,
|
| 18 |
+
is_opencv_available,
|
| 19 |
+
is_optimum_quanto_available,
|
| 20 |
+
is_scipy_available,
|
| 21 |
+
is_sentencepiece_available,
|
| 22 |
+
is_torch_available,
|
| 23 |
+
is_torchao_available,
|
| 24 |
+
is_torchsde_available,
|
| 25 |
+
is_transformers_available,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Lazy Import based on
|
| 30 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
|
| 31 |
+
|
| 32 |
+
# When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
|
| 33 |
+
# and is used to defer the actual importing for when the objects are requested.
|
| 34 |
+
# This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
|
| 35 |
+
|
| 36 |
+
_import_structure = {
|
| 37 |
+
"configuration_utils": ["ConfigMixin"],
|
| 38 |
+
"guiders": [],
|
| 39 |
+
"hooks": [],
|
| 40 |
+
"loaders": ["FromOriginalModelMixin"],
|
| 41 |
+
"models": [],
|
| 42 |
+
"modular_pipelines": [],
|
| 43 |
+
"pipelines": [],
|
| 44 |
+
"quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
|
| 45 |
+
"quantizers.quantization_config": [],
|
| 46 |
+
"schedulers": [],
|
| 47 |
+
"utils": [
|
| 48 |
+
"OptionalDependencyNotAvailable",
|
| 49 |
+
"is_flax_available",
|
| 50 |
+
"is_inflect_available",
|
| 51 |
+
"is_invisible_watermark_available",
|
| 52 |
+
"is_k_diffusion_available",
|
| 53 |
+
"is_k_diffusion_version",
|
| 54 |
+
"is_librosa_available",
|
| 55 |
+
"is_note_seq_available",
|
| 56 |
+
"is_onnx_available",
|
| 57 |
+
"is_scipy_available",
|
| 58 |
+
"is_torch_available",
|
| 59 |
+
"is_torchsde_available",
|
| 60 |
+
"is_transformers_available",
|
| 61 |
+
"is_transformers_version",
|
| 62 |
+
"is_unidecode_available",
|
| 63 |
+
"logging",
|
| 64 |
+
],
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
|
| 69 |
+
raise OptionalDependencyNotAvailable()
|
| 70 |
+
except OptionalDependencyNotAvailable:
|
| 71 |
+
from .utils import dummy_bitsandbytes_objects
|
| 72 |
+
|
| 73 |
+
_import_structure["utils.dummy_bitsandbytes_objects"] = [
|
| 74 |
+
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
|
| 75 |
+
]
|
| 76 |
+
else:
|
| 77 |
+
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
|
| 81 |
+
raise OptionalDependencyNotAvailable()
|
| 82 |
+
except OptionalDependencyNotAvailable:
|
| 83 |
+
from .utils import dummy_gguf_objects
|
| 84 |
+
|
| 85 |
+
_import_structure["utils.dummy_gguf_objects"] = [
|
| 86 |
+
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
|
| 87 |
+
]
|
| 88 |
+
else:
|
| 89 |
+
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
|
| 93 |
+
raise OptionalDependencyNotAvailable()
|
| 94 |
+
except OptionalDependencyNotAvailable:
|
| 95 |
+
from .utils import dummy_torchao_objects
|
| 96 |
+
|
| 97 |
+
_import_structure["utils.dummy_torchao_objects"] = [
|
| 98 |
+
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
|
| 99 |
+
]
|
| 100 |
+
else:
|
| 101 |
+
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
|
| 105 |
+
raise OptionalDependencyNotAvailable()
|
| 106 |
+
except OptionalDependencyNotAvailable:
|
| 107 |
+
from .utils import dummy_optimum_quanto_objects
|
| 108 |
+
|
| 109 |
+
_import_structure["utils.dummy_optimum_quanto_objects"] = [
|
| 110 |
+
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
|
| 111 |
+
]
|
| 112 |
+
else:
|
| 113 |
+
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
|
| 117 |
+
raise OptionalDependencyNotAvailable()
|
| 118 |
+
except OptionalDependencyNotAvailable:
|
| 119 |
+
from .utils import dummy_nvidia_modelopt_objects
|
| 120 |
+
|
| 121 |
+
_import_structure["utils.dummy_nvidia_modelopt_objects"] = [
|
| 122 |
+
name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
|
| 123 |
+
]
|
| 124 |
+
else:
|
| 125 |
+
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
if not is_onnx_available():
|
| 129 |
+
raise OptionalDependencyNotAvailable()
|
| 130 |
+
except OptionalDependencyNotAvailable:
|
| 131 |
+
from .utils import dummy_onnx_objects # noqa F403
|
| 132 |
+
|
| 133 |
+
_import_structure["utils.dummy_onnx_objects"] = [
|
| 134 |
+
name for name in dir(dummy_onnx_objects) if not name.startswith("_")
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
else:
|
| 138 |
+
_import_structure["pipelines"].extend(["OnnxRuntimeModel"])
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
if not is_torch_available():
|
| 142 |
+
raise OptionalDependencyNotAvailable()
|
| 143 |
+
except OptionalDependencyNotAvailable:
|
| 144 |
+
from .utils import dummy_pt_objects # noqa F403
|
| 145 |
+
|
| 146 |
+
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
| 147 |
+
|
| 148 |
+
else:
|
| 149 |
+
_import_structure["guiders"].extend(
|
| 150 |
+
[
|
| 151 |
+
"AdaptiveProjectedGuidance",
|
| 152 |
+
"AutoGuidance",
|
| 153 |
+
"ClassifierFreeGuidance",
|
| 154 |
+
"ClassifierFreeZeroStarGuidance",
|
| 155 |
+
"FrequencyDecoupledGuidance",
|
| 156 |
+
"PerturbedAttentionGuidance",
|
| 157 |
+
"SkipLayerGuidance",
|
| 158 |
+
"SmoothedEnergyGuidance",
|
| 159 |
+
"TangentialClassifierFreeGuidance",
|
| 160 |
+
]
|
| 161 |
+
)
|
| 162 |
+
_import_structure["hooks"].extend(
|
| 163 |
+
[
|
| 164 |
+
"FasterCacheConfig",
|
| 165 |
+
"FirstBlockCacheConfig",
|
| 166 |
+
"HookRegistry",
|
| 167 |
+
"LayerSkipConfig",
|
| 168 |
+
"PyramidAttentionBroadcastConfig",
|
| 169 |
+
"SmoothedEnergyGuidanceConfig",
|
| 170 |
+
"apply_faster_cache",
|
| 171 |
+
"apply_first_block_cache",
|
| 172 |
+
"apply_layer_skip",
|
| 173 |
+
"apply_pyramid_attention_broadcast",
|
| 174 |
+
]
|
| 175 |
+
)
|
| 176 |
+
_import_structure["models"].extend(
|
| 177 |
+
[
|
| 178 |
+
"AllegroTransformer3DModel",
|
| 179 |
+
"AsymmetricAutoencoderKL",
|
| 180 |
+
"AttentionBackendName",
|
| 181 |
+
"AuraFlowTransformer2DModel",
|
| 182 |
+
"AutoencoderDC",
|
| 183 |
+
"AutoencoderKL",
|
| 184 |
+
"AutoencoderKLAllegro",
|
| 185 |
+
"AutoencoderKLCogVideoX",
|
| 186 |
+
"AutoencoderKLCosmos",
|
| 187 |
+
"AutoencoderKLHunyuanVideo",
|
| 188 |
+
"AutoencoderKLLTXVideo",
|
| 189 |
+
"AutoencoderKLMagvit",
|
| 190 |
+
"AutoencoderKLMochi",
|
| 191 |
+
"AutoencoderKLQwenImage",
|
| 192 |
+
"AutoencoderKLTemporalDecoder",
|
| 193 |
+
"AutoencoderKLWan",
|
| 194 |
+
"AutoencoderOobleck",
|
| 195 |
+
"AutoencoderTiny",
|
| 196 |
+
"AutoModel",
|
| 197 |
+
"BriaTransformer2DModel",
|
| 198 |
+
"CacheMixin",
|
| 199 |
+
"ChromaTransformer2DModel",
|
| 200 |
+
"CogVideoXTransformer3DModel",
|
| 201 |
+
"CogView3PlusTransformer2DModel",
|
| 202 |
+
"CogView4Transformer2DModel",
|
| 203 |
+
"ConsisIDTransformer3DModel",
|
| 204 |
+
"ConsistencyDecoderVAE",
|
| 205 |
+
"ControlNetModel",
|
| 206 |
+
"ControlNetUnionModel",
|
| 207 |
+
"ControlNetXSAdapter",
|
| 208 |
+
"CosmosTransformer3DModel",
|
| 209 |
+
"DiTTransformer2DModel",
|
| 210 |
+
"EasyAnimateTransformer3DModel",
|
| 211 |
+
"FluxControlNetModel",
|
| 212 |
+
"FluxMultiControlNetModel",
|
| 213 |
+
"FluxTransformer2DModel",
|
| 214 |
+
"HiDreamImageTransformer2DModel",
|
| 215 |
+
"HunyuanDiT2DControlNetModel",
|
| 216 |
+
"HunyuanDiT2DModel",
|
| 217 |
+
"HunyuanDiT2DMultiControlNetModel",
|
| 218 |
+
"HunyuanVideoFramepackTransformer3DModel",
|
| 219 |
+
"HunyuanVideoTransformer3DModel",
|
| 220 |
+
"I2VGenXLUNet",
|
| 221 |
+
"Kandinsky3UNet",
|
| 222 |
+
"LatteTransformer3DModel",
|
| 223 |
+
"LTXVideoTransformer3DModel",
|
| 224 |
+
"Lumina2Transformer2DModel",
|
| 225 |
+
"LuminaNextDiT2DModel",
|
| 226 |
+
"MochiTransformer3DModel",
|
| 227 |
+
"ModelMixin",
|
| 228 |
+
"MotionAdapter",
|
| 229 |
+
"MultiAdapter",
|
| 230 |
+
"MultiControlNetModel",
|
| 231 |
+
"OmniGenTransformer2DModel",
|
| 232 |
+
"PixArtTransformer2DModel",
|
| 233 |
+
"PriorTransformer",
|
| 234 |
+
"QwenImageControlNetModel",
|
| 235 |
+
"QwenImageMultiControlNetModel",
|
| 236 |
+
"QwenImageTransformer2DModel",
|
| 237 |
+
"SanaControlNetModel",
|
| 238 |
+
"SanaTransformer2DModel",
|
| 239 |
+
"SD3ControlNetModel",
|
| 240 |
+
"SD3MultiControlNetModel",
|
| 241 |
+
"SD3Transformer2DModel",
|
| 242 |
+
"SkyReelsV2Transformer3DModel",
|
| 243 |
+
"SparseControlNetModel",
|
| 244 |
+
"StableAudioDiTModel",
|
| 245 |
+
"StableCascadeUNet",
|
| 246 |
+
"T2IAdapter",
|
| 247 |
+
"T5FilmDecoder",
|
| 248 |
+
"Transformer2DModel",
|
| 249 |
+
"TransformerTemporalModel",
|
| 250 |
+
"UNet1DModel",
|
| 251 |
+
"UNet2DConditionModel",
|
| 252 |
+
"UNet2DModel",
|
| 253 |
+
"UNet3DConditionModel",
|
| 254 |
+
"UNetControlNetXSModel",
|
| 255 |
+
"UNetMotionModel",
|
| 256 |
+
"UNetSpatioTemporalConditionModel",
|
| 257 |
+
"UVit2DModel",
|
| 258 |
+
"VQModel",
|
| 259 |
+
"WanTransformer3DModel",
|
| 260 |
+
"WanVACETransformer3DModel",
|
| 261 |
+
"attention_backend",
|
| 262 |
+
]
|
| 263 |
+
)
|
| 264 |
+
_import_structure["modular_pipelines"].extend(
|
| 265 |
+
[
|
| 266 |
+
"ComponentsManager",
|
| 267 |
+
"ComponentSpec",
|
| 268 |
+
"ModularPipeline",
|
| 269 |
+
"ModularPipelineBlocks",
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
_import_structure["optimization"] = [
|
| 273 |
+
"get_constant_schedule",
|
| 274 |
+
"get_constant_schedule_with_warmup",
|
| 275 |
+
"get_cosine_schedule_with_warmup",
|
| 276 |
+
"get_cosine_with_hard_restarts_schedule_with_warmup",
|
| 277 |
+
"get_linear_schedule_with_warmup",
|
| 278 |
+
"get_polynomial_decay_schedule_with_warmup",
|
| 279 |
+
"get_scheduler",
|
| 280 |
+
]
|
| 281 |
+
_import_structure["pipelines"].extend(
|
| 282 |
+
[
|
| 283 |
+
"AudioPipelineOutput",
|
| 284 |
+
"AutoPipelineForImage2Image",
|
| 285 |
+
"AutoPipelineForInpainting",
|
| 286 |
+
"AutoPipelineForText2Image",
|
| 287 |
+
"ConsistencyModelPipeline",
|
| 288 |
+
"DanceDiffusionPipeline",
|
| 289 |
+
"DDIMPipeline",
|
| 290 |
+
"DDPMPipeline",
|
| 291 |
+
"DiffusionPipeline",
|
| 292 |
+
"DiTPipeline",
|
| 293 |
+
"ImagePipelineOutput",
|
| 294 |
+
"KarrasVePipeline",
|
| 295 |
+
"LDMPipeline",
|
| 296 |
+
"LDMSuperResolutionPipeline",
|
| 297 |
+
"PNDMPipeline",
|
| 298 |
+
"RePaintPipeline",
|
| 299 |
+
"ScoreSdeVePipeline",
|
| 300 |
+
"StableDiffusionMixin",
|
| 301 |
+
]
|
| 302 |
+
)
|
| 303 |
+
_import_structure["quantizers"] = ["DiffusersQuantizer"]
|
| 304 |
+
_import_structure["schedulers"].extend(
|
| 305 |
+
[
|
| 306 |
+
"AmusedScheduler",
|
| 307 |
+
"CMStochasticIterativeScheduler",
|
| 308 |
+
"CogVideoXDDIMScheduler",
|
| 309 |
+
"CogVideoXDPMScheduler",
|
| 310 |
+
"DDIMInverseScheduler",
|
| 311 |
+
"DDIMParallelScheduler",
|
| 312 |
+
"DDIMScheduler",
|
| 313 |
+
"DDPMParallelScheduler",
|
| 314 |
+
"DDPMScheduler",
|
| 315 |
+
"DDPMWuerstchenScheduler",
|
| 316 |
+
"DEISMultistepScheduler",
|
| 317 |
+
"DPMSolverMultistepInverseScheduler",
|
| 318 |
+
"DPMSolverMultistepScheduler",
|
| 319 |
+
"DPMSolverSinglestepScheduler",
|
| 320 |
+
"EDMDPMSolverMultistepScheduler",
|
| 321 |
+
"EDMEulerScheduler",
|
| 322 |
+
"EulerAncestralDiscreteScheduler",
|
| 323 |
+
"EulerDiscreteScheduler",
|
| 324 |
+
"FlowMatchEulerDiscreteScheduler",
|
| 325 |
+
"FlowMatchHeunDiscreteScheduler",
|
| 326 |
+
"FlowMatchLCMScheduler",
|
| 327 |
+
"HeunDiscreteScheduler",
|
| 328 |
+
"IPNDMScheduler",
|
| 329 |
+
"KarrasVeScheduler",
|
| 330 |
+
"KDPM2AncestralDiscreteScheduler",
|
| 331 |
+
"KDPM2DiscreteScheduler",
|
| 332 |
+
"LCMScheduler",
|
| 333 |
+
"PNDMScheduler",
|
| 334 |
+
"RePaintScheduler",
|
| 335 |
+
"SASolverScheduler",
|
| 336 |
+
"SchedulerMixin",
|
| 337 |
+
"SCMScheduler",
|
| 338 |
+
"ScoreSdeVeScheduler",
|
| 339 |
+
"TCDScheduler",
|
| 340 |
+
"UnCLIPScheduler",
|
| 341 |
+
"UniPCMultistepScheduler",
|
| 342 |
+
"VQDiffusionScheduler",
|
| 343 |
+
]
|
| 344 |
+
)
|
| 345 |
+
_import_structure["training_utils"] = ["EMAModel"]
|
| 346 |
+
|
| 347 |
+
try:
|
| 348 |
+
if not (is_torch_available() and is_scipy_available()):
|
| 349 |
+
raise OptionalDependencyNotAvailable()
|
| 350 |
+
except OptionalDependencyNotAvailable:
|
| 351 |
+
from .utils import dummy_torch_and_scipy_objects # noqa F403
|
| 352 |
+
|
| 353 |
+
_import_structure["utils.dummy_torch_and_scipy_objects"] = [
|
| 354 |
+
name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
else:
|
| 358 |
+
_import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
if not (is_torch_available() and is_torchsde_available()):
|
| 362 |
+
raise OptionalDependencyNotAvailable()
|
| 363 |
+
except OptionalDependencyNotAvailable:
|
| 364 |
+
from .utils import dummy_torch_and_torchsde_objects # noqa F403
|
| 365 |
+
|
| 366 |
+
_import_structure["utils.dummy_torch_and_torchsde_objects"] = [
|
| 367 |
+
name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
else:
|
| 371 |
+
_import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
|
| 372 |
+
|
| 373 |
+
try:
|
| 374 |
+
if not (is_torch_available() and is_transformers_available()):
|
| 375 |
+
raise OptionalDependencyNotAvailable()
|
| 376 |
+
except OptionalDependencyNotAvailable:
|
| 377 |
+
from .utils import dummy_torch_and_transformers_objects # noqa F403
|
| 378 |
+
|
| 379 |
+
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
|
| 380 |
+
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
|
| 381 |
+
]
|
| 382 |
+
|
| 383 |
+
else:
|
| 384 |
+
_import_structure["modular_pipelines"].extend(
|
| 385 |
+
[
|
| 386 |
+
"FluxAutoBlocks",
|
| 387 |
+
"FluxModularPipeline",
|
| 388 |
+
"QwenImageAutoBlocks",
|
| 389 |
+
"QwenImageEditAutoBlocks",
|
| 390 |
+
"QwenImageEditModularPipeline",
|
| 391 |
+
"QwenImageModularPipeline",
|
| 392 |
+
"StableDiffusionXLAutoBlocks",
|
| 393 |
+
"StableDiffusionXLModularPipeline",
|
| 394 |
+
"WanAutoBlocks",
|
| 395 |
+
"WanModularPipeline",
|
| 396 |
+
]
|
| 397 |
+
)
|
| 398 |
+
_import_structure["pipelines"].extend(
|
| 399 |
+
[
|
| 400 |
+
"AllegroPipeline",
|
| 401 |
+
"AltDiffusionImg2ImgPipeline",
|
| 402 |
+
"AltDiffusionPipeline",
|
| 403 |
+
"AmusedImg2ImgPipeline",
|
| 404 |
+
"AmusedInpaintPipeline",
|
| 405 |
+
"AmusedPipeline",
|
| 406 |
+
"AnimateDiffControlNetPipeline",
|
| 407 |
+
"AnimateDiffPAGPipeline",
|
| 408 |
+
"AnimateDiffPipeline",
|
| 409 |
+
"AnimateDiffSDXLPipeline",
|
| 410 |
+
"AnimateDiffSparseControlNetPipeline",
|
| 411 |
+
"AnimateDiffVideoToVideoControlNetPipeline",
|
| 412 |
+
"AnimateDiffVideoToVideoPipeline",
|
| 413 |
+
"AudioLDM2Pipeline",
|
| 414 |
+
"AudioLDM2ProjectionModel",
|
| 415 |
+
"AudioLDM2UNet2DConditionModel",
|
| 416 |
+
"AudioLDMPipeline",
|
| 417 |
+
"AuraFlowPipeline",
|
| 418 |
+
"BlipDiffusionControlNetPipeline",
|
| 419 |
+
"BlipDiffusionPipeline",
|
| 420 |
+
"BriaPipeline",
|
| 421 |
+
"ChromaImg2ImgPipeline",
|
| 422 |
+
"ChromaPipeline",
|
| 423 |
+
"CLIPImageProjection",
|
| 424 |
+
"CogVideoXFunControlPipeline",
|
| 425 |
+
"CogVideoXImageToVideoPipeline",
|
| 426 |
+
"CogVideoXPipeline",
|
| 427 |
+
"CogVideoXVideoToVideoPipeline",
|
| 428 |
+
"CogView3PlusPipeline",
|
| 429 |
+
"CogView4ControlPipeline",
|
| 430 |
+
"CogView4Pipeline",
|
| 431 |
+
"ConsisIDPipeline",
|
| 432 |
+
"Cosmos2TextToImagePipeline",
|
| 433 |
+
"Cosmos2VideoToWorldPipeline",
|
| 434 |
+
"CosmosTextToWorldPipeline",
|
| 435 |
+
"CosmosVideoToWorldPipeline",
|
| 436 |
+
"CycleDiffusionPipeline",
|
| 437 |
+
"EasyAnimateControlPipeline",
|
| 438 |
+
"EasyAnimateInpaintPipeline",
|
| 439 |
+
"EasyAnimatePipeline",
|
| 440 |
+
"FluxControlImg2ImgPipeline",
|
| 441 |
+
"FluxControlInpaintPipeline",
|
| 442 |
+
"FluxControlNetImg2ImgPipeline",
|
| 443 |
+
"FluxControlNetInpaintPipeline",
|
| 444 |
+
"FluxControlNetPipeline",
|
| 445 |
+
"FluxControlPipeline",
|
| 446 |
+
"FluxFillPipeline",
|
| 447 |
+
"FluxImg2ImgPipeline",
|
| 448 |
+
"FluxInpaintPipeline",
|
| 449 |
+
"FluxKontextInpaintPipeline",
|
| 450 |
+
"FluxKontextPipeline",
|
| 451 |
+
"FluxPipeline",
|
| 452 |
+
"FluxPriorReduxPipeline",
|
| 453 |
+
"HiDreamImagePipeline",
|
| 454 |
+
"HunyuanDiTControlNetPipeline",
|
| 455 |
+
"HunyuanDiTPAGPipeline",
|
| 456 |
+
"HunyuanDiTPipeline",
|
| 457 |
+
"HunyuanSkyreelsImageToVideoPipeline",
|
| 458 |
+
"HunyuanVideoFramepackPipeline",
|
| 459 |
+
"HunyuanVideoImageToVideoPipeline",
|
| 460 |
+
"HunyuanVideoPipeline",
|
| 461 |
+
"I2VGenXLPipeline",
|
| 462 |
+
"IFImg2ImgPipeline",
|
| 463 |
+
"IFImg2ImgSuperResolutionPipeline",
|
| 464 |
+
"IFInpaintingPipeline",
|
| 465 |
+
"IFInpaintingSuperResolutionPipeline",
|
| 466 |
+
"IFPipeline",
|
| 467 |
+
"IFSuperResolutionPipeline",
|
| 468 |
+
"ImageTextPipelineOutput",
|
| 469 |
+
"Kandinsky3Img2ImgPipeline",
|
| 470 |
+
"Kandinsky3Pipeline",
|
| 471 |
+
"KandinskyCombinedPipeline",
|
| 472 |
+
"KandinskyImg2ImgCombinedPipeline",
|
| 473 |
+
"KandinskyImg2ImgPipeline",
|
| 474 |
+
"KandinskyInpaintCombinedPipeline",
|
| 475 |
+
"KandinskyInpaintPipeline",
|
| 476 |
+
"KandinskyPipeline",
|
| 477 |
+
"KandinskyPriorPipeline",
|
| 478 |
+
"KandinskyV22CombinedPipeline",
|
| 479 |
+
"KandinskyV22ControlnetImg2ImgPipeline",
|
| 480 |
+
"KandinskyV22ControlnetPipeline",
|
| 481 |
+
"KandinskyV22Img2ImgCombinedPipeline",
|
| 482 |
+
"KandinskyV22Img2ImgPipeline",
|
| 483 |
+
"KandinskyV22InpaintCombinedPipeline",
|
| 484 |
+
"KandinskyV22InpaintPipeline",
|
| 485 |
+
"KandinskyV22Pipeline",
|
| 486 |
+
"KandinskyV22PriorEmb2EmbPipeline",
|
| 487 |
+
"KandinskyV22PriorPipeline",
|
| 488 |
+
"LatentConsistencyModelImg2ImgPipeline",
|
| 489 |
+
"LatentConsistencyModelPipeline",
|
| 490 |
+
"LattePipeline",
|
| 491 |
+
"LDMTextToImagePipeline",
|
| 492 |
+
"LEditsPPPipelineStableDiffusion",
|
| 493 |
+
"LEditsPPPipelineStableDiffusionXL",
|
| 494 |
+
"LTXConditionPipeline",
|
| 495 |
+
"LTXImageToVideoPipeline",
|
| 496 |
+
"LTXLatentUpsamplePipeline",
|
| 497 |
+
"LTXPipeline",
|
| 498 |
+
"Lumina2Pipeline",
|
| 499 |
+
"Lumina2Text2ImgPipeline",
|
| 500 |
+
"LuminaPipeline",
|
| 501 |
+
"LuminaText2ImgPipeline",
|
| 502 |
+
"MarigoldDepthPipeline",
|
| 503 |
+
"MarigoldIntrinsicsPipeline",
|
| 504 |
+
"MarigoldNormalsPipeline",
|
| 505 |
+
"MochiPipeline",
|
| 506 |
+
"MusicLDMPipeline",
|
| 507 |
+
"OmniGenPipeline",
|
| 508 |
+
"PaintByExamplePipeline",
|
| 509 |
+
"PIAPipeline",
|
| 510 |
+
"PixArtAlphaPipeline",
|
| 511 |
+
"PixArtSigmaPAGPipeline",
|
| 512 |
+
"PixArtSigmaPipeline",
|
| 513 |
+
"QwenImageControlNetInpaintPipeline",
|
| 514 |
+
"QwenImageControlNetPipeline",
|
| 515 |
+
"QwenImageEditInpaintPipeline",
|
| 516 |
+
"QwenImageEditPipeline",
|
| 517 |
+
"QwenImageImg2ImgPipeline",
|
| 518 |
+
"QwenImageInpaintPipeline",
|
| 519 |
+
"QwenImagePipeline",
|
| 520 |
+
"ReduxImageEncoder",
|
| 521 |
+
"SanaControlNetPipeline",
|
| 522 |
+
"SanaPAGPipeline",
|
| 523 |
+
"SanaPipeline",
|
| 524 |
+
"SanaSprintImg2ImgPipeline",
|
| 525 |
+
"SanaSprintPipeline",
|
| 526 |
+
"SemanticStableDiffusionPipeline",
|
| 527 |
+
"ShapEImg2ImgPipeline",
|
| 528 |
+
"ShapEPipeline",
|
| 529 |
+
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
|
| 530 |
+
"SkyReelsV2DiffusionForcingPipeline",
|
| 531 |
+
"SkyReelsV2DiffusionForcingVideoToVideoPipeline",
|
| 532 |
+
"SkyReelsV2ImageToVideoPipeline",
|
| 533 |
+
"SkyReelsV2Pipeline",
|
| 534 |
+
"StableAudioPipeline",
|
| 535 |
+
"StableAudioProjectionModel",
|
| 536 |
+
"StableCascadeCombinedPipeline",
|
| 537 |
+
"StableCascadeDecoderPipeline",
|
| 538 |
+
"StableCascadePriorPipeline",
|
| 539 |
+
"StableDiffusion3ControlNetInpaintingPipeline",
|
| 540 |
+
"StableDiffusion3ControlNetPipeline",
|
| 541 |
+
"StableDiffusion3Img2ImgPipeline",
|
| 542 |
+
"StableDiffusion3InpaintPipeline",
|
| 543 |
+
"StableDiffusion3PAGImg2ImgPipeline",
|
| 544 |
+
"StableDiffusion3PAGImg2ImgPipeline",
|
| 545 |
+
"StableDiffusion3PAGPipeline",
|
| 546 |
+
"StableDiffusion3Pipeline",
|
| 547 |
+
"StableDiffusionAdapterPipeline",
|
| 548 |
+
"StableDiffusionAttendAndExcitePipeline",
|
| 549 |
+
"StableDiffusionControlNetImg2ImgPipeline",
|
| 550 |
+
"StableDiffusionControlNetInpaintPipeline",
|
| 551 |
+
"StableDiffusionControlNetPAGInpaintPipeline",
|
| 552 |
+
"StableDiffusionControlNetPAGPipeline",
|
| 553 |
+
"StableDiffusionControlNetPipeline",
|
| 554 |
+
"StableDiffusionControlNetXSPipeline",
|
| 555 |
+
"StableDiffusionDepth2ImgPipeline",
|
| 556 |
+
"StableDiffusionDiffEditPipeline",
|
| 557 |
+
"StableDiffusionGLIGENPipeline",
|
| 558 |
+
"StableDiffusionGLIGENTextImagePipeline",
|
| 559 |
+
"StableDiffusionImageVariationPipeline",
|
| 560 |
+
"StableDiffusionImg2ImgPipeline",
|
| 561 |
+
"StableDiffusionInpaintPipeline",
|
| 562 |
+
"StableDiffusionInpaintPipelineLegacy",
|
| 563 |
+
"StableDiffusionInstructPix2PixPipeline",
|
| 564 |
+
"StableDiffusionLatentUpscalePipeline",
|
| 565 |
+
"StableDiffusionLDM3DPipeline",
|
| 566 |
+
"StableDiffusionModelEditingPipeline",
|
| 567 |
+
"StableDiffusionPAGImg2ImgPipeline",
|
| 568 |
+
"StableDiffusionPAGInpaintPipeline",
|
| 569 |
+
"StableDiffusionPAGPipeline",
|
| 570 |
+
"StableDiffusionPanoramaPipeline",
|
| 571 |
+
"StableDiffusionParadigmsPipeline",
|
| 572 |
+
"StableDiffusionPipeline",
|
| 573 |
+
"StableDiffusionPipelineSafe",
|
| 574 |
+
"StableDiffusionPix2PixZeroPipeline",
|
| 575 |
+
"StableDiffusionSAGPipeline",
|
| 576 |
+
"StableDiffusionUpscalePipeline",
|
| 577 |
+
"StableDiffusionXLAdapterPipeline",
|
| 578 |
+
"StableDiffusionXLControlNetImg2ImgPipeline",
|
| 579 |
+
"StableDiffusionXLControlNetInpaintPipeline",
|
| 580 |
+
"StableDiffusionXLControlNetPAGImg2ImgPipeline",
|
| 581 |
+
"StableDiffusionXLControlNetPAGPipeline",
|
| 582 |
+
"StableDiffusionXLControlNetPipeline",
|
| 583 |
+
"StableDiffusionXLControlNetUnionImg2ImgPipeline",
|
| 584 |
+
"StableDiffusionXLControlNetUnionInpaintPipeline",
|
| 585 |
+
"StableDiffusionXLControlNetUnionPipeline",
|
| 586 |
+
"StableDiffusionXLControlNetXSPipeline",
|
| 587 |
+
"StableDiffusionXLImg2ImgPipeline",
|
| 588 |
+
"StableDiffusionXLInpaintPipeline",
|
| 589 |
+
"StableDiffusionXLInstructPix2PixPipeline",
|
| 590 |
+
"StableDiffusionXLPAGImg2ImgPipeline",
|
| 591 |
+
"StableDiffusionXLPAGInpaintPipeline",
|
| 592 |
+
"StableDiffusionXLPAGPipeline",
|
| 593 |
+
"StableDiffusionXLPipeline",
|
| 594 |
+
"StableUnCLIPImg2ImgPipeline",
|
| 595 |
+
"StableUnCLIPPipeline",
|
| 596 |
+
"StableVideoDiffusionPipeline",
|
| 597 |
+
"TextToVideoSDPipeline",
|
| 598 |
+
"TextToVideoZeroPipeline",
|
| 599 |
+
"TextToVideoZeroSDXLPipeline",
|
| 600 |
+
"UnCLIPImageVariationPipeline",
|
| 601 |
+
"UnCLIPPipeline",
|
| 602 |
+
"UniDiffuserModel",
|
| 603 |
+
"UniDiffuserPipeline",
|
| 604 |
+
"UniDiffuserTextDecoder",
|
| 605 |
+
"VersatileDiffusionDualGuidedPipeline",
|
| 606 |
+
"VersatileDiffusionImageVariationPipeline",
|
| 607 |
+
"VersatileDiffusionPipeline",
|
| 608 |
+
"VersatileDiffusionTextToImagePipeline",
|
| 609 |
+
"VideoToVideoSDPipeline",
|
| 610 |
+
"VisualClozeGenerationPipeline",
|
| 611 |
+
"VisualClozePipeline",
|
| 612 |
+
"VQDiffusionPipeline",
|
| 613 |
+
"WanImageToVideoPipeline",
|
| 614 |
+
"WanPipeline",
|
| 615 |
+
"WanVACEPipeline",
|
| 616 |
+
"WanVideoToVideoPipeline",
|
| 617 |
+
"WuerstchenCombinedPipeline",
|
| 618 |
+
"WuerstchenDecoderPipeline",
|
| 619 |
+
"WuerstchenPriorPipeline",
|
| 620 |
+
]
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
try:
|
| 625 |
+
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
| 626 |
+
raise OptionalDependencyNotAvailable()
|
| 627 |
+
except OptionalDependencyNotAvailable:
|
| 628 |
+
from .utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
|
| 629 |
+
|
| 630 |
+
_import_structure["utils.dummy_torch_and_transformers_and_opencv_objects"] = [
|
| 631 |
+
name for name in dir(dummy_torch_and_transformers_and_opencv_objects) if not name.startswith("_")
|
| 632 |
+
]
|
| 633 |
+
|
| 634 |
+
else:
|
| 635 |
+
_import_structure["pipelines"].extend(["ConsisIDPipeline"])
|
| 636 |
+
|
| 637 |
+
try:
|
| 638 |
+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
| 639 |
+
raise OptionalDependencyNotAvailable()
|
| 640 |
+
except OptionalDependencyNotAvailable:
|
| 641 |
+
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
| 642 |
+
|
| 643 |
+
_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
|
| 644 |
+
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
|
| 645 |
+
]
|
| 646 |
+
|
| 647 |
+
else:
|
| 648 |
+
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
|
| 649 |
+
|
| 650 |
+
try:
|
| 651 |
+
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
| 652 |
+
raise OptionalDependencyNotAvailable()
|
| 653 |
+
except OptionalDependencyNotAvailable:
|
| 654 |
+
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
|
| 655 |
+
|
| 656 |
+
_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
|
| 657 |
+
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
|
| 658 |
+
]
|
| 659 |
+
|
| 660 |
+
else:
|
| 661 |
+
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
|
| 662 |
+
|
| 663 |
+
try:
|
| 664 |
+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
| 665 |
+
raise OptionalDependencyNotAvailable()
|
| 666 |
+
except OptionalDependencyNotAvailable:
|
| 667 |
+
from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
| 668 |
+
|
| 669 |
+
_import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
|
| 670 |
+
name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
|
| 671 |
+
]
|
| 672 |
+
|
| 673 |
+
else:
|
| 674 |
+
_import_structure["pipelines"].extend(
|
| 675 |
+
[
|
| 676 |
+
"OnnxStableDiffusionImg2ImgPipeline",
|
| 677 |
+
"OnnxStableDiffusionInpaintPipeline",
|
| 678 |
+
"OnnxStableDiffusionInpaintPipelineLegacy",
|
| 679 |
+
"OnnxStableDiffusionPipeline",
|
| 680 |
+
"OnnxStableDiffusionUpscalePipeline",
|
| 681 |
+
"StableDiffusionOnnxPipeline",
|
| 682 |
+
]
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
try:
|
| 686 |
+
if not (is_torch_available() and is_librosa_available()):
|
| 687 |
+
raise OptionalDependencyNotAvailable()
|
| 688 |
+
except OptionalDependencyNotAvailable:
|
| 689 |
+
from .utils import dummy_torch_and_librosa_objects # noqa F403
|
| 690 |
+
|
| 691 |
+
_import_structure["utils.dummy_torch_and_librosa_objects"] = [
|
| 692 |
+
name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
|
| 693 |
+
]
|
| 694 |
+
|
| 695 |
+
else:
|
| 696 |
+
_import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
|
| 697 |
+
|
| 698 |
+
try:
|
| 699 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
| 700 |
+
raise OptionalDependencyNotAvailable()
|
| 701 |
+
except OptionalDependencyNotAvailable:
|
| 702 |
+
from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
| 703 |
+
|
| 704 |
+
_import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
|
| 705 |
+
name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
|
| 706 |
+
]
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
else:
|
| 710 |
+
_import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
|
| 711 |
+
|
| 712 |
+
try:
|
| 713 |
+
if not is_flax_available():
|
| 714 |
+
raise OptionalDependencyNotAvailable()
|
| 715 |
+
except OptionalDependencyNotAvailable:
|
| 716 |
+
from .utils import dummy_flax_objects # noqa F403
|
| 717 |
+
|
| 718 |
+
_import_structure["utils.dummy_flax_objects"] = [
|
| 719 |
+
name for name in dir(dummy_flax_objects) if not name.startswith("_")
|
| 720 |
+
]
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
else:
|
| 724 |
+
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
|
| 725 |
+
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
| 726 |
+
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
| 727 |
+
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
| 728 |
+
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
|
| 729 |
+
_import_structure["schedulers"].extend(
|
| 730 |
+
[
|
| 731 |
+
"FlaxDDIMScheduler",
|
| 732 |
+
"FlaxDDPMScheduler",
|
| 733 |
+
"FlaxDPMSolverMultistepScheduler",
|
| 734 |
+
"FlaxEulerDiscreteScheduler",
|
| 735 |
+
"FlaxKarrasVeScheduler",
|
| 736 |
+
"FlaxLMSDiscreteScheduler",
|
| 737 |
+
"FlaxPNDMScheduler",
|
| 738 |
+
"FlaxSchedulerMixin",
|
| 739 |
+
"FlaxScoreSdeVeScheduler",
|
| 740 |
+
]
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
try:
|
| 745 |
+
if not (is_flax_available() and is_transformers_available()):
|
| 746 |
+
raise OptionalDependencyNotAvailable()
|
| 747 |
+
except OptionalDependencyNotAvailable:
|
| 748 |
+
from .utils import dummy_flax_and_transformers_objects # noqa F403
|
| 749 |
+
|
| 750 |
+
_import_structure["utils.dummy_flax_and_transformers_objects"] = [
|
| 751 |
+
name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
|
| 752 |
+
]
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
else:
|
| 756 |
+
_import_structure["pipelines"].extend(
|
| 757 |
+
[
|
| 758 |
+
"FlaxStableDiffusionControlNetPipeline",
|
| 759 |
+
"FlaxStableDiffusionImg2ImgPipeline",
|
| 760 |
+
"FlaxStableDiffusionInpaintPipeline",
|
| 761 |
+
"FlaxStableDiffusionPipeline",
|
| 762 |
+
"FlaxStableDiffusionXLPipeline",
|
| 763 |
+
]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
try:
|
| 767 |
+
if not (is_note_seq_available()):
|
| 768 |
+
raise OptionalDependencyNotAvailable()
|
| 769 |
+
except OptionalDependencyNotAvailable:
|
| 770 |
+
from .utils import dummy_note_seq_objects # noqa F403
|
| 771 |
+
|
| 772 |
+
_import_structure["utils.dummy_note_seq_objects"] = [
|
| 773 |
+
name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
|
| 774 |
+
]
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
else:
|
| 778 |
+
_import_structure["pipelines"].extend(["MidiProcessor"])
|
| 779 |
+
|
| 780 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 781 |
+
from .configuration_utils import ConfigMixin
|
| 782 |
+
from .quantizers import PipelineQuantizationConfig
|
| 783 |
+
|
| 784 |
+
try:
|
| 785 |
+
if not is_bitsandbytes_available():
|
| 786 |
+
raise OptionalDependencyNotAvailable()
|
| 787 |
+
except OptionalDependencyNotAvailable:
|
| 788 |
+
from .utils.dummy_bitsandbytes_objects import *
|
| 789 |
+
else:
|
| 790 |
+
from .quantizers.quantization_config import BitsAndBytesConfig
|
| 791 |
+
|
| 792 |
+
try:
|
| 793 |
+
if not is_gguf_available():
|
| 794 |
+
raise OptionalDependencyNotAvailable()
|
| 795 |
+
except OptionalDependencyNotAvailable:
|
| 796 |
+
from .utils.dummy_gguf_objects import *
|
| 797 |
+
else:
|
| 798 |
+
from .quantizers.quantization_config import GGUFQuantizationConfig
|
| 799 |
+
|
| 800 |
+
try:
|
| 801 |
+
if not is_torchao_available():
|
| 802 |
+
raise OptionalDependencyNotAvailable()
|
| 803 |
+
except OptionalDependencyNotAvailable:
|
| 804 |
+
from .utils.dummy_torchao_objects import *
|
| 805 |
+
else:
|
| 806 |
+
from .quantizers.quantization_config import TorchAoConfig
|
| 807 |
+
|
| 808 |
+
try:
|
| 809 |
+
if not is_optimum_quanto_available():
|
| 810 |
+
raise OptionalDependencyNotAvailable()
|
| 811 |
+
except OptionalDependencyNotAvailable:
|
| 812 |
+
from .utils.dummy_optimum_quanto_objects import *
|
| 813 |
+
else:
|
| 814 |
+
from .quantizers.quantization_config import QuantoConfig
|
| 815 |
+
|
| 816 |
+
try:
|
| 817 |
+
if not is_nvidia_modelopt_available():
|
| 818 |
+
raise OptionalDependencyNotAvailable()
|
| 819 |
+
except OptionalDependencyNotAvailable:
|
| 820 |
+
from .utils.dummy_nvidia_modelopt_objects import *
|
| 821 |
+
else:
|
| 822 |
+
from .quantizers.quantization_config import NVIDIAModelOptConfig
|
| 823 |
+
|
| 824 |
+
try:
|
| 825 |
+
if not is_onnx_available():
|
| 826 |
+
raise OptionalDependencyNotAvailable()
|
| 827 |
+
except OptionalDependencyNotAvailable:
|
| 828 |
+
from .utils.dummy_onnx_objects import * # noqa F403
|
| 829 |
+
else:
|
| 830 |
+
from .pipelines import OnnxRuntimeModel
|
| 831 |
+
|
| 832 |
+
try:
|
| 833 |
+
if not is_torch_available():
|
| 834 |
+
raise OptionalDependencyNotAvailable()
|
| 835 |
+
except OptionalDependencyNotAvailable:
|
| 836 |
+
from .utils.dummy_pt_objects import * # noqa F403
|
| 837 |
+
else:
|
| 838 |
+
from .guiders import (
|
| 839 |
+
AdaptiveProjectedGuidance,
|
| 840 |
+
AutoGuidance,
|
| 841 |
+
ClassifierFreeGuidance,
|
| 842 |
+
ClassifierFreeZeroStarGuidance,
|
| 843 |
+
FrequencyDecoupledGuidance,
|
| 844 |
+
PerturbedAttentionGuidance,
|
| 845 |
+
SkipLayerGuidance,
|
| 846 |
+
SmoothedEnergyGuidance,
|
| 847 |
+
TangentialClassifierFreeGuidance,
|
| 848 |
+
)
|
| 849 |
+
from .hooks import (
|
| 850 |
+
FasterCacheConfig,
|
| 851 |
+
FirstBlockCacheConfig,
|
| 852 |
+
HookRegistry,
|
| 853 |
+
LayerSkipConfig,
|
| 854 |
+
PyramidAttentionBroadcastConfig,
|
| 855 |
+
SmoothedEnergyGuidanceConfig,
|
| 856 |
+
apply_faster_cache,
|
| 857 |
+
apply_first_block_cache,
|
| 858 |
+
apply_layer_skip,
|
| 859 |
+
apply_pyramid_attention_broadcast,
|
| 860 |
+
)
|
| 861 |
+
from .models import (
|
| 862 |
+
AllegroTransformer3DModel,
|
| 863 |
+
AsymmetricAutoencoderKL,
|
| 864 |
+
AttentionBackendName,
|
| 865 |
+
AuraFlowTransformer2DModel,
|
| 866 |
+
AutoencoderDC,
|
| 867 |
+
AutoencoderKL,
|
| 868 |
+
AutoencoderKLAllegro,
|
| 869 |
+
AutoencoderKLCogVideoX,
|
| 870 |
+
AutoencoderKLCosmos,
|
| 871 |
+
AutoencoderKLHunyuanVideo,
|
| 872 |
+
AutoencoderKLLTXVideo,
|
| 873 |
+
AutoencoderKLMagvit,
|
| 874 |
+
AutoencoderKLMochi,
|
| 875 |
+
AutoencoderKLQwenImage,
|
| 876 |
+
AutoencoderKLTemporalDecoder,
|
| 877 |
+
AutoencoderKLWan,
|
| 878 |
+
AutoencoderOobleck,
|
| 879 |
+
AutoencoderTiny,
|
| 880 |
+
AutoModel,
|
| 881 |
+
BriaTransformer2DModel,
|
| 882 |
+
CacheMixin,
|
| 883 |
+
ChromaTransformer2DModel,
|
| 884 |
+
CogVideoXTransformer3DModel,
|
| 885 |
+
CogView3PlusTransformer2DModel,
|
| 886 |
+
CogView4Transformer2DModel,
|
| 887 |
+
ConsisIDTransformer3DModel,
|
| 888 |
+
ConsistencyDecoderVAE,
|
| 889 |
+
ControlNetModel,
|
| 890 |
+
ControlNetUnionModel,
|
| 891 |
+
ControlNetXSAdapter,
|
| 892 |
+
CosmosTransformer3DModel,
|
| 893 |
+
DiTTransformer2DModel,
|
| 894 |
+
EasyAnimateTransformer3DModel,
|
| 895 |
+
FluxControlNetModel,
|
| 896 |
+
FluxMultiControlNetModel,
|
| 897 |
+
FluxTransformer2DModel,
|
| 898 |
+
HiDreamImageTransformer2DModel,
|
| 899 |
+
HunyuanDiT2DControlNetModel,
|
| 900 |
+
HunyuanDiT2DModel,
|
| 901 |
+
HunyuanDiT2DMultiControlNetModel,
|
| 902 |
+
HunyuanVideoFramepackTransformer3DModel,
|
| 903 |
+
HunyuanVideoTransformer3DModel,
|
| 904 |
+
I2VGenXLUNet,
|
| 905 |
+
Kandinsky3UNet,
|
| 906 |
+
LatteTransformer3DModel,
|
| 907 |
+
LTXVideoTransformer3DModel,
|
| 908 |
+
Lumina2Transformer2DModel,
|
| 909 |
+
LuminaNextDiT2DModel,
|
| 910 |
+
MochiTransformer3DModel,
|
| 911 |
+
ModelMixin,
|
| 912 |
+
MotionAdapter,
|
| 913 |
+
MultiAdapter,
|
| 914 |
+
MultiControlNetModel,
|
| 915 |
+
OmniGenTransformer2DModel,
|
| 916 |
+
PixArtTransformer2DModel,
|
| 917 |
+
PriorTransformer,
|
| 918 |
+
QwenImageControlNetModel,
|
| 919 |
+
QwenImageMultiControlNetModel,
|
| 920 |
+
QwenImageTransformer2DModel,
|
| 921 |
+
SanaControlNetModel,
|
| 922 |
+
SanaTransformer2DModel,
|
| 923 |
+
SD3ControlNetModel,
|
| 924 |
+
SD3MultiControlNetModel,
|
| 925 |
+
SD3Transformer2DModel,
|
| 926 |
+
SkyReelsV2Transformer3DModel,
|
| 927 |
+
SparseControlNetModel,
|
| 928 |
+
StableAudioDiTModel,
|
| 929 |
+
T2IAdapter,
|
| 930 |
+
T5FilmDecoder,
|
| 931 |
+
Transformer2DModel,
|
| 932 |
+
TransformerTemporalModel,
|
| 933 |
+
UNet1DModel,
|
| 934 |
+
UNet2DConditionModel,
|
| 935 |
+
UNet2DModel,
|
| 936 |
+
UNet3DConditionModel,
|
| 937 |
+
UNetControlNetXSModel,
|
| 938 |
+
UNetMotionModel,
|
| 939 |
+
UNetSpatioTemporalConditionModel,
|
| 940 |
+
UVit2DModel,
|
| 941 |
+
VQModel,
|
| 942 |
+
WanTransformer3DModel,
|
| 943 |
+
WanVACETransformer3DModel,
|
| 944 |
+
attention_backend,
|
| 945 |
+
)
|
| 946 |
+
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
|
| 947 |
+
from .optimization import (
|
| 948 |
+
get_constant_schedule,
|
| 949 |
+
get_constant_schedule_with_warmup,
|
| 950 |
+
get_cosine_schedule_with_warmup,
|
| 951 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 952 |
+
get_linear_schedule_with_warmup,
|
| 953 |
+
get_polynomial_decay_schedule_with_warmup,
|
| 954 |
+
get_scheduler,
|
| 955 |
+
)
|
| 956 |
+
from .pipelines import (
|
| 957 |
+
AudioPipelineOutput,
|
| 958 |
+
AutoPipelineForImage2Image,
|
| 959 |
+
AutoPipelineForInpainting,
|
| 960 |
+
AutoPipelineForText2Image,
|
| 961 |
+
BlipDiffusionControlNetPipeline,
|
| 962 |
+
BlipDiffusionPipeline,
|
| 963 |
+
CLIPImageProjection,
|
| 964 |
+
ConsistencyModelPipeline,
|
| 965 |
+
DanceDiffusionPipeline,
|
| 966 |
+
DDIMPipeline,
|
| 967 |
+
DDPMPipeline,
|
| 968 |
+
DiffusionPipeline,
|
| 969 |
+
DiTPipeline,
|
| 970 |
+
ImagePipelineOutput,
|
| 971 |
+
KarrasVePipeline,
|
| 972 |
+
LDMPipeline,
|
| 973 |
+
LDMSuperResolutionPipeline,
|
| 974 |
+
PNDMPipeline,
|
| 975 |
+
RePaintPipeline,
|
| 976 |
+
ScoreSdeVePipeline,
|
| 977 |
+
StableDiffusionMixin,
|
| 978 |
+
)
|
| 979 |
+
from .quantizers import DiffusersQuantizer
|
| 980 |
+
from .schedulers import (
|
| 981 |
+
AmusedScheduler,
|
| 982 |
+
CMStochasticIterativeScheduler,
|
| 983 |
+
CogVideoXDDIMScheduler,
|
| 984 |
+
CogVideoXDPMScheduler,
|
| 985 |
+
DDIMInverseScheduler,
|
| 986 |
+
DDIMParallelScheduler,
|
| 987 |
+
DDIMScheduler,
|
| 988 |
+
DDPMParallelScheduler,
|
| 989 |
+
DDPMScheduler,
|
| 990 |
+
DDPMWuerstchenScheduler,
|
| 991 |
+
DEISMultistepScheduler,
|
| 992 |
+
DPMSolverMultistepInverseScheduler,
|
| 993 |
+
DPMSolverMultistepScheduler,
|
| 994 |
+
DPMSolverSinglestepScheduler,
|
| 995 |
+
EDMDPMSolverMultistepScheduler,
|
| 996 |
+
EDMEulerScheduler,
|
| 997 |
+
EulerAncestralDiscreteScheduler,
|
| 998 |
+
EulerDiscreteScheduler,
|
| 999 |
+
FlowMatchEulerDiscreteScheduler,
|
| 1000 |
+
FlowMatchHeunDiscreteScheduler,
|
| 1001 |
+
FlowMatchLCMScheduler,
|
| 1002 |
+
HeunDiscreteScheduler,
|
| 1003 |
+
IPNDMScheduler,
|
| 1004 |
+
KarrasVeScheduler,
|
| 1005 |
+
KDPM2AncestralDiscreteScheduler,
|
| 1006 |
+
KDPM2DiscreteScheduler,
|
| 1007 |
+
LCMScheduler,
|
| 1008 |
+
PNDMScheduler,
|
| 1009 |
+
RePaintScheduler,
|
| 1010 |
+
SASolverScheduler,
|
| 1011 |
+
SchedulerMixin,
|
| 1012 |
+
SCMScheduler,
|
| 1013 |
+
ScoreSdeVeScheduler,
|
| 1014 |
+
TCDScheduler,
|
| 1015 |
+
UnCLIPScheduler,
|
| 1016 |
+
UniPCMultistepScheduler,
|
| 1017 |
+
VQDiffusionScheduler,
|
| 1018 |
+
)
|
| 1019 |
+
from .training_utils import EMAModel
|
| 1020 |
+
|
| 1021 |
+
try:
|
| 1022 |
+
if not (is_torch_available() and is_scipy_available()):
|
| 1023 |
+
raise OptionalDependencyNotAvailable()
|
| 1024 |
+
except OptionalDependencyNotAvailable:
|
| 1025 |
+
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
| 1026 |
+
else:
|
| 1027 |
+
from .schedulers import LMSDiscreteScheduler
|
| 1028 |
+
|
| 1029 |
+
try:
|
| 1030 |
+
if not (is_torch_available() and is_torchsde_available()):
|
| 1031 |
+
raise OptionalDependencyNotAvailable()
|
| 1032 |
+
except OptionalDependencyNotAvailable:
|
| 1033 |
+
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
| 1034 |
+
else:
|
| 1035 |
+
from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler
|
| 1036 |
+
|
| 1037 |
+
try:
|
| 1038 |
+
if not (is_torch_available() and is_transformers_available()):
|
| 1039 |
+
raise OptionalDependencyNotAvailable()
|
| 1040 |
+
except OptionalDependencyNotAvailable:
|
| 1041 |
+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
| 1042 |
+
else:
|
| 1043 |
+
from .modular_pipelines import (
|
| 1044 |
+
FluxAutoBlocks,
|
| 1045 |
+
FluxModularPipeline,
|
| 1046 |
+
QwenImageAutoBlocks,
|
| 1047 |
+
QwenImageEditAutoBlocks,
|
| 1048 |
+
QwenImageEditModularPipeline,
|
| 1049 |
+
QwenImageModularPipeline,
|
| 1050 |
+
StableDiffusionXLAutoBlocks,
|
| 1051 |
+
StableDiffusionXLModularPipeline,
|
| 1052 |
+
WanAutoBlocks,
|
| 1053 |
+
WanModularPipeline,
|
| 1054 |
+
)
|
| 1055 |
+
from .pipelines import (
|
| 1056 |
+
AllegroPipeline,
|
| 1057 |
+
AltDiffusionImg2ImgPipeline,
|
| 1058 |
+
AltDiffusionPipeline,
|
| 1059 |
+
AmusedImg2ImgPipeline,
|
| 1060 |
+
AmusedInpaintPipeline,
|
| 1061 |
+
AmusedPipeline,
|
| 1062 |
+
AnimateDiffControlNetPipeline,
|
| 1063 |
+
AnimateDiffPAGPipeline,
|
| 1064 |
+
AnimateDiffPipeline,
|
| 1065 |
+
AnimateDiffSDXLPipeline,
|
| 1066 |
+
AnimateDiffSparseControlNetPipeline,
|
| 1067 |
+
AnimateDiffVideoToVideoControlNetPipeline,
|
| 1068 |
+
AnimateDiffVideoToVideoPipeline,
|
| 1069 |
+
AudioLDM2Pipeline,
|
| 1070 |
+
AudioLDM2ProjectionModel,
|
| 1071 |
+
AudioLDM2UNet2DConditionModel,
|
| 1072 |
+
AudioLDMPipeline,
|
| 1073 |
+
AuraFlowPipeline,
|
| 1074 |
+
BriaPipeline,
|
| 1075 |
+
ChromaImg2ImgPipeline,
|
| 1076 |
+
ChromaPipeline,
|
| 1077 |
+
CLIPImageProjection,
|
| 1078 |
+
CogVideoXFunControlPipeline,
|
| 1079 |
+
CogVideoXImageToVideoPipeline,
|
| 1080 |
+
CogVideoXPipeline,
|
| 1081 |
+
CogVideoXVideoToVideoPipeline,
|
| 1082 |
+
CogView3PlusPipeline,
|
| 1083 |
+
CogView4ControlPipeline,
|
| 1084 |
+
CogView4Pipeline,
|
| 1085 |
+
ConsisIDPipeline,
|
| 1086 |
+
Cosmos2TextToImagePipeline,
|
| 1087 |
+
Cosmos2VideoToWorldPipeline,
|
| 1088 |
+
CosmosTextToWorldPipeline,
|
| 1089 |
+
CosmosVideoToWorldPipeline,
|
| 1090 |
+
CycleDiffusionPipeline,
|
| 1091 |
+
EasyAnimateControlPipeline,
|
| 1092 |
+
EasyAnimateInpaintPipeline,
|
| 1093 |
+
EasyAnimatePipeline,
|
| 1094 |
+
FluxControlImg2ImgPipeline,
|
| 1095 |
+
FluxControlInpaintPipeline,
|
| 1096 |
+
FluxControlNetImg2ImgPipeline,
|
| 1097 |
+
FluxControlNetInpaintPipeline,
|
| 1098 |
+
FluxControlNetPipeline,
|
| 1099 |
+
FluxControlPipeline,
|
| 1100 |
+
FluxFillPipeline,
|
| 1101 |
+
FluxImg2ImgPipeline,
|
| 1102 |
+
FluxInpaintPipeline,
|
| 1103 |
+
FluxKontextInpaintPipeline,
|
| 1104 |
+
FluxKontextPipeline,
|
| 1105 |
+
FluxPipeline,
|
| 1106 |
+
FluxPriorReduxPipeline,
|
| 1107 |
+
HiDreamImagePipeline,
|
| 1108 |
+
HunyuanDiTControlNetPipeline,
|
| 1109 |
+
HunyuanDiTPAGPipeline,
|
| 1110 |
+
HunyuanDiTPipeline,
|
| 1111 |
+
HunyuanSkyreelsImageToVideoPipeline,
|
| 1112 |
+
HunyuanVideoFramepackPipeline,
|
| 1113 |
+
HunyuanVideoImageToVideoPipeline,
|
| 1114 |
+
HunyuanVideoPipeline,
|
| 1115 |
+
I2VGenXLPipeline,
|
| 1116 |
+
IFImg2ImgPipeline,
|
| 1117 |
+
IFImg2ImgSuperResolutionPipeline,
|
| 1118 |
+
IFInpaintingPipeline,
|
| 1119 |
+
IFInpaintingSuperResolutionPipeline,
|
| 1120 |
+
IFPipeline,
|
| 1121 |
+
IFSuperResolutionPipeline,
|
| 1122 |
+
ImageTextPipelineOutput,
|
| 1123 |
+
Kandinsky3Img2ImgPipeline,
|
| 1124 |
+
Kandinsky3Pipeline,
|
| 1125 |
+
KandinskyCombinedPipeline,
|
| 1126 |
+
KandinskyImg2ImgCombinedPipeline,
|
| 1127 |
+
KandinskyImg2ImgPipeline,
|
| 1128 |
+
KandinskyInpaintCombinedPipeline,
|
| 1129 |
+
KandinskyInpaintPipeline,
|
| 1130 |
+
KandinskyPipeline,
|
| 1131 |
+
KandinskyPriorPipeline,
|
| 1132 |
+
KandinskyV22CombinedPipeline,
|
| 1133 |
+
KandinskyV22ControlnetImg2ImgPipeline,
|
| 1134 |
+
KandinskyV22ControlnetPipeline,
|
| 1135 |
+
KandinskyV22Img2ImgCombinedPipeline,
|
| 1136 |
+
KandinskyV22Img2ImgPipeline,
|
| 1137 |
+
KandinskyV22InpaintCombinedPipeline,
|
| 1138 |
+
KandinskyV22InpaintPipeline,
|
| 1139 |
+
KandinskyV22Pipeline,
|
| 1140 |
+
KandinskyV22PriorEmb2EmbPipeline,
|
| 1141 |
+
KandinskyV22PriorPipeline,
|
| 1142 |
+
LatentConsistencyModelImg2ImgPipeline,
|
| 1143 |
+
LatentConsistencyModelPipeline,
|
| 1144 |
+
LattePipeline,
|
| 1145 |
+
LDMTextToImagePipeline,
|
| 1146 |
+
LEditsPPPipelineStableDiffusion,
|
| 1147 |
+
LEditsPPPipelineStableDiffusionXL,
|
| 1148 |
+
LTXConditionPipeline,
|
| 1149 |
+
LTXImageToVideoPipeline,
|
| 1150 |
+
LTXLatentUpsamplePipeline,
|
| 1151 |
+
LTXPipeline,
|
| 1152 |
+
Lumina2Pipeline,
|
| 1153 |
+
Lumina2Text2ImgPipeline,
|
| 1154 |
+
LuminaPipeline,
|
| 1155 |
+
LuminaText2ImgPipeline,
|
| 1156 |
+
MarigoldDepthPipeline,
|
| 1157 |
+
MarigoldIntrinsicsPipeline,
|
| 1158 |
+
MarigoldNormalsPipeline,
|
| 1159 |
+
MochiPipeline,
|
| 1160 |
+
MusicLDMPipeline,
|
| 1161 |
+
OmniGenPipeline,
|
| 1162 |
+
PaintByExamplePipeline,
|
| 1163 |
+
PIAPipeline,
|
| 1164 |
+
PixArtAlphaPipeline,
|
| 1165 |
+
PixArtSigmaPAGPipeline,
|
| 1166 |
+
PixArtSigmaPipeline,
|
| 1167 |
+
QwenImageControlNetInpaintPipeline,
|
| 1168 |
+
QwenImageControlNetPipeline,
|
| 1169 |
+
QwenImageEditInpaintPipeline,
|
| 1170 |
+
QwenImageEditPipeline,
|
| 1171 |
+
QwenImageImg2ImgPipeline,
|
| 1172 |
+
QwenImageInpaintPipeline,
|
| 1173 |
+
QwenImagePipeline,
|
| 1174 |
+
ReduxImageEncoder,
|
| 1175 |
+
SanaControlNetPipeline,
|
| 1176 |
+
SanaPAGPipeline,
|
| 1177 |
+
SanaPipeline,
|
| 1178 |
+
SanaSprintImg2ImgPipeline,
|
| 1179 |
+
SanaSprintPipeline,
|
| 1180 |
+
SemanticStableDiffusionPipeline,
|
| 1181 |
+
ShapEImg2ImgPipeline,
|
| 1182 |
+
ShapEPipeline,
|
| 1183 |
+
SkyReelsV2DiffusionForcingImageToVideoPipeline,
|
| 1184 |
+
SkyReelsV2DiffusionForcingPipeline,
|
| 1185 |
+
SkyReelsV2DiffusionForcingVideoToVideoPipeline,
|
| 1186 |
+
SkyReelsV2ImageToVideoPipeline,
|
| 1187 |
+
SkyReelsV2Pipeline,
|
| 1188 |
+
StableAudioPipeline,
|
| 1189 |
+
StableAudioProjectionModel,
|
| 1190 |
+
StableCascadeCombinedPipeline,
|
| 1191 |
+
StableCascadeDecoderPipeline,
|
| 1192 |
+
StableCascadePriorPipeline,
|
| 1193 |
+
StableDiffusion3ControlNetInpaintingPipeline,
|
| 1194 |
+
StableDiffusion3ControlNetPipeline,
|
| 1195 |
+
StableDiffusion3Img2ImgPipeline,
|
| 1196 |
+
StableDiffusion3InpaintPipeline,
|
| 1197 |
+
StableDiffusion3PAGImg2ImgPipeline,
|
| 1198 |
+
StableDiffusion3PAGPipeline,
|
| 1199 |
+
StableDiffusion3Pipeline,
|
| 1200 |
+
StableDiffusionAdapterPipeline,
|
| 1201 |
+
StableDiffusionAttendAndExcitePipeline,
|
| 1202 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
| 1203 |
+
StableDiffusionControlNetInpaintPipeline,
|
| 1204 |
+
StableDiffusionControlNetPAGInpaintPipeline,
|
| 1205 |
+
StableDiffusionControlNetPAGPipeline,
|
| 1206 |
+
StableDiffusionControlNetPipeline,
|
| 1207 |
+
StableDiffusionControlNetXSPipeline,
|
| 1208 |
+
StableDiffusionDepth2ImgPipeline,
|
| 1209 |
+
StableDiffusionDiffEditPipeline,
|
| 1210 |
+
StableDiffusionGLIGENPipeline,
|
| 1211 |
+
StableDiffusionGLIGENTextImagePipeline,
|
| 1212 |
+
StableDiffusionImageVariationPipeline,
|
| 1213 |
+
StableDiffusionImg2ImgPipeline,
|
| 1214 |
+
StableDiffusionInpaintPipeline,
|
| 1215 |
+
StableDiffusionInpaintPipelineLegacy,
|
| 1216 |
+
StableDiffusionInstructPix2PixPipeline,
|
| 1217 |
+
StableDiffusionLatentUpscalePipeline,
|
| 1218 |
+
StableDiffusionLDM3DPipeline,
|
| 1219 |
+
StableDiffusionModelEditingPipeline,
|
| 1220 |
+
StableDiffusionPAGImg2ImgPipeline,
|
| 1221 |
+
StableDiffusionPAGInpaintPipeline,
|
| 1222 |
+
StableDiffusionPAGPipeline,
|
| 1223 |
+
StableDiffusionPanoramaPipeline,
|
| 1224 |
+
StableDiffusionParadigmsPipeline,
|
| 1225 |
+
StableDiffusionPipeline,
|
| 1226 |
+
StableDiffusionPipelineSafe,
|
| 1227 |
+
StableDiffusionPix2PixZeroPipeline,
|
| 1228 |
+
StableDiffusionSAGPipeline,
|
| 1229 |
+
StableDiffusionUpscalePipeline,
|
| 1230 |
+
StableDiffusionXLAdapterPipeline,
|
| 1231 |
+
StableDiffusionXLControlNetImg2ImgPipeline,
|
| 1232 |
+
StableDiffusionXLControlNetInpaintPipeline,
|
| 1233 |
+
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
| 1234 |
+
StableDiffusionXLControlNetPAGPipeline,
|
| 1235 |
+
StableDiffusionXLControlNetPipeline,
|
| 1236 |
+
StableDiffusionXLControlNetUnionImg2ImgPipeline,
|
| 1237 |
+
StableDiffusionXLControlNetUnionInpaintPipeline,
|
| 1238 |
+
StableDiffusionXLControlNetUnionPipeline,
|
| 1239 |
+
StableDiffusionXLControlNetXSPipeline,
|
| 1240 |
+
StableDiffusionXLImg2ImgPipeline,
|
| 1241 |
+
StableDiffusionXLInpaintPipeline,
|
| 1242 |
+
StableDiffusionXLInstructPix2PixPipeline,
|
| 1243 |
+
StableDiffusionXLPAGImg2ImgPipeline,
|
| 1244 |
+
StableDiffusionXLPAGInpaintPipeline,
|
| 1245 |
+
StableDiffusionXLPAGPipeline,
|
| 1246 |
+
StableDiffusionXLPipeline,
|
| 1247 |
+
StableUnCLIPImg2ImgPipeline,
|
| 1248 |
+
StableUnCLIPPipeline,
|
| 1249 |
+
StableVideoDiffusionPipeline,
|
| 1250 |
+
TextToVideoSDPipeline,
|
| 1251 |
+
TextToVideoZeroPipeline,
|
| 1252 |
+
TextToVideoZeroSDXLPipeline,
|
| 1253 |
+
UnCLIPImageVariationPipeline,
|
| 1254 |
+
UnCLIPPipeline,
|
| 1255 |
+
UniDiffuserModel,
|
| 1256 |
+
UniDiffuserPipeline,
|
| 1257 |
+
UniDiffuserTextDecoder,
|
| 1258 |
+
VersatileDiffusionDualGuidedPipeline,
|
| 1259 |
+
VersatileDiffusionImageVariationPipeline,
|
| 1260 |
+
VersatileDiffusionPipeline,
|
| 1261 |
+
VersatileDiffusionTextToImagePipeline,
|
| 1262 |
+
VideoToVideoSDPipeline,
|
| 1263 |
+
VisualClozeGenerationPipeline,
|
| 1264 |
+
VisualClozePipeline,
|
| 1265 |
+
VQDiffusionPipeline,
|
| 1266 |
+
WanImageToVideoPipeline,
|
| 1267 |
+
WanPipeline,
|
| 1268 |
+
WanVACEPipeline,
|
| 1269 |
+
WanVideoToVideoPipeline,
|
| 1270 |
+
WuerstchenCombinedPipeline,
|
| 1271 |
+
WuerstchenDecoderPipeline,
|
| 1272 |
+
WuerstchenPriorPipeline,
|
| 1273 |
+
)
|
| 1274 |
+
|
| 1275 |
+
try:
|
| 1276 |
+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
| 1277 |
+
raise OptionalDependencyNotAvailable()
|
| 1278 |
+
except OptionalDependencyNotAvailable:
|
| 1279 |
+
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
| 1280 |
+
else:
|
| 1281 |
+
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
|
| 1282 |
+
|
| 1283 |
+
try:
|
| 1284 |
+
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
| 1285 |
+
raise OptionalDependencyNotAvailable()
|
| 1286 |
+
except OptionalDependencyNotAvailable:
|
| 1287 |
+
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
| 1288 |
+
else:
|
| 1289 |
+
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
|
| 1290 |
+
|
| 1291 |
+
try:
|
| 1292 |
+
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
| 1293 |
+
raise OptionalDependencyNotAvailable()
|
| 1294 |
+
except OptionalDependencyNotAvailable:
|
| 1295 |
+
from .utils.dummy_torch_and_transformers_and_opencv_objects import * # noqa F403
|
| 1296 |
+
else:
|
| 1297 |
+
from .pipelines import ConsisIDPipeline
|
| 1298 |
+
|
| 1299 |
+
try:
|
| 1300 |
+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
| 1301 |
+
raise OptionalDependencyNotAvailable()
|
| 1302 |
+
except OptionalDependencyNotAvailable:
|
| 1303 |
+
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
| 1304 |
+
else:
|
| 1305 |
+
from .pipelines import (
|
| 1306 |
+
OnnxStableDiffusionImg2ImgPipeline,
|
| 1307 |
+
OnnxStableDiffusionInpaintPipeline,
|
| 1308 |
+
OnnxStableDiffusionInpaintPipelineLegacy,
|
| 1309 |
+
OnnxStableDiffusionPipeline,
|
| 1310 |
+
OnnxStableDiffusionUpscalePipeline,
|
| 1311 |
+
StableDiffusionOnnxPipeline,
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
try:
|
| 1315 |
+
if not (is_torch_available() and is_librosa_available()):
|
| 1316 |
+
raise OptionalDependencyNotAvailable()
|
| 1317 |
+
except OptionalDependencyNotAvailable:
|
| 1318 |
+
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
| 1319 |
+
else:
|
| 1320 |
+
from .pipelines import AudioDiffusionPipeline, Mel
|
| 1321 |
+
|
| 1322 |
+
try:
|
| 1323 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
| 1324 |
+
raise OptionalDependencyNotAvailable()
|
| 1325 |
+
except OptionalDependencyNotAvailable:
|
| 1326 |
+
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
| 1327 |
+
else:
|
| 1328 |
+
from .pipelines import SpectrogramDiffusionPipeline
|
| 1329 |
+
|
| 1330 |
+
try:
|
| 1331 |
+
if not is_flax_available():
|
| 1332 |
+
raise OptionalDependencyNotAvailable()
|
| 1333 |
+
except OptionalDependencyNotAvailable:
|
| 1334 |
+
from .utils.dummy_flax_objects import * # noqa F403
|
| 1335 |
+
else:
|
| 1336 |
+
from .models.controlnets.controlnet_flax import FlaxControlNetModel
|
| 1337 |
+
from .models.modeling_flax_utils import FlaxModelMixin
|
| 1338 |
+
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
| 1339 |
+
from .models.vae_flax import FlaxAutoencoderKL
|
| 1340 |
+
from .pipelines import FlaxDiffusionPipeline
|
| 1341 |
+
from .schedulers import (
|
| 1342 |
+
FlaxDDIMScheduler,
|
| 1343 |
+
FlaxDDPMScheduler,
|
| 1344 |
+
FlaxDPMSolverMultistepScheduler,
|
| 1345 |
+
FlaxEulerDiscreteScheduler,
|
| 1346 |
+
FlaxKarrasVeScheduler,
|
| 1347 |
+
FlaxLMSDiscreteScheduler,
|
| 1348 |
+
FlaxPNDMScheduler,
|
| 1349 |
+
FlaxSchedulerMixin,
|
| 1350 |
+
FlaxScoreSdeVeScheduler,
|
| 1351 |
+
)
|
| 1352 |
+
|
| 1353 |
+
try:
|
| 1354 |
+
if not (is_flax_available() and is_transformers_available()):
|
| 1355 |
+
raise OptionalDependencyNotAvailable()
|
| 1356 |
+
except OptionalDependencyNotAvailable:
|
| 1357 |
+
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
| 1358 |
+
else:
|
| 1359 |
+
from .pipelines import (
|
| 1360 |
+
FlaxStableDiffusionControlNetPipeline,
|
| 1361 |
+
FlaxStableDiffusionImg2ImgPipeline,
|
| 1362 |
+
FlaxStableDiffusionInpaintPipeline,
|
| 1363 |
+
FlaxStableDiffusionPipeline,
|
| 1364 |
+
FlaxStableDiffusionXLPipeline,
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
try:
|
| 1368 |
+
if not (is_note_seq_available()):
|
| 1369 |
+
raise OptionalDependencyNotAvailable()
|
| 1370 |
+
except OptionalDependencyNotAvailable:
|
| 1371 |
+
from .utils.dummy_note_seq_objects import * # noqa F403
|
| 1372 |
+
else:
|
| 1373 |
+
from .pipelines import MidiProcessor
|
| 1374 |
+
|
| 1375 |
+
else:
|
| 1376 |
+
import sys
|
| 1377 |
+
|
| 1378 |
+
sys.modules[__name__] = _LazyModule(
|
| 1379 |
+
__name__,
|
| 1380 |
+
globals()["__file__"],
|
| 1381 |
+
_import_structure,
|
| 1382 |
+
module_spec=__spec__,
|
| 1383 |
+
extra_objects={"__version__": __version__},
|
| 1384 |
+
)
|
pythonProject/diffusers-main/build/lib/diffusers/commands/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BaseDiffusersCLICommand(ABC):
|
| 20 |
+
@staticmethod
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def register_subcommand(parser: ArgumentParser):
|
| 23 |
+
raise NotImplementedError()
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def run(self):
|
| 27 |
+
raise NotImplementedError()
|
pythonProject/diffusers-main/build/lib/diffusers/commands/custom_blocks.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Usage example:
|
| 17 |
+
TODO
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import ast
|
| 21 |
+
import importlib.util
|
| 22 |
+
import os
|
| 23 |
+
from argparse import ArgumentParser, Namespace
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
from ..utils import logging
|
| 27 |
+
from . import BaseDiffusersCLICommand
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
|
| 31 |
+
CONFIG = "config.json"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def conversion_command_factory(args: Namespace):
|
| 35 |
+
return CustomBlocksCommand(args.block_module_name, args.block_class_name)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CustomBlocksCommand(BaseDiffusersCLICommand):
|
| 39 |
+
@staticmethod
|
| 40 |
+
def register_subcommand(parser: ArgumentParser):
|
| 41 |
+
conversion_parser = parser.add_parser("custom_blocks")
|
| 42 |
+
conversion_parser.add_argument(
|
| 43 |
+
"--block_module_name",
|
| 44 |
+
type=str,
|
| 45 |
+
default="block.py",
|
| 46 |
+
help="Module filename in which the custom block will be implemented.",
|
| 47 |
+
)
|
| 48 |
+
conversion_parser.add_argument(
|
| 49 |
+
"--block_class_name",
|
| 50 |
+
type=str,
|
| 51 |
+
default=None,
|
| 52 |
+
help="Name of the custom block. If provided None, we will try to infer it.",
|
| 53 |
+
)
|
| 54 |
+
conversion_parser.set_defaults(func=conversion_command_factory)
|
| 55 |
+
|
| 56 |
+
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
|
| 57 |
+
self.logger = logging.get_logger("diffusers-cli/custom_blocks")
|
| 58 |
+
self.block_module_name = Path(block_module_name)
|
| 59 |
+
self.block_class_name = block_class_name
|
| 60 |
+
|
| 61 |
+
def run(self):
|
| 62 |
+
# determine the block to be saved.
|
| 63 |
+
out = self._get_class_names(self.block_module_name)
|
| 64 |
+
classes_found = list({cls for cls, _ in out})
|
| 65 |
+
|
| 66 |
+
if self.block_class_name is not None:
|
| 67 |
+
child_class, parent_class = self._choose_block(out, self.block_class_name)
|
| 68 |
+
if child_class is None and parent_class is None:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"`block_class_name` could not be retrieved. Available classes from "
|
| 71 |
+
f"{self.block_module_name}:\n{classes_found}"
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
self.logger.info(
|
| 75 |
+
f"Found classes: {classes_found} will be using {classes_found[0]}. "
|
| 76 |
+
"If this needs to be changed, re-run the command specifying `block_class_name`."
|
| 77 |
+
)
|
| 78 |
+
child_class, parent_class = out[0][0], out[0][1]
|
| 79 |
+
|
| 80 |
+
# dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
|
| 81 |
+
# the user is responsible for running it, so I guess that is safe?
|
| 82 |
+
module_name = f"__dynamic__{self.block_module_name.stem}"
|
| 83 |
+
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
|
| 84 |
+
module = importlib.util.module_from_spec(spec)
|
| 85 |
+
spec.loader.exec_module(module)
|
| 86 |
+
getattr(module, child_class)().save_pretrained(os.getcwd())
|
| 87 |
+
|
| 88 |
+
# or, we could create it manually.
|
| 89 |
+
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
| 90 |
+
# with open(CONFIG, "w") as f:
|
| 91 |
+
# json.dump(automap, f)
|
| 92 |
+
with open("requirements.txt", "w") as f:
|
| 93 |
+
f.write("")
|
| 94 |
+
|
| 95 |
+
def _choose_block(self, candidates, chosen=None):
|
| 96 |
+
for cls, base in candidates:
|
| 97 |
+
if cls == chosen:
|
| 98 |
+
return cls, base
|
| 99 |
+
return None, None
|
| 100 |
+
|
| 101 |
+
def _get_class_names(self, file_path):
|
| 102 |
+
source = file_path.read_text(encoding="utf-8")
|
| 103 |
+
try:
|
| 104 |
+
tree = ast.parse(source, filename=file_path)
|
| 105 |
+
except SyntaxError as e:
|
| 106 |
+
raise ValueError(f"Could not parse {file_path!r}: {e}") from e
|
| 107 |
+
|
| 108 |
+
results: list[tuple[str, str]] = []
|
| 109 |
+
for node in tree.body:
|
| 110 |
+
if not isinstance(node, ast.ClassDef):
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
# extract all base names for this class
|
| 114 |
+
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
|
| 115 |
+
|
| 116 |
+
# for each allowed base that appears in the class's bases, emit a tuple
|
| 117 |
+
for allowed in EXPECTED_PARENT_CLASSES:
|
| 118 |
+
if allowed in base_names:
|
| 119 |
+
results.append((node.name, allowed))
|
| 120 |
+
|
| 121 |
+
return results
|
| 122 |
+
|
| 123 |
+
def _get_base_name(self, node: ast.expr):
|
| 124 |
+
if isinstance(node, ast.Name):
|
| 125 |
+
return node.id
|
| 126 |
+
elif isinstance(node, ast.Attribute):
|
| 127 |
+
val = self._get_base_name(node.value)
|
| 128 |
+
return f"{val}.{node.attr}" if val else node.attr
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
def _create_automap(self, parent_class, child_class):
|
| 132 |
+
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
|
| 133 |
+
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
|
| 134 |
+
return {"auto_map": auto_map}
|
pythonProject/diffusers-main/build/lib/diffusers/commands/diffusers_cli.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
|
| 18 |
+
from .custom_blocks import CustomBlocksCommand
|
| 19 |
+
from .env import EnvironmentCommand
|
| 20 |
+
from .fp16_safetensors import FP16SafetensorsCommand
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
|
| 25 |
+
commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
|
| 26 |
+
|
| 27 |
+
# Register commands
|
| 28 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
| 29 |
+
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
| 30 |
+
CustomBlocksCommand.register_subcommand(commands_parser)
|
| 31 |
+
|
| 32 |
+
# Let's go
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
|
| 35 |
+
if not hasattr(args, "func"):
|
| 36 |
+
parser.print_help()
|
| 37 |
+
exit(1)
|
| 38 |
+
|
| 39 |
+
# Run
|
| 40 |
+
service = args.func(args)
|
| 41 |
+
service.run()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
pythonProject/diffusers-main/build/lib/diffusers/commands/env.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import platform
|
| 16 |
+
import subprocess
|
| 17 |
+
from argparse import ArgumentParser
|
| 18 |
+
|
| 19 |
+
import huggingface_hub
|
| 20 |
+
|
| 21 |
+
from .. import __version__ as version
|
| 22 |
+
from ..utils import (
|
| 23 |
+
is_accelerate_available,
|
| 24 |
+
is_bitsandbytes_available,
|
| 25 |
+
is_flax_available,
|
| 26 |
+
is_google_colab,
|
| 27 |
+
is_peft_available,
|
| 28 |
+
is_safetensors_available,
|
| 29 |
+
is_torch_available,
|
| 30 |
+
is_transformers_available,
|
| 31 |
+
is_xformers_available,
|
| 32 |
+
)
|
| 33 |
+
from . import BaseDiffusersCLICommand
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def info_command_factory(_):
|
| 37 |
+
return EnvironmentCommand()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EnvironmentCommand(BaseDiffusersCLICommand):
|
| 41 |
+
@staticmethod
|
| 42 |
+
def register_subcommand(parser: ArgumentParser) -> None:
|
| 43 |
+
download_parser = parser.add_parser("env")
|
| 44 |
+
download_parser.set_defaults(func=info_command_factory)
|
| 45 |
+
|
| 46 |
+
def run(self) -> dict:
|
| 47 |
+
hub_version = huggingface_hub.__version__
|
| 48 |
+
|
| 49 |
+
safetensors_version = "not installed"
|
| 50 |
+
if is_safetensors_available():
|
| 51 |
+
import safetensors
|
| 52 |
+
|
| 53 |
+
safetensors_version = safetensors.__version__
|
| 54 |
+
|
| 55 |
+
pt_version = "not installed"
|
| 56 |
+
pt_cuda_available = "NA"
|
| 57 |
+
if is_torch_available():
|
| 58 |
+
import torch
|
| 59 |
+
|
| 60 |
+
pt_version = torch.__version__
|
| 61 |
+
pt_cuda_available = torch.cuda.is_available()
|
| 62 |
+
|
| 63 |
+
flax_version = "not installed"
|
| 64 |
+
jax_version = "not installed"
|
| 65 |
+
jaxlib_version = "not installed"
|
| 66 |
+
jax_backend = "NA"
|
| 67 |
+
if is_flax_available():
|
| 68 |
+
import flax
|
| 69 |
+
import jax
|
| 70 |
+
import jaxlib
|
| 71 |
+
|
| 72 |
+
flax_version = flax.__version__
|
| 73 |
+
jax_version = jax.__version__
|
| 74 |
+
jaxlib_version = jaxlib.__version__
|
| 75 |
+
jax_backend = jax.lib.xla_bridge.get_backend().platform
|
| 76 |
+
|
| 77 |
+
transformers_version = "not installed"
|
| 78 |
+
if is_transformers_available():
|
| 79 |
+
import transformers
|
| 80 |
+
|
| 81 |
+
transformers_version = transformers.__version__
|
| 82 |
+
|
| 83 |
+
accelerate_version = "not installed"
|
| 84 |
+
if is_accelerate_available():
|
| 85 |
+
import accelerate
|
| 86 |
+
|
| 87 |
+
accelerate_version = accelerate.__version__
|
| 88 |
+
|
| 89 |
+
peft_version = "not installed"
|
| 90 |
+
if is_peft_available():
|
| 91 |
+
import peft
|
| 92 |
+
|
| 93 |
+
peft_version = peft.__version__
|
| 94 |
+
|
| 95 |
+
bitsandbytes_version = "not installed"
|
| 96 |
+
if is_bitsandbytes_available():
|
| 97 |
+
import bitsandbytes
|
| 98 |
+
|
| 99 |
+
bitsandbytes_version = bitsandbytes.__version__
|
| 100 |
+
|
| 101 |
+
xformers_version = "not installed"
|
| 102 |
+
if is_xformers_available():
|
| 103 |
+
import xformers
|
| 104 |
+
|
| 105 |
+
xformers_version = xformers.__version__
|
| 106 |
+
|
| 107 |
+
platform_info = platform.platform()
|
| 108 |
+
|
| 109 |
+
is_google_colab_str = "Yes" if is_google_colab() else "No"
|
| 110 |
+
|
| 111 |
+
accelerator = "NA"
|
| 112 |
+
if platform.system() in {"Linux", "Windows"}:
|
| 113 |
+
try:
|
| 114 |
+
sp = subprocess.Popen(
|
| 115 |
+
["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
|
| 116 |
+
stdout=subprocess.PIPE,
|
| 117 |
+
stderr=subprocess.PIPE,
|
| 118 |
+
)
|
| 119 |
+
out_str, _ = sp.communicate()
|
| 120 |
+
out_str = out_str.decode("utf-8")
|
| 121 |
+
|
| 122 |
+
if len(out_str) > 0:
|
| 123 |
+
accelerator = out_str.strip()
|
| 124 |
+
except FileNotFoundError:
|
| 125 |
+
pass
|
| 126 |
+
elif platform.system() == "Darwin": # Mac OS
|
| 127 |
+
try:
|
| 128 |
+
sp = subprocess.Popen(
|
| 129 |
+
["system_profiler", "SPDisplaysDataType"],
|
| 130 |
+
stdout=subprocess.PIPE,
|
| 131 |
+
stderr=subprocess.PIPE,
|
| 132 |
+
)
|
| 133 |
+
out_str, _ = sp.communicate()
|
| 134 |
+
out_str = out_str.decode("utf-8")
|
| 135 |
+
|
| 136 |
+
start = out_str.find("Chipset Model:")
|
| 137 |
+
if start != -1:
|
| 138 |
+
start += len("Chipset Model:")
|
| 139 |
+
end = out_str.find("\n", start)
|
| 140 |
+
accelerator = out_str[start:end].strip()
|
| 141 |
+
|
| 142 |
+
start = out_str.find("VRAM (Total):")
|
| 143 |
+
if start != -1:
|
| 144 |
+
start += len("VRAM (Total):")
|
| 145 |
+
end = out_str.find("\n", start)
|
| 146 |
+
accelerator += " VRAM: " + out_str[start:end].strip()
|
| 147 |
+
except FileNotFoundError:
|
| 148 |
+
pass
|
| 149 |
+
else:
|
| 150 |
+
print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
|
| 151 |
+
|
| 152 |
+
info = {
|
| 153 |
+
"🤗 Diffusers version": version,
|
| 154 |
+
"Platform": platform_info,
|
| 155 |
+
"Running on Google Colab?": is_google_colab_str,
|
| 156 |
+
"Python version": platform.python_version(),
|
| 157 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
| 158 |
+
"Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
|
| 159 |
+
"Jax version": jax_version,
|
| 160 |
+
"JaxLib version": jaxlib_version,
|
| 161 |
+
"Huggingface_hub version": hub_version,
|
| 162 |
+
"Transformers version": transformers_version,
|
| 163 |
+
"Accelerate version": accelerate_version,
|
| 164 |
+
"PEFT version": peft_version,
|
| 165 |
+
"Bitsandbytes version": bitsandbytes_version,
|
| 166 |
+
"Safetensors version": safetensors_version,
|
| 167 |
+
"xFormers version": xformers_version,
|
| 168 |
+
"Accelerator": accelerator,
|
| 169 |
+
"Using GPU in script?": "<fill in>",
|
| 170 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
| 174 |
+
print(self.format_dict(info))
|
| 175 |
+
|
| 176 |
+
return info
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def format_dict(d: dict) -> str:
|
| 180 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
pythonProject/diffusers-main/build/lib/diffusers/commands/fp16_safetensors.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Usage example:
|
| 17 |
+
diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import glob
|
| 21 |
+
import json
|
| 22 |
+
import warnings
|
| 23 |
+
from argparse import ArgumentParser, Namespace
|
| 24 |
+
from importlib import import_module
|
| 25 |
+
|
| 26 |
+
import huggingface_hub
|
| 27 |
+
import torch
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
from packaging import version
|
| 30 |
+
|
| 31 |
+
from ..utils import logging
|
| 32 |
+
from . import BaseDiffusersCLICommand
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def conversion_command_factory(args: Namespace):
|
| 36 |
+
if args.use_auth_token:
|
| 37 |
+
warnings.warn(
|
| 38 |
+
"The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
|
| 39 |
+
" handled automatically if user is logged in."
|
| 40 |
+
)
|
| 41 |
+
return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
| 45 |
+
@staticmethod
|
| 46 |
+
def register_subcommand(parser: ArgumentParser):
|
| 47 |
+
conversion_parser = parser.add_parser("fp16_safetensors")
|
| 48 |
+
conversion_parser.add_argument(
|
| 49 |
+
"--ckpt_id",
|
| 50 |
+
type=str,
|
| 51 |
+
help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
|
| 52 |
+
)
|
| 53 |
+
conversion_parser.add_argument(
|
| 54 |
+
"--fp16", action="store_true", help="If serializing the variables in FP16 precision."
|
| 55 |
+
)
|
| 56 |
+
conversion_parser.add_argument(
|
| 57 |
+
"--use_safetensors", action="store_true", help="If serializing in the safetensors format."
|
| 58 |
+
)
|
| 59 |
+
conversion_parser.add_argument(
|
| 60 |
+
"--use_auth_token",
|
| 61 |
+
action="store_true",
|
| 62 |
+
help="When working with checkpoints having private visibility. When used `hf auth login` needs to be run beforehand.",
|
| 63 |
+
)
|
| 64 |
+
conversion_parser.set_defaults(func=conversion_command_factory)
|
| 65 |
+
|
| 66 |
+
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
|
| 67 |
+
self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
|
| 68 |
+
self.ckpt_id = ckpt_id
|
| 69 |
+
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
| 70 |
+
self.fp16 = fp16
|
| 71 |
+
|
| 72 |
+
self.use_safetensors = use_safetensors
|
| 73 |
+
|
| 74 |
+
if not self.use_safetensors and not self.fp16:
|
| 75 |
+
raise NotImplementedError(
|
| 76 |
+
"When `use_safetensors` and `fp16` both are False, then this command is of no use."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def run(self):
|
| 80 |
+
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
|
| 81 |
+
raise ImportError(
|
| 82 |
+
"The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
|
| 83 |
+
" installation."
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
from huggingface_hub import create_commit
|
| 87 |
+
from huggingface_hub._commit_api import CommitOperationAdd
|
| 88 |
+
|
| 89 |
+
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
|
| 90 |
+
with open(model_index, "r") as f:
|
| 91 |
+
pipeline_class_name = json.load(f)["_class_name"]
|
| 92 |
+
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
|
| 93 |
+
self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
|
| 94 |
+
|
| 95 |
+
# Load the appropriate pipeline. We could have use `DiffusionPipeline`
|
| 96 |
+
# here, but just to avoid any rough edge cases.
|
| 97 |
+
pipeline = pipeline_class.from_pretrained(
|
| 98 |
+
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
|
| 99 |
+
)
|
| 100 |
+
pipeline.save_pretrained(
|
| 101 |
+
self.local_ckpt_dir,
|
| 102 |
+
safe_serialization=True if self.use_safetensors else False,
|
| 103 |
+
variant="fp16" if self.fp16 else None,
|
| 104 |
+
)
|
| 105 |
+
self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
|
| 106 |
+
|
| 107 |
+
# Fetch all the paths.
|
| 108 |
+
if self.fp16:
|
| 109 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
|
| 110 |
+
elif self.use_safetensors:
|
| 111 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
|
| 112 |
+
|
| 113 |
+
# Prepare for the PR.
|
| 114 |
+
commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
|
| 115 |
+
operations = []
|
| 116 |
+
for path in modified_paths:
|
| 117 |
+
operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
|
| 118 |
+
|
| 119 |
+
# Open the PR.
|
| 120 |
+
commit_description = (
|
| 121 |
+
"Variables converted by the [`diffusers`' `fp16_safetensors`"
|
| 122 |
+
" CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
|
| 123 |
+
)
|
| 124 |
+
hub_pr_url = create_commit(
|
| 125 |
+
repo_id=self.ckpt_id,
|
| 126 |
+
operations=operations,
|
| 127 |
+
commit_message=commit_message,
|
| 128 |
+
commit_description=commit_description,
|
| 129 |
+
repo_type="model",
|
| 130 |
+
create_pr=True,
|
| 131 |
+
).pr_url
|
| 132 |
+
self.logger.info(f"PR created here: {hub_pr_url}.")
|
pythonProject/diffusers-main/build/lib/diffusers/experimental/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .rl import ValueGuidedRLPipeline
|
pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .value_guided_sampling import ValueGuidedRLPipeline
|
pythonProject/diffusers-main/build/lib/diffusers/experimental/rl/value_guided_sampling.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import tqdm
|
| 18 |
+
|
| 19 |
+
from ...models.unets.unet_1d import UNet1DModel
|
| 20 |
+
from ...pipelines import DiffusionPipeline
|
| 21 |
+
from ...utils.dummy_pt_objects import DDPMScheduler
|
| 22 |
+
from ...utils.torch_utils import randn_tensor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ValueGuidedRLPipeline(DiffusionPipeline):
|
| 26 |
+
r"""
|
| 27 |
+
Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
|
| 28 |
+
|
| 29 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 30 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 31 |
+
|
| 32 |
+
Parameters:
|
| 33 |
+
value_function ([`UNet1DModel`]):
|
| 34 |
+
A specialized UNet for fine-tuning trajectories base on reward.
|
| 35 |
+
unet ([`UNet1DModel`]):
|
| 36 |
+
UNet architecture to denoise the encoded trajectories.
|
| 37 |
+
scheduler ([`SchedulerMixin`]):
|
| 38 |
+
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
|
| 39 |
+
application is [`DDPMScheduler`].
|
| 40 |
+
env ():
|
| 41 |
+
An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
value_function: UNet1DModel,
|
| 47 |
+
unet: UNet1DModel,
|
| 48 |
+
scheduler: DDPMScheduler,
|
| 49 |
+
env,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
|
| 54 |
+
|
| 55 |
+
self.data = env.get_dataset()
|
| 56 |
+
self.means = {}
|
| 57 |
+
for key in self.data.keys():
|
| 58 |
+
try:
|
| 59 |
+
self.means[key] = self.data[key].mean()
|
| 60 |
+
except: # noqa: E722
|
| 61 |
+
pass
|
| 62 |
+
self.stds = {}
|
| 63 |
+
for key in self.data.keys():
|
| 64 |
+
try:
|
| 65 |
+
self.stds[key] = self.data[key].std()
|
| 66 |
+
except: # noqa: E722
|
| 67 |
+
pass
|
| 68 |
+
self.state_dim = env.observation_space.shape[0]
|
| 69 |
+
self.action_dim = env.action_space.shape[0]
|
| 70 |
+
|
| 71 |
+
def normalize(self, x_in, key):
|
| 72 |
+
return (x_in - self.means[key]) / self.stds[key]
|
| 73 |
+
|
| 74 |
+
def de_normalize(self, x_in, key):
|
| 75 |
+
return x_in * self.stds[key] + self.means[key]
|
| 76 |
+
|
| 77 |
+
def to_torch(self, x_in):
|
| 78 |
+
if isinstance(x_in, dict):
|
| 79 |
+
return {k: self.to_torch(v) for k, v in x_in.items()}
|
| 80 |
+
elif torch.is_tensor(x_in):
|
| 81 |
+
return x_in.to(self.unet.device)
|
| 82 |
+
return torch.tensor(x_in, device=self.unet.device)
|
| 83 |
+
|
| 84 |
+
def reset_x0(self, x_in, cond, act_dim):
|
| 85 |
+
for key, val in cond.items():
|
| 86 |
+
x_in[:, key, act_dim:] = val.clone()
|
| 87 |
+
return x_in
|
| 88 |
+
|
| 89 |
+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
| 90 |
+
batch_size = x.shape[0]
|
| 91 |
+
y = None
|
| 92 |
+
for i in tqdm.tqdm(self.scheduler.timesteps):
|
| 93 |
+
# create batch of timesteps to pass into model
|
| 94 |
+
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
|
| 95 |
+
for _ in range(n_guide_steps):
|
| 96 |
+
with torch.enable_grad():
|
| 97 |
+
x.requires_grad_()
|
| 98 |
+
|
| 99 |
+
# permute to match dimension for pre-trained models
|
| 100 |
+
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
| 101 |
+
grad = torch.autograd.grad([y.sum()], [x])[0]
|
| 102 |
+
|
| 103 |
+
posterior_variance = self.scheduler._get_variance(i)
|
| 104 |
+
model_std = torch.exp(0.5 * posterior_variance)
|
| 105 |
+
grad = model_std * grad
|
| 106 |
+
|
| 107 |
+
grad[timesteps < 2] = 0
|
| 108 |
+
x = x.detach()
|
| 109 |
+
x = x + scale * grad
|
| 110 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
| 111 |
+
|
| 112 |
+
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
| 113 |
+
|
| 114 |
+
# TODO: verify deprecation of this kwarg
|
| 115 |
+
x = self.scheduler.step(prev_x, i, x)["prev_sample"]
|
| 116 |
+
|
| 117 |
+
# apply conditions to the trajectory (set the initial state)
|
| 118 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
| 119 |
+
x = self.to_torch(x)
|
| 120 |
+
return x, y
|
| 121 |
+
|
| 122 |
+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
| 123 |
+
# normalize the observations and create batch dimension
|
| 124 |
+
obs = self.normalize(obs, "observations")
|
| 125 |
+
obs = obs[None].repeat(batch_size, axis=0)
|
| 126 |
+
|
| 127 |
+
conditions = {0: self.to_torch(obs)}
|
| 128 |
+
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
| 129 |
+
|
| 130 |
+
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
| 131 |
+
x1 = randn_tensor(shape, device=self.unet.device)
|
| 132 |
+
x = self.reset_x0(x1, conditions, self.action_dim)
|
| 133 |
+
x = self.to_torch(x)
|
| 134 |
+
|
| 135 |
+
# run the diffusion process
|
| 136 |
+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
| 137 |
+
|
| 138 |
+
# sort output trajectories by value
|
| 139 |
+
sorted_idx = y.argsort(0, descending=True).squeeze()
|
| 140 |
+
sorted_values = x[sorted_idx]
|
| 141 |
+
actions = sorted_values[:, :, : self.action_dim]
|
| 142 |
+
actions = actions.detach().cpu().numpy()
|
| 143 |
+
denorm_actions = self.de_normalize(actions, key="actions")
|
| 144 |
+
|
| 145 |
+
# select the action with the highest value
|
| 146 |
+
if y is not None:
|
| 147 |
+
selected_index = 0
|
| 148 |
+
else:
|
| 149 |
+
# if we didn't run value guiding, select a random action
|
| 150 |
+
selected_index = np.random.randint(0, batch_size)
|
| 151 |
+
|
| 152 |
+
denorm_actions = denorm_actions[selected_index, 0]
|
| 153 |
+
return denorm_actions
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/adaptive_projected_guidance.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AdaptiveProjectedGuidance(BaseGuidance):
|
| 29 |
+
"""
|
| 30 |
+
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 34 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 35 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 36 |
+
deterioration of image quality.
|
| 37 |
+
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
| 38 |
+
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
| 39 |
+
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
| 40 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 41 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 42 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 43 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 44 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 45 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 46 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 47 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 48 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 49 |
+
start (`float`, defaults to `0.0`):
|
| 50 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 51 |
+
stop (`float`, defaults to `1.0`):
|
| 52 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 56 |
+
|
| 57 |
+
@register_to_config
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
guidance_scale: float = 7.5,
|
| 61 |
+
adaptive_projected_guidance_momentum: Optional[float] = None,
|
| 62 |
+
adaptive_projected_guidance_rescale: float = 15.0,
|
| 63 |
+
eta: float = 1.0,
|
| 64 |
+
guidance_rescale: float = 0.0,
|
| 65 |
+
use_original_formulation: bool = False,
|
| 66 |
+
start: float = 0.0,
|
| 67 |
+
stop: float = 1.0,
|
| 68 |
+
):
|
| 69 |
+
super().__init__(start, stop)
|
| 70 |
+
|
| 71 |
+
self.guidance_scale = guidance_scale
|
| 72 |
+
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
| 73 |
+
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
| 74 |
+
self.eta = eta
|
| 75 |
+
self.guidance_rescale = guidance_rescale
|
| 76 |
+
self.use_original_formulation = use_original_formulation
|
| 77 |
+
self.momentum_buffer = None
|
| 78 |
+
|
| 79 |
+
def prepare_inputs(
|
| 80 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 81 |
+
) -> List["BlockState"]:
|
| 82 |
+
if input_fields is None:
|
| 83 |
+
input_fields = self._input_fields
|
| 84 |
+
|
| 85 |
+
if self._step == 0:
|
| 86 |
+
if self.adaptive_projected_guidance_momentum is not None:
|
| 87 |
+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
| 88 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 89 |
+
data_batches = []
|
| 90 |
+
for i in range(self.num_conditions):
|
| 91 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
| 92 |
+
data_batches.append(data_batch)
|
| 93 |
+
return data_batches
|
| 94 |
+
|
| 95 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
| 96 |
+
pred = None
|
| 97 |
+
|
| 98 |
+
if not self._is_apg_enabled():
|
| 99 |
+
pred = pred_cond
|
| 100 |
+
else:
|
| 101 |
+
pred = normalized_guidance(
|
| 102 |
+
pred_cond,
|
| 103 |
+
pred_uncond,
|
| 104 |
+
self.guidance_scale,
|
| 105 |
+
self.momentum_buffer,
|
| 106 |
+
self.eta,
|
| 107 |
+
self.adaptive_projected_guidance_rescale,
|
| 108 |
+
self.use_original_formulation,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if self.guidance_rescale > 0.0:
|
| 112 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 113 |
+
|
| 114 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def is_conditional(self) -> bool:
|
| 118 |
+
return self._count_prepared == 1
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def num_conditions(self) -> int:
|
| 122 |
+
num_conditions = 1
|
| 123 |
+
if self._is_apg_enabled():
|
| 124 |
+
num_conditions += 1
|
| 125 |
+
return num_conditions
|
| 126 |
+
|
| 127 |
+
def _is_apg_enabled(self) -> bool:
|
| 128 |
+
if not self._enabled:
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
is_within_range = True
|
| 132 |
+
if self._num_inference_steps is not None:
|
| 133 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 134 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 135 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 136 |
+
|
| 137 |
+
is_close = False
|
| 138 |
+
if self.use_original_formulation:
|
| 139 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 140 |
+
else:
|
| 141 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 142 |
+
|
| 143 |
+
return is_within_range and not is_close
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class MomentumBuffer:
|
| 147 |
+
def __init__(self, momentum: float):
|
| 148 |
+
self.momentum = momentum
|
| 149 |
+
self.running_average = 0
|
| 150 |
+
|
| 151 |
+
def update(self, update_value: torch.Tensor):
|
| 152 |
+
new_average = self.momentum * self.running_average
|
| 153 |
+
self.running_average = update_value + new_average
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def normalized_guidance(
|
| 157 |
+
pred_cond: torch.Tensor,
|
| 158 |
+
pred_uncond: torch.Tensor,
|
| 159 |
+
guidance_scale: float,
|
| 160 |
+
momentum_buffer: Optional[MomentumBuffer] = None,
|
| 161 |
+
eta: float = 1.0,
|
| 162 |
+
norm_threshold: float = 0.0,
|
| 163 |
+
use_original_formulation: bool = False,
|
| 164 |
+
):
|
| 165 |
+
diff = pred_cond - pred_uncond
|
| 166 |
+
dim = [-i for i in range(1, len(diff.shape))]
|
| 167 |
+
|
| 168 |
+
if momentum_buffer is not None:
|
| 169 |
+
momentum_buffer.update(diff)
|
| 170 |
+
diff = momentum_buffer.running_average
|
| 171 |
+
|
| 172 |
+
if norm_threshold > 0:
|
| 173 |
+
ones = torch.ones_like(diff)
|
| 174 |
+
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
| 175 |
+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
| 176 |
+
diff = diff * scale_factor
|
| 177 |
+
|
| 178 |
+
v0, v1 = diff.double(), pred_cond.double()
|
| 179 |
+
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
| 180 |
+
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
| 181 |
+
v0_orthogonal = v0 - v0_parallel
|
| 182 |
+
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
| 183 |
+
normalized_update = diff_orthogonal + eta * diff_parallel
|
| 184 |
+
|
| 185 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
| 186 |
+
pred = pred + guidance_scale * normalized_update
|
| 187 |
+
|
| 188 |
+
return pred
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/auto_guidance.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 22 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AutoGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
AutoGuidance: https://huggingface.co/papers/2406.02507
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 36 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 37 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 38 |
+
deterioration of image quality.
|
| 39 |
+
auto_guidance_layers (`int` or `List[int]`, *optional*):
|
| 40 |
+
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
| 41 |
+
provided, `skip_layer_config` must be provided.
|
| 42 |
+
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
| 43 |
+
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
| 44 |
+
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
| 45 |
+
dropout (`float`, *optional*):
|
| 46 |
+
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
|
| 47 |
+
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
|
| 48 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 49 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 50 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 51 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 52 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 53 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 54 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 55 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 56 |
+
start (`float`, defaults to `0.0`):
|
| 57 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 58 |
+
stop (`float`, defaults to `1.0`):
|
| 59 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 63 |
+
|
| 64 |
+
@register_to_config
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
guidance_scale: float = 7.5,
|
| 68 |
+
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
|
| 69 |
+
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
| 70 |
+
dropout: Optional[float] = None,
|
| 71 |
+
guidance_rescale: float = 0.0,
|
| 72 |
+
use_original_formulation: bool = False,
|
| 73 |
+
start: float = 0.0,
|
| 74 |
+
stop: float = 1.0,
|
| 75 |
+
):
|
| 76 |
+
super().__init__(start, stop)
|
| 77 |
+
|
| 78 |
+
self.guidance_scale = guidance_scale
|
| 79 |
+
self.auto_guidance_layers = auto_guidance_layers
|
| 80 |
+
self.auto_guidance_config = auto_guidance_config
|
| 81 |
+
self.dropout = dropout
|
| 82 |
+
self.guidance_rescale = guidance_rescale
|
| 83 |
+
self.use_original_formulation = use_original_formulation
|
| 84 |
+
|
| 85 |
+
is_layer_or_config_provided = auto_guidance_layers is not None or auto_guidance_config is not None
|
| 86 |
+
is_layer_and_config_provided = auto_guidance_layers is not None and auto_guidance_config is not None
|
| 87 |
+
if not is_layer_or_config_provided:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable AutoGuidance."
|
| 90 |
+
)
|
| 91 |
+
if is_layer_and_config_provided:
|
| 92 |
+
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
|
| 93 |
+
if auto_guidance_config is None and dropout is None:
|
| 94 |
+
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
|
| 95 |
+
|
| 96 |
+
if auto_guidance_layers is not None:
|
| 97 |
+
if isinstance(auto_guidance_layers, int):
|
| 98 |
+
auto_guidance_layers = [auto_guidance_layers]
|
| 99 |
+
if not isinstance(auto_guidance_layers, list):
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
|
| 102 |
+
)
|
| 103 |
+
auto_guidance_config = [
|
| 104 |
+
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
if isinstance(auto_guidance_config, dict):
|
| 108 |
+
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
|
| 109 |
+
|
| 110 |
+
if isinstance(auto_guidance_config, LayerSkipConfig):
|
| 111 |
+
auto_guidance_config = [auto_guidance_config]
|
| 112 |
+
|
| 113 |
+
if not isinstance(auto_guidance_config, list):
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
|
| 116 |
+
)
|
| 117 |
+
elif isinstance(next(iter(auto_guidance_config), None), dict):
|
| 118 |
+
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
|
| 119 |
+
|
| 120 |
+
self.auto_guidance_config = auto_guidance_config
|
| 121 |
+
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
|
| 122 |
+
|
| 123 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 124 |
+
self._count_prepared += 1
|
| 125 |
+
if self._is_ag_enabled() and self.is_unconditional:
|
| 126 |
+
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
| 127 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 128 |
+
|
| 129 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 130 |
+
if self._is_ag_enabled() and self.is_unconditional:
|
| 131 |
+
for name in self._auto_guidance_hook_names:
|
| 132 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 133 |
+
registry.remove_hook(name, recurse=True)
|
| 134 |
+
|
| 135 |
+
def prepare_inputs(
|
| 136 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 137 |
+
) -> List["BlockState"]:
|
| 138 |
+
if input_fields is None:
|
| 139 |
+
input_fields = self._input_fields
|
| 140 |
+
|
| 141 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 142 |
+
data_batches = []
|
| 143 |
+
for i in range(self.num_conditions):
|
| 144 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
| 145 |
+
data_batches.append(data_batch)
|
| 146 |
+
return data_batches
|
| 147 |
+
|
| 148 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
| 149 |
+
pred = None
|
| 150 |
+
|
| 151 |
+
if not self._is_ag_enabled():
|
| 152 |
+
pred = pred_cond
|
| 153 |
+
else:
|
| 154 |
+
shift = pred_cond - pred_uncond
|
| 155 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 156 |
+
pred = pred + self.guidance_scale * shift
|
| 157 |
+
|
| 158 |
+
if self.guidance_rescale > 0.0:
|
| 159 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 160 |
+
|
| 161 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def is_conditional(self) -> bool:
|
| 165 |
+
return self._count_prepared == 1
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def num_conditions(self) -> int:
|
| 169 |
+
num_conditions = 1
|
| 170 |
+
if self._is_ag_enabled():
|
| 171 |
+
num_conditions += 1
|
| 172 |
+
return num_conditions
|
| 173 |
+
|
| 174 |
+
def _is_ag_enabled(self) -> bool:
|
| 175 |
+
if not self._enabled:
|
| 176 |
+
return False
|
| 177 |
+
|
| 178 |
+
is_within_range = True
|
| 179 |
+
if self._num_inference_steps is not None:
|
| 180 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 181 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 182 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 183 |
+
|
| 184 |
+
is_close = False
|
| 185 |
+
if self.use_original_formulation:
|
| 186 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 187 |
+
else:
|
| 188 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 189 |
+
|
| 190 |
+
return is_within_range and not is_close
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_guidance.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ClassifierFreeGuidance(BaseGuidance):
|
| 29 |
+
"""
|
| 30 |
+
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
| 31 |
+
|
| 32 |
+
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
| 33 |
+
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
| 34 |
+
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
|
| 35 |
+
proposes scaling and shifting the conditional distribution based on the difference between conditional and
|
| 36 |
+
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
| 37 |
+
|
| 38 |
+
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
| 39 |
+
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
| 40 |
+
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
| 41 |
+
|
| 42 |
+
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
| 43 |
+
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
| 44 |
+
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
| 45 |
+
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
| 46 |
+
|
| 47 |
+
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
| 48 |
+
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 52 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 53 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 54 |
+
deterioration of image quality.
|
| 55 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 56 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 57 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 58 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 59 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 60 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 61 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 62 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 63 |
+
start (`float`, defaults to `0.0`):
|
| 64 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 65 |
+
stop (`float`, defaults to `1.0`):
|
| 66 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 70 |
+
|
| 71 |
+
@register_to_config
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
guidance_scale: float = 7.5,
|
| 75 |
+
guidance_rescale: float = 0.0,
|
| 76 |
+
use_original_formulation: bool = False,
|
| 77 |
+
start: float = 0.0,
|
| 78 |
+
stop: float = 1.0,
|
| 79 |
+
):
|
| 80 |
+
super().__init__(start, stop)
|
| 81 |
+
|
| 82 |
+
self.guidance_scale = guidance_scale
|
| 83 |
+
self.guidance_rescale = guidance_rescale
|
| 84 |
+
self.use_original_formulation = use_original_formulation
|
| 85 |
+
|
| 86 |
+
def prepare_inputs(
|
| 87 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 88 |
+
) -> List["BlockState"]:
|
| 89 |
+
if input_fields is None:
|
| 90 |
+
input_fields = self._input_fields
|
| 91 |
+
|
| 92 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 93 |
+
data_batches = []
|
| 94 |
+
for i in range(self.num_conditions):
|
| 95 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
| 96 |
+
data_batches.append(data_batch)
|
| 97 |
+
return data_batches
|
| 98 |
+
|
| 99 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
| 100 |
+
pred = None
|
| 101 |
+
|
| 102 |
+
if not self._is_cfg_enabled():
|
| 103 |
+
pred = pred_cond
|
| 104 |
+
else:
|
| 105 |
+
shift = pred_cond - pred_uncond
|
| 106 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 107 |
+
pred = pred + self.guidance_scale * shift
|
| 108 |
+
|
| 109 |
+
if self.guidance_rescale > 0.0:
|
| 110 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 111 |
+
|
| 112 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def is_conditional(self) -> bool:
|
| 116 |
+
return self._count_prepared == 1
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def num_conditions(self) -> int:
|
| 120 |
+
num_conditions = 1
|
| 121 |
+
if self._is_cfg_enabled():
|
| 122 |
+
num_conditions += 1
|
| 123 |
+
return num_conditions
|
| 124 |
+
|
| 125 |
+
def _is_cfg_enabled(self) -> bool:
|
| 126 |
+
if not self._enabled:
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
is_within_range = True
|
| 130 |
+
if self._num_inference_steps is not None:
|
| 131 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 132 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 133 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 134 |
+
|
| 135 |
+
is_close = False
|
| 136 |
+
if self.use_original_formulation:
|
| 137 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 138 |
+
else:
|
| 139 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 140 |
+
|
| 141 |
+
return is_within_range and not is_close
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/classifier_free_zero_star_guidance.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
| 29 |
+
"""
|
| 30 |
+
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
| 31 |
+
|
| 32 |
+
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
| 33 |
+
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
| 34 |
+
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
| 35 |
+
quality of generated images.
|
| 36 |
+
|
| 37 |
+
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 41 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 42 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 43 |
+
deterioration of image quality.
|
| 44 |
+
zero_init_steps (`int`, defaults to `1`):
|
| 45 |
+
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
| 46 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 47 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 48 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 49 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 50 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 51 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 52 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 53 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 54 |
+
start (`float`, defaults to `0.01`):
|
| 55 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 56 |
+
stop (`float`, defaults to `0.2`):
|
| 57 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 61 |
+
|
| 62 |
+
@register_to_config
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
guidance_scale: float = 7.5,
|
| 66 |
+
zero_init_steps: int = 1,
|
| 67 |
+
guidance_rescale: float = 0.0,
|
| 68 |
+
use_original_formulation: bool = False,
|
| 69 |
+
start: float = 0.0,
|
| 70 |
+
stop: float = 1.0,
|
| 71 |
+
):
|
| 72 |
+
super().__init__(start, stop)
|
| 73 |
+
|
| 74 |
+
self.guidance_scale = guidance_scale
|
| 75 |
+
self.zero_init_steps = zero_init_steps
|
| 76 |
+
self.guidance_rescale = guidance_rescale
|
| 77 |
+
self.use_original_formulation = use_original_formulation
|
| 78 |
+
|
| 79 |
+
def prepare_inputs(
|
| 80 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 81 |
+
) -> List["BlockState"]:
|
| 82 |
+
if input_fields is None:
|
| 83 |
+
input_fields = self._input_fields
|
| 84 |
+
|
| 85 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 86 |
+
data_batches = []
|
| 87 |
+
for i in range(self.num_conditions):
|
| 88 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
| 89 |
+
data_batches.append(data_batch)
|
| 90 |
+
return data_batches
|
| 91 |
+
|
| 92 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
| 93 |
+
pred = None
|
| 94 |
+
|
| 95 |
+
if self._step < self.zero_init_steps:
|
| 96 |
+
pred = torch.zeros_like(pred_cond)
|
| 97 |
+
elif not self._is_cfg_enabled():
|
| 98 |
+
pred = pred_cond
|
| 99 |
+
else:
|
| 100 |
+
pred_cond_flat = pred_cond.flatten(1)
|
| 101 |
+
pred_uncond_flat = pred_uncond.flatten(1)
|
| 102 |
+
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
| 103 |
+
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
| 104 |
+
pred_uncond = pred_uncond * alpha
|
| 105 |
+
shift = pred_cond - pred_uncond
|
| 106 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 107 |
+
pred = pred + self.guidance_scale * shift
|
| 108 |
+
|
| 109 |
+
if self.guidance_rescale > 0.0:
|
| 110 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 111 |
+
|
| 112 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def is_conditional(self) -> bool:
|
| 116 |
+
return self._count_prepared == 1
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def num_conditions(self) -> int:
|
| 120 |
+
num_conditions = 1
|
| 121 |
+
if self._is_cfg_enabled():
|
| 122 |
+
num_conditions += 1
|
| 123 |
+
return num_conditions
|
| 124 |
+
|
| 125 |
+
def _is_cfg_enabled(self) -> bool:
|
| 126 |
+
if not self._enabled:
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
is_within_range = True
|
| 130 |
+
if self._num_inference_steps is not None:
|
| 131 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 132 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 133 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 134 |
+
|
| 135 |
+
is_close = False
|
| 136 |
+
if self.use_original_formulation:
|
| 137 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 138 |
+
else:
|
| 139 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 140 |
+
|
| 141 |
+
return is_within_range and not is_close
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| 145 |
+
cond_dtype = cond.dtype
|
| 146 |
+
cond = cond.float()
|
| 147 |
+
uncond = uncond.float()
|
| 148 |
+
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
| 149 |
+
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
| 150 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 151 |
+
scale = dot_product / squared_norm
|
| 152 |
+
return scale.to(dtype=cond_dtype)
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/frequency_decoupled_guidance.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..utils import is_kornia_available
|
| 22 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
_CAN_USE_KORNIA = is_kornia_available()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if _CAN_USE_KORNIA:
|
| 33 |
+
from kornia.geometry import pyrup as upsample_and_blur_func
|
| 34 |
+
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
|
| 35 |
+
else:
|
| 36 |
+
upsample_and_blur_func = None
|
| 37 |
+
build_laplacian_pyramid_func = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 41 |
+
"""
|
| 42 |
+
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
| 43 |
+
(Algorithm 2).
|
| 44 |
+
"""
|
| 45 |
+
# v0 shape: [B, ...]
|
| 46 |
+
# v1 shape: [B, ...]
|
| 47 |
+
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
|
| 48 |
+
all_dims_but_first = list(range(1, len(v0.shape)))
|
| 49 |
+
if upcast_to_double:
|
| 50 |
+
dtype = v0.dtype
|
| 51 |
+
v0, v1 = v0.double(), v1.double()
|
| 52 |
+
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
|
| 53 |
+
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
|
| 54 |
+
v0_orthogonal = v0 - v0_parallel
|
| 55 |
+
if upcast_to_double:
|
| 56 |
+
v0_parallel = v0_parallel.to(dtype)
|
| 57 |
+
v0_orthogonal = v0_orthogonal.to(dtype)
|
| 58 |
+
return v0_parallel, v0_orthogonal
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
| 64 |
+
(Algorithm 2).
|
| 65 |
+
"""
|
| 66 |
+
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
|
| 67 |
+
img = pyramid[-1]
|
| 68 |
+
for i in range(len(pyramid) - 2, -1, -1):
|
| 69 |
+
img = upsample_and_blur_func(img) + pyramid[i]
|
| 70 |
+
return img
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class FrequencyDecoupledGuidance(BaseGuidance):
|
| 74 |
+
"""
|
| 75 |
+
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
|
| 76 |
+
|
| 77 |
+
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
|
| 78 |
+
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
|
| 79 |
+
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
|
| 80 |
+
how CFG works, you can check out the CFG guider.)
|
| 81 |
+
|
| 82 |
+
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
|
| 83 |
+
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
|
| 84 |
+
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
|
| 85 |
+
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
|
| 86 |
+
to form the final FDG prediction.
|
| 87 |
+
|
| 88 |
+
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
|
| 89 |
+
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
|
| 90 |
+
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
|
| 91 |
+
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
|
| 92 |
+
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
|
| 93 |
+
|
| 94 |
+
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
|
| 95 |
+
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
|
| 96 |
+
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
| 97 |
+
|
| 98 |
+
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
| 99 |
+
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
|
| 103 |
+
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
| 104 |
+
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
| 105 |
+
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
| 106 |
+
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
| 107 |
+
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
| 108 |
+
descending order).
|
| 109 |
+
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
|
| 110 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 111 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 112 |
+
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
| 113 |
+
`guidance_scales`.
|
| 114 |
+
parallel_weights (`float` or `List[float]`, *optional*):
|
| 115 |
+
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
| 116 |
+
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
| 117 |
+
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
| 118 |
+
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
|
| 119 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 120 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 121 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 122 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 123 |
+
start (`float` or `List[float]`, defaults to `0.0`):
|
| 124 |
+
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
| 125 |
+
should be the same length as `guidance_scales`.
|
| 126 |
+
stop (`float` or `List[float]`, defaults to `1.0`):
|
| 127 |
+
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
| 128 |
+
should be the same length as `guidance_scales`.
|
| 129 |
+
guidance_rescale_space (`str`, defaults to `"data"`):
|
| 130 |
+
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
|
| 131 |
+
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
|
| 132 |
+
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
|
| 133 |
+
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
|
| 134 |
+
upcast_to_double (`bool`, defaults to `True`):
|
| 135 |
+
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
|
| 136 |
+
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
_input_predictions = ["pred_cond", "pred_uncond"]
|
| 140 |
+
|
| 141 |
+
@register_to_config
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
|
| 145 |
+
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
|
| 146 |
+
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
|
| 147 |
+
use_original_formulation: bool = False,
|
| 148 |
+
start: Union[float, List[float], Tuple[float]] = 0.0,
|
| 149 |
+
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
| 150 |
+
guidance_rescale_space: str = "data",
|
| 151 |
+
upcast_to_double: bool = True,
|
| 152 |
+
):
|
| 153 |
+
if not _CAN_USE_KORNIA:
|
| 154 |
+
raise ImportError(
|
| 155 |
+
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
|
| 156 |
+
"it depends is not available in the current environment. You can install `kornia` with `pip install "
|
| 157 |
+
"kornia`."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
| 161 |
+
min_start = start if isinstance(start, float) else min(start)
|
| 162 |
+
max_stop = stop if isinstance(stop, float) else max(stop)
|
| 163 |
+
super().__init__(min_start, max_stop)
|
| 164 |
+
|
| 165 |
+
self.guidance_scales = guidance_scales
|
| 166 |
+
self.levels = len(guidance_scales)
|
| 167 |
+
|
| 168 |
+
if isinstance(guidance_rescale, float):
|
| 169 |
+
self.guidance_rescale = [guidance_rescale] * self.levels
|
| 170 |
+
elif len(guidance_rescale) == self.levels:
|
| 171 |
+
self.guidance_rescale = guidance_rescale
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
|
| 175 |
+
f"`guidance_scales` ({len(self.guidance_scales)})"
|
| 176 |
+
)
|
| 177 |
+
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
|
| 178 |
+
# transforming from frequency space back to data space)
|
| 179 |
+
if guidance_rescale_space not in ["data", "freq"]:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
|
| 182 |
+
)
|
| 183 |
+
self.guidance_rescale_space = guidance_rescale_space
|
| 184 |
+
|
| 185 |
+
if parallel_weights is None:
|
| 186 |
+
# Use normal CFG shift (equal weights for parallel and orthogonal components)
|
| 187 |
+
self.parallel_weights = [1.0] * self.levels
|
| 188 |
+
elif isinstance(parallel_weights, float):
|
| 189 |
+
self.parallel_weights = [parallel_weights] * self.levels
|
| 190 |
+
elif len(parallel_weights) == self.levels:
|
| 191 |
+
self.parallel_weights = parallel_weights
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
|
| 195 |
+
f"`guidance_scales` ({len(self.guidance_scales)})"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self.use_original_formulation = use_original_formulation
|
| 199 |
+
self.upcast_to_double = upcast_to_double
|
| 200 |
+
|
| 201 |
+
if isinstance(start, float):
|
| 202 |
+
self.guidance_start = [start] * self.levels
|
| 203 |
+
elif len(start) == self.levels:
|
| 204 |
+
self.guidance_start = start
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
|
| 208 |
+
f"({len(self.guidance_scales)})"
|
| 209 |
+
)
|
| 210 |
+
if isinstance(stop, float):
|
| 211 |
+
self.guidance_stop = [stop] * self.levels
|
| 212 |
+
elif len(stop) == self.levels:
|
| 213 |
+
self.guidance_stop = stop
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
|
| 217 |
+
f"({len(self.guidance_scales)})"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def prepare_inputs(
|
| 221 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 222 |
+
) -> List["BlockState"]:
|
| 223 |
+
if input_fields is None:
|
| 224 |
+
input_fields = self._input_fields
|
| 225 |
+
|
| 226 |
+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
| 227 |
+
data_batches = []
|
| 228 |
+
for i in range(self.num_conditions):
|
| 229 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
| 230 |
+
data_batches.append(data_batch)
|
| 231 |
+
return data_batches
|
| 232 |
+
|
| 233 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
| 234 |
+
pred = None
|
| 235 |
+
|
| 236 |
+
if not self._is_fdg_enabled():
|
| 237 |
+
pred = pred_cond
|
| 238 |
+
else:
|
| 239 |
+
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
|
| 240 |
+
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
|
| 241 |
+
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
|
| 242 |
+
|
| 243 |
+
# From high frequencies to low frequencies, following the paper implementation
|
| 244 |
+
pred_guided_pyramid = []
|
| 245 |
+
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
|
| 246 |
+
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
|
| 247 |
+
if self._is_fdg_enabled_for_level(level):
|
| 248 |
+
# Get the cond/uncond preds (in freq space) at the current frequency level
|
| 249 |
+
pred_cond_freq = pred_cond_pyramid[level]
|
| 250 |
+
pred_uncond_freq = pred_uncond_pyramid[level]
|
| 251 |
+
|
| 252 |
+
shift = pred_cond_freq - pred_uncond_freq
|
| 253 |
+
|
| 254 |
+
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
|
| 255 |
+
if not math.isclose(parallel_weight, 1.0):
|
| 256 |
+
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
|
| 257 |
+
shift = parallel_weight * shift_parallel + shift_orthogonal
|
| 258 |
+
|
| 259 |
+
# Apply CFG update for the current frequency level
|
| 260 |
+
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
|
| 261 |
+
pred = pred + guidance_scale * shift
|
| 262 |
+
|
| 263 |
+
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
|
| 264 |
+
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
|
| 265 |
+
|
| 266 |
+
# Add the current FDG guided level to the FDG prediction pyramid
|
| 267 |
+
pred_guided_pyramid.append(pred)
|
| 268 |
+
else:
|
| 269 |
+
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
|
| 270 |
+
pred_guided_pyramid.append(pred_cond_freq)
|
| 271 |
+
|
| 272 |
+
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
|
| 273 |
+
pred = build_image_from_pyramid(pred_guided_pyramid)
|
| 274 |
+
|
| 275 |
+
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
|
| 276 |
+
# across all freq levels
|
| 277 |
+
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
|
| 278 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
|
| 279 |
+
|
| 280 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 281 |
+
|
| 282 |
+
@property
|
| 283 |
+
def is_conditional(self) -> bool:
|
| 284 |
+
return self._count_prepared == 1
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
def num_conditions(self) -> int:
|
| 288 |
+
num_conditions = 1
|
| 289 |
+
if self._is_fdg_enabled():
|
| 290 |
+
num_conditions += 1
|
| 291 |
+
return num_conditions
|
| 292 |
+
|
| 293 |
+
def _is_fdg_enabled(self) -> bool:
|
| 294 |
+
if not self._enabled:
|
| 295 |
+
return False
|
| 296 |
+
|
| 297 |
+
is_within_range = True
|
| 298 |
+
if self._num_inference_steps is not None:
|
| 299 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 300 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 301 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 302 |
+
|
| 303 |
+
is_close = False
|
| 304 |
+
if self.use_original_formulation:
|
| 305 |
+
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
|
| 306 |
+
else:
|
| 307 |
+
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
|
| 308 |
+
|
| 309 |
+
return is_within_range and not is_close
|
| 310 |
+
|
| 311 |
+
def _is_fdg_enabled_for_level(self, level: int) -> bool:
|
| 312 |
+
if not self._enabled:
|
| 313 |
+
return False
|
| 314 |
+
|
| 315 |
+
is_within_range = True
|
| 316 |
+
if self._num_inference_steps is not None:
|
| 317 |
+
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
|
| 318 |
+
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
|
| 319 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 320 |
+
|
| 321 |
+
is_close = False
|
| 322 |
+
if self.use_original_formulation:
|
| 323 |
+
is_close = math.isclose(self.guidance_scales[level], 0.0)
|
| 324 |
+
else:
|
| 325 |
+
is_close = math.isclose(self.guidance_scales[level], 1.0)
|
| 326 |
+
|
| 327 |
+
return is_within_range and not is_close
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/guider_utils.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
| 20 |
+
from typing_extensions import Self
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import ConfigMixin
|
| 23 |
+
from ..utils import BaseOutput, PushToHubMixin, get_logger
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
GUIDER_CONFIG_NAME = "guider_config.json"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BaseGuidance(ConfigMixin, PushToHubMixin):
|
| 37 |
+
r"""Base class providing the skeleton for implementing guidance techniques."""
|
| 38 |
+
|
| 39 |
+
config_name = GUIDER_CONFIG_NAME
|
| 40 |
+
_input_predictions = None
|
| 41 |
+
_identifier_key = "__guidance_identifier__"
|
| 42 |
+
|
| 43 |
+
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
| 44 |
+
self._start = start
|
| 45 |
+
self._stop = stop
|
| 46 |
+
self._step: int = None
|
| 47 |
+
self._num_inference_steps: int = None
|
| 48 |
+
self._timestep: torch.LongTensor = None
|
| 49 |
+
self._count_prepared = 0
|
| 50 |
+
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
| 51 |
+
self._enabled = True
|
| 52 |
+
|
| 53 |
+
if not (0.0 <= start < 1.0):
|
| 54 |
+
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
| 55 |
+
if not (start <= stop <= 1.0):
|
| 56 |
+
raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
|
| 57 |
+
|
| 58 |
+
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
| 59 |
+
raise ValueError(
|
| 60 |
+
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def disable(self):
|
| 64 |
+
self._enabled = False
|
| 65 |
+
|
| 66 |
+
def enable(self):
|
| 67 |
+
self._enabled = True
|
| 68 |
+
|
| 69 |
+
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
| 70 |
+
self._step = step
|
| 71 |
+
self._num_inference_steps = num_inference_steps
|
| 72 |
+
self._timestep = timestep
|
| 73 |
+
self._count_prepared = 0
|
| 74 |
+
|
| 75 |
+
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
|
| 78 |
+
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
|
| 79 |
+
the values of the provided keyword arguments to this method.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
| 83 |
+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
| 84 |
+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
| 85 |
+
to look up the required data provided for preparation.
|
| 86 |
+
|
| 87 |
+
If a string is provided, it will be used as the conditional data (or unconditional if used with a
|
| 88 |
+
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
|
| 89 |
+
conditional data identifier and the second element must be the unconditional data identifier or None.
|
| 90 |
+
|
| 91 |
+
Example:
|
| 92 |
+
```
|
| 93 |
+
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
| 94 |
+
|
| 95 |
+
BaseGuidance.set_input_fields(
|
| 96 |
+
latents="latents",
|
| 97 |
+
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
| 98 |
+
)
|
| 99 |
+
```
|
| 100 |
+
"""
|
| 101 |
+
for key, value in kwargs.items():
|
| 102 |
+
is_string = isinstance(value, str)
|
| 103 |
+
is_tuple_of_str_with_len_2 = (
|
| 104 |
+
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
| 105 |
+
)
|
| 106 |
+
if not (is_string or is_tuple_of_str_with_len_2):
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
| 109 |
+
)
|
| 110 |
+
self._input_fields = kwargs
|
| 111 |
+
|
| 112 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 113 |
+
"""
|
| 114 |
+
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
| 115 |
+
subclasses to implement specific model preparation logic.
|
| 116 |
+
"""
|
| 117 |
+
self._count_prepared += 1
|
| 118 |
+
|
| 119 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 120 |
+
"""
|
| 121 |
+
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
|
| 122 |
+
in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
|
| 123 |
+
modifications made during `prepare_models`.
|
| 124 |
+
"""
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
| 128 |
+
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
| 129 |
+
|
| 130 |
+
def __call__(self, data: List["BlockState"]) -> Any:
|
| 131 |
+
if not all(hasattr(d, "noise_pred") for d in data):
|
| 132 |
+
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
| 133 |
+
if len(data) != self.num_conditions:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
|
| 136 |
+
)
|
| 137 |
+
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
|
| 138 |
+
return self.forward(**forward_inputs)
|
| 139 |
+
|
| 140 |
+
def forward(self, *args, **kwargs) -> Any:
|
| 141 |
+
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def is_conditional(self) -> bool:
|
| 145 |
+
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def is_unconditional(self) -> bool:
|
| 149 |
+
return not self.is_conditional
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def num_conditions(self) -> int:
|
| 153 |
+
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
| 154 |
+
|
| 155 |
+
@classmethod
|
| 156 |
+
def _prepare_batch(
|
| 157 |
+
cls,
|
| 158 |
+
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
| 159 |
+
data: "BlockState",
|
| 160 |
+
tuple_index: int,
|
| 161 |
+
identifier: str,
|
| 162 |
+
) -> "BlockState":
|
| 163 |
+
"""
|
| 164 |
+
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
| 165 |
+
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
| 169 |
+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
| 170 |
+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
| 171 |
+
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
| 172 |
+
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
| 173 |
+
length 2 is provided, the first element must be the conditional data identifier and the second element
|
| 174 |
+
must be the unconditional data identifier or None.
|
| 175 |
+
data (`BlockState`):
|
| 176 |
+
The input data to be prepared.
|
| 177 |
+
tuple_index (`int`):
|
| 178 |
+
The index to use when accessing input fields that are tuples.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
`BlockState`: The prepared batch of data.
|
| 182 |
+
"""
|
| 183 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 184 |
+
|
| 185 |
+
if input_fields is None:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
|
| 188 |
+
)
|
| 189 |
+
data_batch = {}
|
| 190 |
+
for key, value in input_fields.items():
|
| 191 |
+
try:
|
| 192 |
+
if isinstance(value, str):
|
| 193 |
+
data_batch[key] = getattr(data, value)
|
| 194 |
+
elif isinstance(value, tuple):
|
| 195 |
+
data_batch[key] = getattr(data, value[tuple_index])
|
| 196 |
+
else:
|
| 197 |
+
# We've already checked that value is a string or a tuple of strings with length 2
|
| 198 |
+
pass
|
| 199 |
+
except AttributeError:
|
| 200 |
+
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
| 201 |
+
data_batch[cls._identifier_key] = identifier
|
| 202 |
+
return BlockState(**data_batch)
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
@validate_hf_hub_args
|
| 206 |
+
def from_pretrained(
|
| 207 |
+
cls,
|
| 208 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 209 |
+
subfolder: Optional[str] = None,
|
| 210 |
+
return_unused_kwargs=False,
|
| 211 |
+
**kwargs,
|
| 212 |
+
) -> Self:
|
| 213 |
+
r"""
|
| 214 |
+
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
|
| 215 |
+
|
| 216 |
+
Parameters:
|
| 217 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 218 |
+
Can be either:
|
| 219 |
+
|
| 220 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 221 |
+
the Hub.
|
| 222 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
|
| 223 |
+
saved with [`~BaseGuidance.save_pretrained`].
|
| 224 |
+
subfolder (`str`, *optional*):
|
| 225 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 226 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 227 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
| 228 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 229 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 230 |
+
is not used.
|
| 231 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 232 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 233 |
+
cached versions if they exist.
|
| 234 |
+
|
| 235 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 236 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 237 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 238 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 239 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 240 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 241 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 242 |
+
won't be downloaded from the Hub.
|
| 243 |
+
token (`str` or *bool*, *optional*):
|
| 244 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 245 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 246 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 247 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 248 |
+
allowed by Git.
|
| 249 |
+
|
| 250 |
+
<Tip>
|
| 251 |
+
|
| 252 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
|
| 253 |
+
auth login`. You can also activate the special
|
| 254 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
| 255 |
+
firewalled environment.
|
| 256 |
+
|
| 257 |
+
</Tip>
|
| 258 |
+
|
| 259 |
+
"""
|
| 260 |
+
config, kwargs, commit_hash = cls.load_config(
|
| 261 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 262 |
+
subfolder=subfolder,
|
| 263 |
+
return_unused_kwargs=True,
|
| 264 |
+
return_commit_hash=True,
|
| 265 |
+
**kwargs,
|
| 266 |
+
)
|
| 267 |
+
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
| 268 |
+
|
| 269 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
| 270 |
+
"""
|
| 271 |
+
Save a guider configuration object to a directory so that it can be reloaded using the
|
| 272 |
+
[`~BaseGuidance.from_pretrained`] class method.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
save_directory (`str` or `os.PathLike`):
|
| 276 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
| 277 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 278 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
| 279 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 280 |
+
namespace).
|
| 281 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 282 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 283 |
+
"""
|
| 284 |
+
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class GuiderOutput(BaseOutput):
|
| 288 |
+
pred: torch.Tensor
|
| 289 |
+
pred_cond: Optional[torch.Tensor]
|
| 290 |
+
pred_uncond: Optional[torch.Tensor]
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 294 |
+
r"""
|
| 295 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 296 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 297 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
noise_cfg (`torch.Tensor`):
|
| 301 |
+
The predicted noise tensor for the guided diffusion process.
|
| 302 |
+
noise_pred_text (`torch.Tensor`):
|
| 303 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 304 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 305 |
+
A rescale factor applied to the noise predictions.
|
| 306 |
+
Returns:
|
| 307 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 308 |
+
"""
|
| 309 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 310 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 311 |
+
# rescale the results from guidance (fixes overexposure)
|
| 312 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 313 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 314 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 315 |
+
return noise_cfg
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/perturbed_attention_guidance.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 22 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 23 |
+
from ..utils import get_logger
|
| 24 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PerturbedAttentionGuidance(BaseGuidance):
|
| 35 |
+
"""
|
| 36 |
+
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
|
| 37 |
+
|
| 38 |
+
The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
|
| 39 |
+
worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
|
| 40 |
+
of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
|
| 41 |
+
attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
|
| 42 |
+
layers.
|
| 43 |
+
|
| 44 |
+
Additional reading:
|
| 45 |
+
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
| 46 |
+
|
| 47 |
+
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
|
| 48 |
+
and implementation details.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 52 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 53 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 54 |
+
deterioration of image quality.
|
| 55 |
+
perturbed_guidance_scale (`float`, defaults to `2.8`):
|
| 56 |
+
The scale parameter for perturbed attention guidance.
|
| 57 |
+
perturbed_guidance_start (`float`, defaults to `0.01`):
|
| 58 |
+
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
|
| 59 |
+
perturbed_guidance_stop (`float`, defaults to `0.2`):
|
| 60 |
+
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
|
| 61 |
+
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
|
| 62 |
+
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
|
| 63 |
+
If not provided, `perturbed_guidance_config` must be provided.
|
| 64 |
+
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
| 65 |
+
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
|
| 66 |
+
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
|
| 67 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 68 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 69 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 70 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 71 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 72 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 73 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 74 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 75 |
+
start (`float`, defaults to `0.01`):
|
| 76 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 77 |
+
stop (`float`, defaults to `0.2`):
|
| 78 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
# NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
|
| 82 |
+
# the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
|
| 83 |
+
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
|
| 84 |
+
# for each model architecture.
|
| 85 |
+
|
| 86 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 87 |
+
|
| 88 |
+
@register_to_config
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
guidance_scale: float = 7.5,
|
| 92 |
+
perturbed_guidance_scale: float = 2.8,
|
| 93 |
+
perturbed_guidance_start: float = 0.01,
|
| 94 |
+
perturbed_guidance_stop: float = 0.2,
|
| 95 |
+
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
|
| 96 |
+
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
| 97 |
+
guidance_rescale: float = 0.0,
|
| 98 |
+
use_original_formulation: bool = False,
|
| 99 |
+
start: float = 0.0,
|
| 100 |
+
stop: float = 1.0,
|
| 101 |
+
):
|
| 102 |
+
super().__init__(start, stop)
|
| 103 |
+
|
| 104 |
+
self.guidance_scale = guidance_scale
|
| 105 |
+
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
| 106 |
+
self.skip_layer_guidance_start = perturbed_guidance_start
|
| 107 |
+
self.skip_layer_guidance_stop = perturbed_guidance_stop
|
| 108 |
+
self.guidance_rescale = guidance_rescale
|
| 109 |
+
self.use_original_formulation = use_original_formulation
|
| 110 |
+
|
| 111 |
+
if perturbed_guidance_config is None:
|
| 112 |
+
if perturbed_guidance_layers is None:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
|
| 115 |
+
)
|
| 116 |
+
perturbed_guidance_config = LayerSkipConfig(
|
| 117 |
+
indices=perturbed_guidance_layers,
|
| 118 |
+
fqn="auto",
|
| 119 |
+
skip_attention=False,
|
| 120 |
+
skip_attention_scores=True,
|
| 121 |
+
skip_ff=False,
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
if perturbed_guidance_layers is not None:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if isinstance(perturbed_guidance_config, dict):
|
| 130 |
+
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
|
| 131 |
+
|
| 132 |
+
if isinstance(perturbed_guidance_config, LayerSkipConfig):
|
| 133 |
+
perturbed_guidance_config = [perturbed_guidance_config]
|
| 134 |
+
|
| 135 |
+
if not isinstance(perturbed_guidance_config, list):
|
| 136 |
+
raise ValueError(
|
| 137 |
+
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
|
| 138 |
+
)
|
| 139 |
+
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
|
| 140 |
+
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
|
| 141 |
+
|
| 142 |
+
for config in perturbed_guidance_config:
|
| 143 |
+
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
|
| 144 |
+
logger.warning(
|
| 145 |
+
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
|
| 146 |
+
"Please check your configuration. Modifying the config to match the expected values."
|
| 147 |
+
)
|
| 148 |
+
config.skip_attention = False
|
| 149 |
+
config.skip_attention_scores = True
|
| 150 |
+
config.skip_ff = False
|
| 151 |
+
|
| 152 |
+
self.skip_layer_config = perturbed_guidance_config
|
| 153 |
+
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
| 154 |
+
|
| 155 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
|
| 156 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 157 |
+
self._count_prepared += 1
|
| 158 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 159 |
+
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
| 160 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 161 |
+
|
| 162 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
|
| 163 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 164 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 165 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 166 |
+
# Remove the hooks after inference
|
| 167 |
+
for hook_name in self._skip_layer_hook_names:
|
| 168 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 169 |
+
|
| 170 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
| 171 |
+
def prepare_inputs(
|
| 172 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 173 |
+
) -> List["BlockState"]:
|
| 174 |
+
if input_fields is None:
|
| 175 |
+
input_fields = self._input_fields
|
| 176 |
+
|
| 177 |
+
if self.num_conditions == 1:
|
| 178 |
+
tuple_indices = [0]
|
| 179 |
+
input_predictions = ["pred_cond"]
|
| 180 |
+
elif self.num_conditions == 2:
|
| 181 |
+
tuple_indices = [0, 1]
|
| 182 |
+
input_predictions = (
|
| 183 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
tuple_indices = [0, 1, 0]
|
| 187 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 188 |
+
data_batches = []
|
| 189 |
+
for i in range(self.num_conditions):
|
| 190 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
| 191 |
+
data_batches.append(data_batch)
|
| 192 |
+
return data_batches
|
| 193 |
+
|
| 194 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
pred_cond: torch.Tensor,
|
| 198 |
+
pred_uncond: Optional[torch.Tensor] = None,
|
| 199 |
+
pred_cond_skip: Optional[torch.Tensor] = None,
|
| 200 |
+
) -> GuiderOutput:
|
| 201 |
+
pred = None
|
| 202 |
+
|
| 203 |
+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
| 204 |
+
pred = pred_cond
|
| 205 |
+
elif not self._is_cfg_enabled():
|
| 206 |
+
shift = pred_cond - pred_cond_skip
|
| 207 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
| 208 |
+
pred = pred + self.skip_layer_guidance_scale * shift
|
| 209 |
+
elif not self._is_slg_enabled():
|
| 210 |
+
shift = pred_cond - pred_uncond
|
| 211 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 212 |
+
pred = pred + self.guidance_scale * shift
|
| 213 |
+
else:
|
| 214 |
+
shift = pred_cond - pred_uncond
|
| 215 |
+
shift_skip = pred_cond - pred_cond_skip
|
| 216 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 217 |
+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
| 218 |
+
|
| 219 |
+
if self.guidance_rescale > 0.0:
|
| 220 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 221 |
+
|
| 222 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
|
| 226 |
+
def is_conditional(self) -> bool:
|
| 227 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
|
| 231 |
+
def num_conditions(self) -> int:
|
| 232 |
+
num_conditions = 1
|
| 233 |
+
if self._is_cfg_enabled():
|
| 234 |
+
num_conditions += 1
|
| 235 |
+
if self._is_slg_enabled():
|
| 236 |
+
num_conditions += 1
|
| 237 |
+
return num_conditions
|
| 238 |
+
|
| 239 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
|
| 240 |
+
def _is_cfg_enabled(self) -> bool:
|
| 241 |
+
if not self._enabled:
|
| 242 |
+
return False
|
| 243 |
+
|
| 244 |
+
is_within_range = True
|
| 245 |
+
if self._num_inference_steps is not None:
|
| 246 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 247 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 248 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 249 |
+
|
| 250 |
+
is_close = False
|
| 251 |
+
if self.use_original_formulation:
|
| 252 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 253 |
+
else:
|
| 254 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 255 |
+
|
| 256 |
+
return is_within_range and not is_close
|
| 257 |
+
|
| 258 |
+
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
|
| 259 |
+
def _is_slg_enabled(self) -> bool:
|
| 260 |
+
if not self._enabled:
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
is_within_range = True
|
| 264 |
+
if self._num_inference_steps is not None:
|
| 265 |
+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
| 266 |
+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
| 267 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 268 |
+
|
| 269 |
+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
| 270 |
+
|
| 271 |
+
return is_within_range and not is_zero
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/skip_layer_guidance.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..hooks import HookRegistry, LayerSkipConfig
|
| 22 |
+
from ..hooks.layer_skip import _apply_layer_skip_hook
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SkipLayerGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
| 33 |
+
|
| 34 |
+
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
|
| 35 |
+
|
| 36 |
+
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
| 37 |
+
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
| 38 |
+
batch of data, apart from the conditional and unconditional batches already used in CFG
|
| 39 |
+
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
| 40 |
+
based on the difference between conditional without skipping and conditional with skipping predictions.
|
| 41 |
+
|
| 42 |
+
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
| 43 |
+
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
| 44 |
+
version of the model for the conditional prediction).
|
| 45 |
+
|
| 46 |
+
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
| 47 |
+
generation quality in video diffusion models.
|
| 48 |
+
|
| 49 |
+
Additional reading:
|
| 50 |
+
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
| 51 |
+
|
| 52 |
+
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
| 53 |
+
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 57 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 58 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 59 |
+
deterioration of image quality.
|
| 60 |
+
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
| 61 |
+
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
| 62 |
+
values, but it may also lead to overexposure and saturation.
|
| 63 |
+
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
| 64 |
+
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
| 65 |
+
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
| 66 |
+
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
| 67 |
+
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
| 68 |
+
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
| 69 |
+
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
| 70 |
+
3.5 Medium.
|
| 71 |
+
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
| 72 |
+
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
| 73 |
+
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
| 74 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 75 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 76 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 77 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 78 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 79 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 80 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 81 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 82 |
+
start (`float`, defaults to `0.01`):
|
| 83 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 84 |
+
stop (`float`, defaults to `0.2`):
|
| 85 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 89 |
+
|
| 90 |
+
@register_to_config
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
guidance_scale: float = 7.5,
|
| 94 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 95 |
+
skip_layer_guidance_start: float = 0.01,
|
| 96 |
+
skip_layer_guidance_stop: float = 0.2,
|
| 97 |
+
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
| 98 |
+
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
| 99 |
+
guidance_rescale: float = 0.0,
|
| 100 |
+
use_original_formulation: bool = False,
|
| 101 |
+
start: float = 0.0,
|
| 102 |
+
stop: float = 1.0,
|
| 103 |
+
):
|
| 104 |
+
super().__init__(start, stop)
|
| 105 |
+
|
| 106 |
+
self.guidance_scale = guidance_scale
|
| 107 |
+
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 108 |
+
self.skip_layer_guidance_start = skip_layer_guidance_start
|
| 109 |
+
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
| 110 |
+
self.guidance_rescale = guidance_rescale
|
| 111 |
+
self.use_original_formulation = use_original_formulation
|
| 112 |
+
|
| 113 |
+
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
| 116 |
+
)
|
| 117 |
+
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
| 125 |
+
)
|
| 126 |
+
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
| 127 |
+
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
| 128 |
+
|
| 129 |
+
if skip_layer_guidance_layers is not None:
|
| 130 |
+
if isinstance(skip_layer_guidance_layers, int):
|
| 131 |
+
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
| 132 |
+
if not isinstance(skip_layer_guidance_layers, list):
|
| 133 |
+
raise ValueError(
|
| 134 |
+
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
| 135 |
+
)
|
| 136 |
+
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
| 137 |
+
|
| 138 |
+
if isinstance(skip_layer_config, dict):
|
| 139 |
+
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
|
| 140 |
+
|
| 141 |
+
if isinstance(skip_layer_config, LayerSkipConfig):
|
| 142 |
+
skip_layer_config = [skip_layer_config]
|
| 143 |
+
|
| 144 |
+
if not isinstance(skip_layer_config, list):
|
| 145 |
+
raise ValueError(
|
| 146 |
+
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
| 147 |
+
)
|
| 148 |
+
elif isinstance(next(iter(skip_layer_config), None), dict):
|
| 149 |
+
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
|
| 150 |
+
|
| 151 |
+
self.skip_layer_config = skip_layer_config
|
| 152 |
+
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
| 153 |
+
|
| 154 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 155 |
+
self._count_prepared += 1
|
| 156 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 157 |
+
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
| 158 |
+
_apply_layer_skip_hook(denoiser, config, name=name)
|
| 159 |
+
|
| 160 |
+
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
| 161 |
+
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 162 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 163 |
+
# Remove the hooks after inference
|
| 164 |
+
for hook_name in self._skip_layer_hook_names:
|
| 165 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 166 |
+
|
| 167 |
+
def prepare_inputs(
|
| 168 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 169 |
+
) -> List["BlockState"]:
|
| 170 |
+
if input_fields is None:
|
| 171 |
+
input_fields = self._input_fields
|
| 172 |
+
|
| 173 |
+
if self.num_conditions == 1:
|
| 174 |
+
tuple_indices = [0]
|
| 175 |
+
input_predictions = ["pred_cond"]
|
| 176 |
+
elif self.num_conditions == 2:
|
| 177 |
+
tuple_indices = [0, 1]
|
| 178 |
+
input_predictions = (
|
| 179 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
tuple_indices = [0, 1, 0]
|
| 183 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
| 184 |
+
data_batches = []
|
| 185 |
+
for i in range(self.num_conditions):
|
| 186 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
| 187 |
+
data_batches.append(data_batch)
|
| 188 |
+
return data_batches
|
| 189 |
+
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
pred_cond: torch.Tensor,
|
| 193 |
+
pred_uncond: Optional[torch.Tensor] = None,
|
| 194 |
+
pred_cond_skip: Optional[torch.Tensor] = None,
|
| 195 |
+
) -> GuiderOutput:
|
| 196 |
+
pred = None
|
| 197 |
+
|
| 198 |
+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
| 199 |
+
pred = pred_cond
|
| 200 |
+
elif not self._is_cfg_enabled():
|
| 201 |
+
shift = pred_cond - pred_cond_skip
|
| 202 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
| 203 |
+
pred = pred + self.skip_layer_guidance_scale * shift
|
| 204 |
+
elif not self._is_slg_enabled():
|
| 205 |
+
shift = pred_cond - pred_uncond
|
| 206 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 207 |
+
pred = pred + self.guidance_scale * shift
|
| 208 |
+
else:
|
| 209 |
+
shift = pred_cond - pred_uncond
|
| 210 |
+
shift_skip = pred_cond - pred_cond_skip
|
| 211 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 212 |
+
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
| 213 |
+
|
| 214 |
+
if self.guidance_rescale > 0.0:
|
| 215 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 216 |
+
|
| 217 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def is_conditional(self) -> bool:
|
| 221 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 222 |
+
|
| 223 |
+
@property
|
| 224 |
+
def num_conditions(self) -> int:
|
| 225 |
+
num_conditions = 1
|
| 226 |
+
if self._is_cfg_enabled():
|
| 227 |
+
num_conditions += 1
|
| 228 |
+
if self._is_slg_enabled():
|
| 229 |
+
num_conditions += 1
|
| 230 |
+
return num_conditions
|
| 231 |
+
|
| 232 |
+
def _is_cfg_enabled(self) -> bool:
|
| 233 |
+
if not self._enabled:
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
is_within_range = True
|
| 237 |
+
if self._num_inference_steps is not None:
|
| 238 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 239 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 240 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 241 |
+
|
| 242 |
+
is_close = False
|
| 243 |
+
if self.use_original_formulation:
|
| 244 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 245 |
+
else:
|
| 246 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 247 |
+
|
| 248 |
+
return is_within_range and not is_close
|
| 249 |
+
|
| 250 |
+
def _is_slg_enabled(self) -> bool:
|
| 251 |
+
if not self._enabled:
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
is_within_range = True
|
| 255 |
+
if self._num_inference_steps is not None:
|
| 256 |
+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
| 257 |
+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
| 258 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 259 |
+
|
| 260 |
+
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
| 261 |
+
|
| 262 |
+
return is_within_range and not is_zero
|
pythonProject/diffusers-main/build/lib/diffusers/guiders/smoothed_energy_guidance.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import register_to_config
|
| 21 |
+
from ..hooks import HookRegistry
|
| 22 |
+
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
| 23 |
+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from ..modular_pipelines.modular_pipeline import BlockState
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SmoothedEnergyGuidance(BaseGuidance):
|
| 31 |
+
"""
|
| 32 |
+
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
| 33 |
+
|
| 34 |
+
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
| 35 |
+
future without warning or guarantee of reproducibility. This implementation assumes:
|
| 36 |
+
- Generated images are square (height == width)
|
| 37 |
+
- The model does not combine different modalities together (e.g., text and image latent streams are not combined
|
| 38 |
+
together such as Flux)
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
guidance_scale (`float`, defaults to `7.5`):
|
| 42 |
+
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
| 43 |
+
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
| 44 |
+
deterioration of image quality.
|
| 45 |
+
seg_guidance_scale (`float`, defaults to `3.0`):
|
| 46 |
+
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
|
| 47 |
+
values, but it may also lead to overexposure and saturation.
|
| 48 |
+
seg_blur_sigma (`float`, defaults to `9999999.0`):
|
| 49 |
+
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
|
| 50 |
+
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
|
| 51 |
+
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
|
| 52 |
+
The threshold above which the blur is considered infinite.
|
| 53 |
+
seg_guidance_start (`float`, defaults to `0.0`):
|
| 54 |
+
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
| 55 |
+
seg_guidance_stop (`float`, defaults to `1.0`):
|
| 56 |
+
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
| 57 |
+
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
| 58 |
+
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
|
| 59 |
+
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
|
| 60 |
+
Diffusion 3.5 Medium.
|
| 61 |
+
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
| 62 |
+
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
|
| 63 |
+
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
| 64 |
+
guidance_rescale (`float`, defaults to `0.0`):
|
| 65 |
+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
| 66 |
+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 67 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 68 |
+
use_original_formulation (`bool`, defaults to `False`):
|
| 69 |
+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
| 70 |
+
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
| 71 |
+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
| 72 |
+
start (`float`, defaults to `0.01`):
|
| 73 |
+
The fraction of the total number of denoising steps after which guidance starts.
|
| 74 |
+
stop (`float`, defaults to `0.2`):
|
| 75 |
+
The fraction of the total number of denoising steps after which guidance stops.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
| 79 |
+
|
| 80 |
+
@register_to_config
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
guidance_scale: float = 7.5,
|
| 84 |
+
seg_guidance_scale: float = 2.8,
|
| 85 |
+
seg_blur_sigma: float = 9999999.0,
|
| 86 |
+
seg_blur_threshold_inf: float = 9999.0,
|
| 87 |
+
seg_guidance_start: float = 0.0,
|
| 88 |
+
seg_guidance_stop: float = 1.0,
|
| 89 |
+
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
| 90 |
+
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
| 91 |
+
guidance_rescale: float = 0.0,
|
| 92 |
+
use_original_formulation: bool = False,
|
| 93 |
+
start: float = 0.0,
|
| 94 |
+
stop: float = 1.0,
|
| 95 |
+
):
|
| 96 |
+
super().__init__(start, stop)
|
| 97 |
+
|
| 98 |
+
self.guidance_scale = guidance_scale
|
| 99 |
+
self.seg_guidance_scale = seg_guidance_scale
|
| 100 |
+
self.seg_blur_sigma = seg_blur_sigma
|
| 101 |
+
self.seg_blur_threshold_inf = seg_blur_threshold_inf
|
| 102 |
+
self.seg_guidance_start = seg_guidance_start
|
| 103 |
+
self.seg_guidance_stop = seg_guidance_stop
|
| 104 |
+
self.guidance_rescale = guidance_rescale
|
| 105 |
+
self.use_original_formulation = use_original_formulation
|
| 106 |
+
|
| 107 |
+
if not (0.0 <= seg_guidance_start < 1.0):
|
| 108 |
+
raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
|
| 109 |
+
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
|
| 110 |
+
raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
|
| 111 |
+
|
| 112 |
+
if seg_guidance_layers is None and seg_guidance_config is None:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
|
| 115 |
+
)
|
| 116 |
+
if seg_guidance_layers is not None and seg_guidance_config is not None:
|
| 117 |
+
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
|
| 118 |
+
|
| 119 |
+
if seg_guidance_layers is not None:
|
| 120 |
+
if isinstance(seg_guidance_layers, int):
|
| 121 |
+
seg_guidance_layers = [seg_guidance_layers]
|
| 122 |
+
if not isinstance(seg_guidance_layers, list):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
|
| 125 |
+
)
|
| 126 |
+
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
|
| 127 |
+
|
| 128 |
+
if isinstance(seg_guidance_config, dict):
|
| 129 |
+
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
|
| 130 |
+
|
| 131 |
+
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
|
| 132 |
+
seg_guidance_config = [seg_guidance_config]
|
| 133 |
+
|
| 134 |
+
if not isinstance(seg_guidance_config, list):
|
| 135 |
+
raise ValueError(
|
| 136 |
+
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
|
| 137 |
+
)
|
| 138 |
+
elif isinstance(next(iter(seg_guidance_config), None), dict):
|
| 139 |
+
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
|
| 140 |
+
|
| 141 |
+
self.seg_guidance_config = seg_guidance_config
|
| 142 |
+
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
|
| 143 |
+
|
| 144 |
+
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
| 145 |
+
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 146 |
+
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
| 147 |
+
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
| 148 |
+
|
| 149 |
+
def cleanup_models(self, denoiser: torch.nn.Module):
|
| 150 |
+
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
| 151 |
+
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
| 152 |
+
# Remove the hooks after inference
|
| 153 |
+
for hook_name in self._seg_layer_hook_names:
|
| 154 |
+
registry.remove_hook(hook_name, recurse=True)
|
| 155 |
+
|
| 156 |
+
def prepare_inputs(
|
| 157 |
+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
| 158 |
+
) -> List["BlockState"]:
|
| 159 |
+
if input_fields is None:
|
| 160 |
+
input_fields = self._input_fields
|
| 161 |
+
|
| 162 |
+
if self.num_conditions == 1:
|
| 163 |
+
tuple_indices = [0]
|
| 164 |
+
input_predictions = ["pred_cond"]
|
| 165 |
+
elif self.num_conditions == 2:
|
| 166 |
+
tuple_indices = [0, 1]
|
| 167 |
+
input_predictions = (
|
| 168 |
+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
tuple_indices = [0, 1, 0]
|
| 172 |
+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
| 173 |
+
data_batches = []
|
| 174 |
+
for i in range(self.num_conditions):
|
| 175 |
+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
| 176 |
+
data_batches.append(data_batch)
|
| 177 |
+
return data_batches
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
pred_cond: torch.Tensor,
|
| 182 |
+
pred_uncond: Optional[torch.Tensor] = None,
|
| 183 |
+
pred_cond_seg: Optional[torch.Tensor] = None,
|
| 184 |
+
) -> GuiderOutput:
|
| 185 |
+
pred = None
|
| 186 |
+
|
| 187 |
+
if not self._is_cfg_enabled() and not self._is_seg_enabled():
|
| 188 |
+
pred = pred_cond
|
| 189 |
+
elif not self._is_cfg_enabled():
|
| 190 |
+
shift = pred_cond - pred_cond_seg
|
| 191 |
+
pred = pred_cond if self.use_original_formulation else pred_cond_seg
|
| 192 |
+
pred = pred + self.seg_guidance_scale * shift
|
| 193 |
+
elif not self._is_seg_enabled():
|
| 194 |
+
shift = pred_cond - pred_uncond
|
| 195 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 196 |
+
pred = pred + self.guidance_scale * shift
|
| 197 |
+
else:
|
| 198 |
+
shift = pred_cond - pred_uncond
|
| 199 |
+
shift_seg = pred_cond - pred_cond_seg
|
| 200 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 201 |
+
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
|
| 202 |
+
|
| 203 |
+
if self.guidance_rescale > 0.0:
|
| 204 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 205 |
+
|
| 206 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def is_conditional(self) -> bool:
|
| 210 |
+
return self._count_prepared == 1 or self._count_prepared == 3
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def num_conditions(self) -> int:
|
| 214 |
+
num_conditions = 1
|
| 215 |
+
if self._is_cfg_enabled():
|
| 216 |
+
num_conditions += 1
|
| 217 |
+
if self._is_seg_enabled():
|
| 218 |
+
num_conditions += 1
|
| 219 |
+
return num_conditions
|
| 220 |
+
|
| 221 |
+
def _is_cfg_enabled(self) -> bool:
|
| 222 |
+
if not self._enabled:
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
is_within_range = True
|
| 226 |
+
if self._num_inference_steps is not None:
|
| 227 |
+
skip_start_step = int(self._start * self._num_inference_steps)
|
| 228 |
+
skip_stop_step = int(self._stop * self._num_inference_steps)
|
| 229 |
+
is_within_range = skip_start_step <= self._step < skip_stop_step
|
| 230 |
+
|
| 231 |
+
is_close = False
|
| 232 |
+
if self.use_original_formulation:
|
| 233 |
+
is_close = math.isclose(self.guidance_scale, 0.0)
|
| 234 |
+
else:
|
| 235 |
+
is_close = math.isclose(self.guidance_scale, 1.0)
|
| 236 |
+
|
| 237 |
+
return is_within_range and not is_close
|
| 238 |
+
|
| 239 |
+
def _is_seg_enabled(self) -> bool:
|
| 240 |
+
if not self._enabled:
|
| 241 |
+
return False
|
| 242 |
+
|
| 243 |
+
is_within_range = True
|
| 244 |
+
if self._num_inference_steps is not None:
|
| 245 |
+
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
| 246 |
+
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
| 247 |
+
is_within_range = skip_start_step < self._step < skip_stop_step
|
| 248 |
+
|
| 249 |
+
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
| 250 |
+
|
| 251 |
+
return is_within_range and not is_zero
|
pythonProject/diffusers-main/build/lib/diffusers/training_utils.py
ADDED
|
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import copy
|
| 3 |
+
import gc
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import warnings
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from .models import UNet2DConditionModel
|
| 15 |
+
from .pipelines import DiffusionPipeline
|
| 16 |
+
from .schedulers import SchedulerMixin
|
| 17 |
+
from .utils import (
|
| 18 |
+
convert_state_dict_to_diffusers,
|
| 19 |
+
convert_state_dict_to_peft,
|
| 20 |
+
deprecate,
|
| 21 |
+
is_peft_available,
|
| 22 |
+
is_torch_npu_available,
|
| 23 |
+
is_torchvision_available,
|
| 24 |
+
is_transformers_available,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if is_transformers_available():
|
| 29 |
+
import transformers
|
| 30 |
+
|
| 31 |
+
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 32 |
+
import deepspeed
|
| 33 |
+
|
| 34 |
+
if is_peft_available():
|
| 35 |
+
from peft import set_peft_model_state_dict
|
| 36 |
+
|
| 37 |
+
if is_torchvision_available():
|
| 38 |
+
from torchvision import transforms
|
| 39 |
+
|
| 40 |
+
if is_torch_npu_available():
|
| 41 |
+
import torch_npu # noqa: F401
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def set_seed(seed: int):
|
| 45 |
+
"""
|
| 46 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
seed (`int`): The seed to set.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
`None`
|
| 53 |
+
"""
|
| 54 |
+
random.seed(seed)
|
| 55 |
+
np.random.seed(seed)
|
| 56 |
+
torch.manual_seed(seed)
|
| 57 |
+
if is_torch_npu_available():
|
| 58 |
+
torch.npu.manual_seed_all(seed)
|
| 59 |
+
else:
|
| 60 |
+
torch.cuda.manual_seed_all(seed)
|
| 61 |
+
# ^^ safe to call this function even if cuda is not available
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compute_snr(noise_scheduler, timesteps):
|
| 65 |
+
"""
|
| 66 |
+
Computes SNR as per
|
| 67 |
+
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
| 68 |
+
for the given timesteps using the provided noise scheduler.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
noise_scheduler (`NoiseScheduler`):
|
| 72 |
+
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
| 73 |
+
the SNR values.
|
| 74 |
+
timesteps (`torch.Tensor`):
|
| 75 |
+
A tensor of timesteps for which the SNR is computed.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
| 79 |
+
"""
|
| 80 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 81 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 82 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 83 |
+
|
| 84 |
+
# Expand the tensors.
|
| 85 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
| 86 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 87 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 88 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 89 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 90 |
+
|
| 91 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 92 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 93 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 94 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 95 |
+
|
| 96 |
+
# Compute SNR.
|
| 97 |
+
snr = (alpha / sigma) ** 2
|
| 98 |
+
return snr
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def resolve_interpolation_mode(interpolation_type: str):
|
| 102 |
+
"""
|
| 103 |
+
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
| 104 |
+
full list of supported enums is documented at
|
| 105 |
+
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
interpolation_type (`str`):
|
| 109 |
+
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
|
| 110 |
+
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
|
| 111 |
+
in torchvision.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
| 115 |
+
transform.
|
| 116 |
+
"""
|
| 117 |
+
if not is_torchvision_available():
|
| 118 |
+
raise ImportError(
|
| 119 |
+
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if interpolation_type == "bilinear":
|
| 123 |
+
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
| 124 |
+
elif interpolation_type == "bicubic":
|
| 125 |
+
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
| 126 |
+
elif interpolation_type == "box":
|
| 127 |
+
interpolation_mode = transforms.InterpolationMode.BOX
|
| 128 |
+
elif interpolation_type == "nearest":
|
| 129 |
+
interpolation_mode = transforms.InterpolationMode.NEAREST
|
| 130 |
+
elif interpolation_type == "nearest_exact":
|
| 131 |
+
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
|
| 132 |
+
elif interpolation_type == "hamming":
|
| 133 |
+
interpolation_mode = transforms.InterpolationMode.HAMMING
|
| 134 |
+
elif interpolation_type == "lanczos":
|
| 135 |
+
interpolation_mode = transforms.InterpolationMode.LANCZOS
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
|
| 139 |
+
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return interpolation_mode
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def compute_dream_and_update_latents(
|
| 146 |
+
unet: UNet2DConditionModel,
|
| 147 |
+
noise_scheduler: SchedulerMixin,
|
| 148 |
+
timesteps: torch.Tensor,
|
| 149 |
+
noise: torch.Tensor,
|
| 150 |
+
noisy_latents: torch.Tensor,
|
| 151 |
+
target: torch.Tensor,
|
| 152 |
+
encoder_hidden_states: torch.Tensor,
|
| 153 |
+
dream_detail_preservation: float = 1.0,
|
| 154 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 155 |
+
"""
|
| 156 |
+
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
|
| 157 |
+
https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
|
| 158 |
+
efficient and accurate at the cost of an extra forward step without gradients.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
`unet`: The state unet to use to make a prediction.
|
| 162 |
+
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
|
| 163 |
+
`timesteps`: The timesteps for the noise_scheduler to user.
|
| 164 |
+
`noise`: A tensor of noise in the shape of noisy_latents.
|
| 165 |
+
`noisy_latents`: Previously noise latents from the training loop.
|
| 166 |
+
`target`: The ground-truth tensor to predict after eps is removed.
|
| 167 |
+
`encoder_hidden_states`: Text embeddings from the text model.
|
| 168 |
+
`dream_detail_preservation`: A float value that indicates detail preservation level.
|
| 169 |
+
See reference.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
|
| 173 |
+
"""
|
| 174 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
|
| 175 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 176 |
+
|
| 177 |
+
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
|
| 178 |
+
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
|
| 179 |
+
|
| 180 |
+
pred = None
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 183 |
+
|
| 184 |
+
_noisy_latents, _target = (None, None)
|
| 185 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 186 |
+
predicted_noise = pred
|
| 187 |
+
delta_noise = (noise - predicted_noise).detach()
|
| 188 |
+
delta_noise.mul_(dream_lambda)
|
| 189 |
+
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
| 190 |
+
_target = target.add(delta_noise)
|
| 191 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 192 |
+
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 195 |
+
|
| 196 |
+
return _noisy_latents, _target
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
| 200 |
+
r"""
|
| 201 |
+
Returns:
|
| 202 |
+
A state dict containing just the LoRA parameters.
|
| 203 |
+
"""
|
| 204 |
+
lora_state_dict = {}
|
| 205 |
+
|
| 206 |
+
for name, module in unet.named_modules():
|
| 207 |
+
if hasattr(module, "set_lora_layer"):
|
| 208 |
+
lora_layer = getattr(module, "lora_layer")
|
| 209 |
+
if lora_layer is not None:
|
| 210 |
+
current_lora_layer_sd = lora_layer.state_dict()
|
| 211 |
+
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
|
| 212 |
+
# The matrix name can either be "down" or "up".
|
| 213 |
+
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
| 214 |
+
|
| 215 |
+
return lora_state_dict
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
| 219 |
+
"""
|
| 220 |
+
Casts the training parameters of the model to the specified data type.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
model: The PyTorch model whose parameters will be cast.
|
| 224 |
+
dtype: The data type to which the model parameters will be cast.
|
| 225 |
+
"""
|
| 226 |
+
if not isinstance(model, list):
|
| 227 |
+
model = [model]
|
| 228 |
+
for m in model:
|
| 229 |
+
for param in m.parameters():
|
| 230 |
+
# only upcast trainable parameters into fp32
|
| 231 |
+
if param.requires_grad:
|
| 232 |
+
param.data = param.to(dtype)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _set_state_dict_into_text_encoder(
|
| 236 |
+
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
|
| 237 |
+
):
|
| 238 |
+
"""
|
| 239 |
+
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
lora_state_dict: The state dictionary to be set.
|
| 243 |
+
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
|
| 244 |
+
text_encoder: Where the `lora_state_dict` is to be set.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
text_encoder_state_dict = {
|
| 248 |
+
f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
|
| 249 |
+
}
|
| 250 |
+
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
|
| 251 |
+
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
|
| 255 |
+
metadatas = {}
|
| 256 |
+
for module_name, module in modules_to_save.items():
|
| 257 |
+
if module is not None:
|
| 258 |
+
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
|
| 259 |
+
return metadatas
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def compute_density_for_timestep_sampling(
|
| 263 |
+
weighting_scheme: str,
|
| 264 |
+
batch_size: int,
|
| 265 |
+
logit_mean: float = None,
|
| 266 |
+
logit_std: float = None,
|
| 267 |
+
mode_scale: float = None,
|
| 268 |
+
device: Union[torch.device, str] = "cpu",
|
| 269 |
+
generator: Optional[torch.Generator] = None,
|
| 270 |
+
):
|
| 271 |
+
"""
|
| 272 |
+
Compute the density for sampling the timesteps when doing SD3 training.
|
| 273 |
+
|
| 274 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
| 275 |
+
|
| 276 |
+
SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
|
| 277 |
+
"""
|
| 278 |
+
if weighting_scheme == "logit_normal":
|
| 279 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
|
| 280 |
+
u = torch.nn.functional.sigmoid(u)
|
| 281 |
+
elif weighting_scheme == "mode":
|
| 282 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
| 283 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
| 284 |
+
else:
|
| 285 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
| 286 |
+
return u
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
| 290 |
+
"""
|
| 291 |
+
Computes loss weighting scheme for SD3 training.
|
| 292 |
+
|
| 293 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
| 294 |
+
|
| 295 |
+
SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
|
| 296 |
+
"""
|
| 297 |
+
if weighting_scheme == "sigma_sqrt":
|
| 298 |
+
weighting = (sigmas**-2.0).float()
|
| 299 |
+
elif weighting_scheme == "cosmap":
|
| 300 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
| 301 |
+
weighting = 2 / (math.pi * bot)
|
| 302 |
+
else:
|
| 303 |
+
weighting = torch.ones_like(sigmas)
|
| 304 |
+
return weighting
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def free_memory():
|
| 308 |
+
"""
|
| 309 |
+
Runs garbage collection. Then clears the cache of the available accelerator.
|
| 310 |
+
"""
|
| 311 |
+
gc.collect()
|
| 312 |
+
|
| 313 |
+
if torch.cuda.is_available():
|
| 314 |
+
torch.cuda.empty_cache()
|
| 315 |
+
elif torch.backends.mps.is_available():
|
| 316 |
+
torch.mps.empty_cache()
|
| 317 |
+
elif is_torch_npu_available():
|
| 318 |
+
torch_npu.npu.empty_cache()
|
| 319 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 320 |
+
torch.xpu.empty_cache()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@contextmanager
|
| 324 |
+
def offload_models(
|
| 325 |
+
*modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
|
| 326 |
+
):
|
| 327 |
+
"""
|
| 328 |
+
Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
|
| 329 |
+
device on exit.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
device (`str` or `torch.Device`): Device to move the `modules` to.
|
| 333 |
+
offload (`bool`): Flag to enable offloading.
|
| 334 |
+
"""
|
| 335 |
+
if offload:
|
| 336 |
+
is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
|
| 337 |
+
# record where each module was
|
| 338 |
+
if is_model:
|
| 339 |
+
original_devices = [next(m.parameters()).device for m in modules]
|
| 340 |
+
else:
|
| 341 |
+
assert len(modules) == 1
|
| 342 |
+
# For DiffusionPipeline, wrap the device in a list to make it iterable
|
| 343 |
+
original_devices = [modules[0].device]
|
| 344 |
+
# move to target device
|
| 345 |
+
for m in modules:
|
| 346 |
+
m.to(device)
|
| 347 |
+
|
| 348 |
+
try:
|
| 349 |
+
yield
|
| 350 |
+
finally:
|
| 351 |
+
if offload:
|
| 352 |
+
# move back to original devices
|
| 353 |
+
for m, orig_dev in zip(modules, original_devices):
|
| 354 |
+
m.to(orig_dev)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def parse_buckets_string(buckets_str):
|
| 358 |
+
"""Parses a string defining buckets into a list of (height, width) tuples."""
|
| 359 |
+
if not buckets_str:
|
| 360 |
+
raise ValueError("Bucket string cannot be empty.")
|
| 361 |
+
|
| 362 |
+
bucket_pairs = buckets_str.strip().split(";")
|
| 363 |
+
parsed_buckets = []
|
| 364 |
+
for pair_str in bucket_pairs:
|
| 365 |
+
match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
|
| 366 |
+
if not match:
|
| 367 |
+
raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
|
| 368 |
+
try:
|
| 369 |
+
height = int(match.group(1))
|
| 370 |
+
width = int(match.group(2))
|
| 371 |
+
if height <= 0 or width <= 0:
|
| 372 |
+
raise ValueError("Bucket dimensions must be positive integers.")
|
| 373 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 374 |
+
warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
|
| 375 |
+
parsed_buckets.append((height, width))
|
| 376 |
+
except ValueError as e:
|
| 377 |
+
raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e
|
| 378 |
+
|
| 379 |
+
if not parsed_buckets:
|
| 380 |
+
raise ValueError("No valid buckets found in the provided string.")
|
| 381 |
+
|
| 382 |
+
return parsed_buckets
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def find_nearest_bucket(h, w, bucket_options):
|
| 386 |
+
"""Finds the closes bucket to the given height and width."""
|
| 387 |
+
min_metric = float("inf")
|
| 388 |
+
best_bucket_idx = None
|
| 389 |
+
for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
|
| 390 |
+
metric = abs(h * bucket_w - w * bucket_h)
|
| 391 |
+
if metric <= min_metric:
|
| 392 |
+
min_metric = metric
|
| 393 |
+
best_bucket_idx = bucket_idx
|
| 394 |
+
return best_bucket_idx
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
| 398 |
+
class EMAModel:
|
| 399 |
+
"""
|
| 400 |
+
Exponential Moving Average of models weights
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
def __init__(
|
| 404 |
+
self,
|
| 405 |
+
parameters: Iterable[torch.nn.Parameter],
|
| 406 |
+
decay: float = 0.9999,
|
| 407 |
+
min_decay: float = 0.0,
|
| 408 |
+
update_after_step: int = 0,
|
| 409 |
+
use_ema_warmup: bool = False,
|
| 410 |
+
inv_gamma: Union[float, int] = 1.0,
|
| 411 |
+
power: Union[float, int] = 2 / 3,
|
| 412 |
+
foreach: bool = False,
|
| 413 |
+
model_cls: Optional[Any] = None,
|
| 414 |
+
model_config: Dict[str, Any] = None,
|
| 415 |
+
**kwargs,
|
| 416 |
+
):
|
| 417 |
+
"""
|
| 418 |
+
Args:
|
| 419 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
| 420 |
+
decay (float): The decay factor for the exponential moving average.
|
| 421 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
| 422 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
| 423 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
| 424 |
+
inv_gamma (float):
|
| 425 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
| 426 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
| 427 |
+
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
|
| 428 |
+
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
| 429 |
+
weights will be stored on CPU.
|
| 430 |
+
|
| 431 |
+
@crowsonkb's notes on EMA Warmup:
|
| 432 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
| 433 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
| 434 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
| 435 |
+
at 215.4k steps).
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
if isinstance(parameters, torch.nn.Module):
|
| 439 |
+
deprecation_message = (
|
| 440 |
+
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
|
| 441 |
+
"Please pass the parameters of the module instead."
|
| 442 |
+
)
|
| 443 |
+
deprecate(
|
| 444 |
+
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
|
| 445 |
+
"1.0.0",
|
| 446 |
+
deprecation_message,
|
| 447 |
+
standard_warn=False,
|
| 448 |
+
)
|
| 449 |
+
parameters = parameters.parameters()
|
| 450 |
+
|
| 451 |
+
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
|
| 452 |
+
use_ema_warmup = True
|
| 453 |
+
|
| 454 |
+
if kwargs.get("max_value", None) is not None:
|
| 455 |
+
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
|
| 456 |
+
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
|
| 457 |
+
decay = kwargs["max_value"]
|
| 458 |
+
|
| 459 |
+
if kwargs.get("min_value", None) is not None:
|
| 460 |
+
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
|
| 461 |
+
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
|
| 462 |
+
min_decay = kwargs["min_value"]
|
| 463 |
+
|
| 464 |
+
parameters = list(parameters)
|
| 465 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
| 466 |
+
|
| 467 |
+
if kwargs.get("device", None) is not None:
|
| 468 |
+
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
|
| 469 |
+
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
|
| 470 |
+
self.to(device=kwargs["device"])
|
| 471 |
+
|
| 472 |
+
self.temp_stored_params = None
|
| 473 |
+
|
| 474 |
+
self.decay = decay
|
| 475 |
+
self.min_decay = min_decay
|
| 476 |
+
self.update_after_step = update_after_step
|
| 477 |
+
self.use_ema_warmup = use_ema_warmup
|
| 478 |
+
self.inv_gamma = inv_gamma
|
| 479 |
+
self.power = power
|
| 480 |
+
self.optimization_step = 0
|
| 481 |
+
self.cur_decay_value = None # set in `step()`
|
| 482 |
+
self.foreach = foreach
|
| 483 |
+
|
| 484 |
+
self.model_cls = model_cls
|
| 485 |
+
self.model_config = model_config
|
| 486 |
+
|
| 487 |
+
@classmethod
|
| 488 |
+
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
| 489 |
+
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
|
| 490 |
+
model = model_cls.from_pretrained(path)
|
| 491 |
+
|
| 492 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|
| 493 |
+
|
| 494 |
+
ema_model.load_state_dict(ema_kwargs)
|
| 495 |
+
return ema_model
|
| 496 |
+
|
| 497 |
+
def save_pretrained(self, path):
|
| 498 |
+
if self.model_cls is None:
|
| 499 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
| 500 |
+
|
| 501 |
+
if self.model_config is None:
|
| 502 |
+
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
| 503 |
+
|
| 504 |
+
model = self.model_cls.from_config(self.model_config)
|
| 505 |
+
state_dict = self.state_dict()
|
| 506 |
+
state_dict.pop("shadow_params", None)
|
| 507 |
+
|
| 508 |
+
model.register_to_config(**state_dict)
|
| 509 |
+
self.copy_to(model.parameters())
|
| 510 |
+
model.save_pretrained(path)
|
| 511 |
+
|
| 512 |
+
def get_decay(self, optimization_step: int) -> float:
|
| 513 |
+
"""
|
| 514 |
+
Compute the decay factor for the exponential moving average.
|
| 515 |
+
"""
|
| 516 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
| 517 |
+
|
| 518 |
+
if step <= 0:
|
| 519 |
+
return 0.0
|
| 520 |
+
|
| 521 |
+
if self.use_ema_warmup:
|
| 522 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
| 523 |
+
else:
|
| 524 |
+
cur_decay_value = (1 + step) / (10 + step)
|
| 525 |
+
|
| 526 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
| 527 |
+
# make sure decay is not smaller than min_decay
|
| 528 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
| 529 |
+
return cur_decay_value
|
| 530 |
+
|
| 531 |
+
@torch.no_grad()
|
| 532 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
| 533 |
+
if isinstance(parameters, torch.nn.Module):
|
| 534 |
+
deprecation_message = (
|
| 535 |
+
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
|
| 536 |
+
"Please pass the parameters of the module instead."
|
| 537 |
+
)
|
| 538 |
+
deprecate(
|
| 539 |
+
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
|
| 540 |
+
"1.0.0",
|
| 541 |
+
deprecation_message,
|
| 542 |
+
standard_warn=False,
|
| 543 |
+
)
|
| 544 |
+
parameters = parameters.parameters()
|
| 545 |
+
|
| 546 |
+
parameters = list(parameters)
|
| 547 |
+
|
| 548 |
+
self.optimization_step += 1
|
| 549 |
+
|
| 550 |
+
# Compute the decay factor for the exponential moving average.
|
| 551 |
+
decay = self.get_decay(self.optimization_step)
|
| 552 |
+
self.cur_decay_value = decay
|
| 553 |
+
one_minus_decay = 1 - decay
|
| 554 |
+
|
| 555 |
+
context_manager = contextlib.nullcontext()
|
| 556 |
+
|
| 557 |
+
if self.foreach:
|
| 558 |
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 559 |
+
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
|
| 560 |
+
|
| 561 |
+
with context_manager:
|
| 562 |
+
params_grad = [param for param in parameters if param.requires_grad]
|
| 563 |
+
s_params_grad = [
|
| 564 |
+
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
|
| 565 |
+
]
|
| 566 |
+
|
| 567 |
+
if len(params_grad) < len(parameters):
|
| 568 |
+
torch._foreach_copy_(
|
| 569 |
+
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
|
| 570 |
+
[param for param in parameters if not param.requires_grad],
|
| 571 |
+
non_blocking=True,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
torch._foreach_sub_(
|
| 575 |
+
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
else:
|
| 579 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 580 |
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 581 |
+
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
| 582 |
+
|
| 583 |
+
with context_manager:
|
| 584 |
+
if param.requires_grad:
|
| 585 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
| 586 |
+
else:
|
| 587 |
+
s_param.copy_(param)
|
| 588 |
+
|
| 589 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 590 |
+
"""
|
| 591 |
+
Copy current averaged parameters into given collection of parameters.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 595 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
| 596 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 597 |
+
"""
|
| 598 |
+
parameters = list(parameters)
|
| 599 |
+
if self.foreach:
|
| 600 |
+
torch._foreach_copy_(
|
| 601 |
+
[param.data for param in parameters],
|
| 602 |
+
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
|
| 603 |
+
)
|
| 604 |
+
else:
|
| 605 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 606 |
+
param.data.copy_(s_param.to(param.device).data)
|
| 607 |
+
|
| 608 |
+
def pin_memory(self) -> None:
|
| 609 |
+
r"""
|
| 610 |
+
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
|
| 611 |
+
offloading EMA params to the host.
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
|
| 615 |
+
|
| 616 |
+
def to(self, device=None, dtype=None, non_blocking=False) -> None:
|
| 617 |
+
r"""
|
| 618 |
+
Move internal buffers of the ExponentialMovingAverage to `device`.
|
| 619 |
+
|
| 620 |
+
Args:
|
| 621 |
+
device: like `device` argument to `torch.Tensor.to`
|
| 622 |
+
"""
|
| 623 |
+
# .to() on the tensors handles None correctly
|
| 624 |
+
self.shadow_params = [
|
| 625 |
+
p.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
| 626 |
+
if p.is_floating_point()
|
| 627 |
+
else p.to(device=device, non_blocking=non_blocking)
|
| 628 |
+
for p in self.shadow_params
|
| 629 |
+
]
|
| 630 |
+
|
| 631 |
+
def state_dict(self) -> dict:
|
| 632 |
+
r"""
|
| 633 |
+
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
| 634 |
+
checkpointing to save the ema state dict.
|
| 635 |
+
"""
|
| 636 |
+
# Following PyTorch conventions, references to tensors are returned:
|
| 637 |
+
# "returns a reference to the state and not its copy!" -
|
| 638 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
| 639 |
+
return {
|
| 640 |
+
"decay": self.decay,
|
| 641 |
+
"min_decay": self.min_decay,
|
| 642 |
+
"optimization_step": self.optimization_step,
|
| 643 |
+
"update_after_step": self.update_after_step,
|
| 644 |
+
"use_ema_warmup": self.use_ema_warmup,
|
| 645 |
+
"inv_gamma": self.inv_gamma,
|
| 646 |
+
"power": self.power,
|
| 647 |
+
"shadow_params": self.shadow_params,
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 651 |
+
r"""
|
| 652 |
+
Saves the current parameters for restoring later.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
|
| 656 |
+
"""
|
| 657 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
| 658 |
+
|
| 659 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 660 |
+
r"""
|
| 661 |
+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
|
| 662 |
+
without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
| 663 |
+
validation (or model saving), use this to restore the former parameters.
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 667 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
| 668 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
if self.temp_stored_params is None:
|
| 672 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
| 673 |
+
if self.foreach:
|
| 674 |
+
torch._foreach_copy_(
|
| 675 |
+
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
| 679 |
+
param.data.copy_(c_param.data)
|
| 680 |
+
|
| 681 |
+
# Better memory-wise.
|
| 682 |
+
self.temp_stored_params = None
|
| 683 |
+
|
| 684 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 685 |
+
r"""
|
| 686 |
+
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
| 687 |
+
ema state dict.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
state_dict (dict): EMA state. Should be an object returned
|
| 691 |
+
from a call to :meth:`state_dict`.
|
| 692 |
+
"""
|
| 693 |
+
# deepcopy, to be consistent with module API
|
| 694 |
+
state_dict = copy.deepcopy(state_dict)
|
| 695 |
+
|
| 696 |
+
self.decay = state_dict.get("decay", self.decay)
|
| 697 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
| 698 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 699 |
+
|
| 700 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
| 701 |
+
if not isinstance(self.min_decay, float):
|
| 702 |
+
raise ValueError("Invalid min_decay")
|
| 703 |
+
|
| 704 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
| 705 |
+
if not isinstance(self.optimization_step, int):
|
| 706 |
+
raise ValueError("Invalid optimization_step")
|
| 707 |
+
|
| 708 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
| 709 |
+
if not isinstance(self.update_after_step, int):
|
| 710 |
+
raise ValueError("Invalid update_after_step")
|
| 711 |
+
|
| 712 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
| 713 |
+
if not isinstance(self.use_ema_warmup, bool):
|
| 714 |
+
raise ValueError("Invalid use_ema_warmup")
|
| 715 |
+
|
| 716 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
| 717 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
| 718 |
+
raise ValueError("Invalid inv_gamma")
|
| 719 |
+
|
| 720 |
+
self.power = state_dict.get("power", self.power)
|
| 721 |
+
if not isinstance(self.power, (float, int)):
|
| 722 |
+
raise ValueError("Invalid power")
|
| 723 |
+
|
| 724 |
+
shadow_params = state_dict.get("shadow_params", None)
|
| 725 |
+
if shadow_params is not None:
|
| 726 |
+
self.shadow_params = shadow_params
|
| 727 |
+
if not isinstance(self.shadow_params, list):
|
| 728 |
+
raise ValueError("shadow_params must be a list")
|
| 729 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
| 730 |
+
raise ValueError("shadow_params must all be Tensors")
|
pythonProject/diffusers-main/build/lib/diffusers/video_processor.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
from typing import List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VideoProcessor(VaeImageProcessor):
|
| 26 |
+
r"""Simple video processor."""
|
| 27 |
+
|
| 28 |
+
def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor:
|
| 29 |
+
r"""
|
| 30 |
+
Preprocesses input video(s).
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`):
|
| 34 |
+
The input video. It can be one of the following:
|
| 35 |
+
* List of the PIL images.
|
| 36 |
+
* List of list of PIL images.
|
| 37 |
+
* 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`).
|
| 38 |
+
* 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
|
| 39 |
+
* List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height,
|
| 40 |
+
width)`).
|
| 41 |
+
* List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
|
| 42 |
+
* 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width,
|
| 43 |
+
num_channels)`.
|
| 44 |
+
* 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height,
|
| 45 |
+
width)`.
|
| 46 |
+
height (`int`, *optional*, defaults to `None`):
|
| 47 |
+
The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to
|
| 48 |
+
get default height.
|
| 49 |
+
width (`int`, *optional*`, defaults to `None`):
|
| 50 |
+
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
|
| 51 |
+
the default width.
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
|
| 54 |
+
warnings.warn(
|
| 55 |
+
"Passing `video` as a list of 5d np.ndarray is deprecated."
|
| 56 |
+
"Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray",
|
| 57 |
+
FutureWarning,
|
| 58 |
+
)
|
| 59 |
+
video = np.concatenate(video, axis=0)
|
| 60 |
+
if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5:
|
| 61 |
+
warnings.warn(
|
| 62 |
+
"Passing `video` as a list of 5d torch.Tensor is deprecated."
|
| 63 |
+
"Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor",
|
| 64 |
+
FutureWarning,
|
| 65 |
+
)
|
| 66 |
+
video = torch.cat(video, axis=0)
|
| 67 |
+
|
| 68 |
+
# ensure the input is a list of videos:
|
| 69 |
+
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
|
| 70 |
+
# - if it is a single video, it is converted to a list of one video.
|
| 71 |
+
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
|
| 72 |
+
video = list(video)
|
| 73 |
+
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
|
| 74 |
+
video = [video]
|
| 75 |
+
elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
|
| 76 |
+
video = video
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
|
| 83 |
+
|
| 84 |
+
# move the number of channels before the number of frames.
|
| 85 |
+
video = video.permute(0, 2, 1, 3, 4)
|
| 86 |
+
|
| 87 |
+
return video
|
| 88 |
+
|
| 89 |
+
def postprocess_video(
|
| 90 |
+
self, video: torch.Tensor, output_type: str = "np"
|
| 91 |
+
) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]:
|
| 92 |
+
r"""
|
| 93 |
+
Converts a video tensor to a list of frames for export.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
video (`torch.Tensor`): The video as a tensor.
|
| 97 |
+
output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
|
| 98 |
+
"""
|
| 99 |
+
batch_size = video.shape[0]
|
| 100 |
+
outputs = []
|
| 101 |
+
for batch_idx in range(batch_size):
|
| 102 |
+
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
| 103 |
+
batch_output = self.postprocess(batch_vid, output_type)
|
| 104 |
+
outputs.append(batch_output)
|
| 105 |
+
|
| 106 |
+
if output_type == "np":
|
| 107 |
+
outputs = np.stack(outputs)
|
| 108 |
+
elif output_type == "pt":
|
| 109 |
+
outputs = torch.stack(outputs)
|
| 110 |
+
elif not output_type == "pil":
|
| 111 |
+
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
| 112 |
+
|
| 113 |
+
return outputs
|