Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/__init__.py +57 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1118 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1309 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1023 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +1065 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1353 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_output.py +24 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/__init__.py +51 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/pipeline_audioldm.py +558 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/__init__.py +50 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/modeling_audioldm2.py +1475 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1104 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/__init__.py +48 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +677 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/__init__.py +20 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/blip_image_processing.py +318 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_blip2.py +639 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +223 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +361 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/__init__.py +48 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_bria.py +729 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_output.py +21 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/__init__.py +49 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_output.py +21 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/__init__.py +54 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +789 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +842 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +903 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +868 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/__init__.py +47 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_cogview3plus.py +682 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/__init__.py +49 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4.py +685 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/__init__.py +49 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/consisid_utils.py +357 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_output.py +20 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/__init__.py +24 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +286 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/__init__.py +86 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/multicontrolnet.py +12 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet.py +1366 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +427 -0
- pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1338 -0
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]}
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 18 |
+
raise OptionalDependencyNotAvailable()
|
| 19 |
+
except OptionalDependencyNotAvailable:
|
| 20 |
+
from ...utils import dummy_torch_and_transformers_objects
|
| 21 |
+
|
| 22 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 23 |
+
else:
|
| 24 |
+
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
|
| 25 |
+
_import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"]
|
| 26 |
+
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
|
| 27 |
+
_import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
|
| 28 |
+
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
|
| 29 |
+
_import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"]
|
| 30 |
+
|
| 31 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 32 |
+
try:
|
| 33 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 34 |
+
raise OptionalDependencyNotAvailable()
|
| 35 |
+
except OptionalDependencyNotAvailable:
|
| 36 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
from .pipeline_animatediff import AnimateDiffPipeline
|
| 40 |
+
from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
|
| 41 |
+
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
|
| 42 |
+
from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
|
| 43 |
+
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
|
| 44 |
+
from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline
|
| 45 |
+
from .pipeline_output import AnimateDiffPipelineOutput
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
import sys
|
| 49 |
+
|
| 50 |
+
sys.modules[__name__] = _LazyModule(
|
| 51 |
+
__name__,
|
| 52 |
+
globals()["__file__"],
|
| 53 |
+
_import_structure,
|
| 54 |
+
module_spec=__spec__,
|
| 55 |
+
)
|
| 56 |
+
for name, value in _dummy_objects.items():
|
| 57 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
ADDED
|
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 21 |
+
|
| 22 |
+
from ...image_processor import PipelineImageInput
|
| 23 |
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
| 24 |
+
from ...models import (
|
| 25 |
+
AutoencoderKL,
|
| 26 |
+
ControlNetModel,
|
| 27 |
+
ImageProjection,
|
| 28 |
+
MultiControlNetModel,
|
| 29 |
+
UNet2DConditionModel,
|
| 30 |
+
UNetMotionModel,
|
| 31 |
+
)
|
| 32 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 33 |
+
from ...models.unets.unet_motion_model import MotionAdapter
|
| 34 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 35 |
+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
|
| 36 |
+
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
| 37 |
+
from ...video_processor import VideoProcessor
|
| 38 |
+
from ..free_init_utils import FreeInitMixin
|
| 39 |
+
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
| 40 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 41 |
+
from .pipeline_output import AnimateDiffPipelineOutput
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_torch_xla_available():
|
| 45 |
+
import torch_xla.core.xla_model as xm
|
| 46 |
+
|
| 47 |
+
XLA_AVAILABLE = True
|
| 48 |
+
else:
|
| 49 |
+
XLA_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
EXAMPLE_DOC_STRING = """
|
| 55 |
+
Examples:
|
| 56 |
+
```py
|
| 57 |
+
>>> import torch
|
| 58 |
+
>>> from diffusers import (
|
| 59 |
+
... AnimateDiffControlNetPipeline,
|
| 60 |
+
... AutoencoderKL,
|
| 61 |
+
... ControlNetModel,
|
| 62 |
+
... MotionAdapter,
|
| 63 |
+
... LCMScheduler,
|
| 64 |
+
... )
|
| 65 |
+
>>> from diffusers.utils import export_to_gif, load_video
|
| 66 |
+
|
| 67 |
+
>>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet
|
| 68 |
+
>>> # HF maintains just the right package for it: `pip install controlnet_aux`
|
| 69 |
+
>>> from controlnet_aux.processor import ZoeDetector
|
| 70 |
+
|
| 71 |
+
>>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
|
| 72 |
+
>>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
|
| 73 |
+
>>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)
|
| 74 |
+
|
| 75 |
+
>>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
|
| 76 |
+
>>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
| 77 |
+
|
| 78 |
+
>>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
| 79 |
+
>>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
|
| 80 |
+
... "SG161222/Realistic_Vision_V5.1_noVAE",
|
| 81 |
+
... motion_adapter=motion_adapter,
|
| 82 |
+
... controlnet=controlnet,
|
| 83 |
+
... vae=vae,
|
| 84 |
+
... ).to(device="cuda", dtype=torch.float16)
|
| 85 |
+
>>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
| 86 |
+
>>> pipe.load_lora_weights(
|
| 87 |
+
... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora"
|
| 88 |
+
... )
|
| 89 |
+
>>> pipe.set_adapters(["lcm-lora"], [0.8])
|
| 90 |
+
|
| 91 |
+
>>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
|
| 92 |
+
>>> video = load_video(
|
| 93 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
|
| 94 |
+
... )
|
| 95 |
+
>>> conditioning_frames = []
|
| 96 |
+
|
| 97 |
+
>>> with pipe.progress_bar(total=len(video)) as progress_bar:
|
| 98 |
+
... for frame in video:
|
| 99 |
+
... conditioning_frames.append(depth_detector(frame))
|
| 100 |
+
... progress_bar.update()
|
| 101 |
+
|
| 102 |
+
>>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
|
| 103 |
+
>>> negative_prompt = "bad quality, worst quality"
|
| 104 |
+
|
| 105 |
+
>>> video = pipe(
|
| 106 |
+
... prompt=prompt,
|
| 107 |
+
... negative_prompt=negative_prompt,
|
| 108 |
+
... num_frames=len(video),
|
| 109 |
+
... num_inference_steps=10,
|
| 110 |
+
... guidance_scale=2.0,
|
| 111 |
+
... conditioning_frames=conditioning_frames,
|
| 112 |
+
... generator=torch.Generator().manual_seed(42),
|
| 113 |
+
... ).frames[0]
|
| 114 |
+
|
| 115 |
+
>>> export_to_gif(video, "animatediff_controlnet.gif", fps=8)
|
| 116 |
+
```
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class AnimateDiffControlNetPipeline(
|
| 121 |
+
DiffusionPipeline,
|
| 122 |
+
StableDiffusionMixin,
|
| 123 |
+
TextualInversionLoaderMixin,
|
| 124 |
+
IPAdapterMixin,
|
| 125 |
+
StableDiffusionLoraLoaderMixin,
|
| 126 |
+
FreeInitMixin,
|
| 127 |
+
AnimateDiffFreeNoiseMixin,
|
| 128 |
+
FromSingleFileMixin,
|
| 129 |
+
):
|
| 130 |
+
r"""
|
| 131 |
+
Pipeline for text-to-video generation with ControlNet guidance.
|
| 132 |
+
|
| 133 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 134 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 135 |
+
|
| 136 |
+
The pipeline also inherits the following loading methods:
|
| 137 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 138 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 139 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 140 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
vae ([`AutoencoderKL`]):
|
| 144 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 145 |
+
text_encoder ([`CLIPTextModel`]):
|
| 146 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 147 |
+
tokenizer (`CLIPTokenizer`):
|
| 148 |
+
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
| 149 |
+
unet ([`UNet2DConditionModel`]):
|
| 150 |
+
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
|
| 151 |
+
motion_adapter ([`MotionAdapter`]):
|
| 152 |
+
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
|
| 153 |
+
scheduler ([`SchedulerMixin`]):
|
| 154 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 155 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
| 159 |
+
_optional_components = ["feature_extractor", "image_encoder"]
|
| 160 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
vae: AutoencoderKL,
|
| 165 |
+
text_encoder: CLIPTextModel,
|
| 166 |
+
tokenizer: CLIPTokenizer,
|
| 167 |
+
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
| 168 |
+
motion_adapter: MotionAdapter,
|
| 169 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
| 170 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 171 |
+
feature_extractor: Optional[CLIPImageProcessor] = None,
|
| 172 |
+
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
if isinstance(unet, UNet2DConditionModel):
|
| 176 |
+
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
| 177 |
+
|
| 178 |
+
if isinstance(controlnet, (list, tuple)):
|
| 179 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 180 |
+
|
| 181 |
+
self.register_modules(
|
| 182 |
+
vae=vae,
|
| 183 |
+
text_encoder=text_encoder,
|
| 184 |
+
tokenizer=tokenizer,
|
| 185 |
+
unet=unet,
|
| 186 |
+
motion_adapter=motion_adapter,
|
| 187 |
+
controlnet=controlnet,
|
| 188 |
+
scheduler=scheduler,
|
| 189 |
+
feature_extractor=feature_extractor,
|
| 190 |
+
image_encoder=image_encoder,
|
| 191 |
+
)
|
| 192 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 193 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 194 |
+
self.control_video_processor = VideoProcessor(
|
| 195 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
| 199 |
+
def encode_prompt(
|
| 200 |
+
self,
|
| 201 |
+
prompt,
|
| 202 |
+
device,
|
| 203 |
+
num_images_per_prompt,
|
| 204 |
+
do_classifier_free_guidance,
|
| 205 |
+
negative_prompt=None,
|
| 206 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 207 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 208 |
+
lora_scale: Optional[float] = None,
|
| 209 |
+
clip_skip: Optional[int] = None,
|
| 210 |
+
):
|
| 211 |
+
r"""
|
| 212 |
+
Encodes the prompt into text encoder hidden states.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 216 |
+
prompt to be encoded
|
| 217 |
+
device: (`torch.device`):
|
| 218 |
+
torch device
|
| 219 |
+
num_images_per_prompt (`int`):
|
| 220 |
+
number of images that should be generated per prompt
|
| 221 |
+
do_classifier_free_guidance (`bool`):
|
| 222 |
+
whether to use classifier free guidance or not
|
| 223 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 224 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 225 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 226 |
+
less than `1`).
|
| 227 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 228 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 229 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 230 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 231 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 232 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 233 |
+
argument.
|
| 234 |
+
lora_scale (`float`, *optional*):
|
| 235 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 236 |
+
clip_skip (`int`, *optional*):
|
| 237 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 238 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 239 |
+
"""
|
| 240 |
+
# set lora scale so that monkey patched LoRA
|
| 241 |
+
# function of text encoder can correctly access it
|
| 242 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 243 |
+
self._lora_scale = lora_scale
|
| 244 |
+
|
| 245 |
+
# dynamically adjust the LoRA scale
|
| 246 |
+
if not USE_PEFT_BACKEND:
|
| 247 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 248 |
+
else:
|
| 249 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 250 |
+
|
| 251 |
+
if prompt is not None and isinstance(prompt, str):
|
| 252 |
+
batch_size = 1
|
| 253 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 254 |
+
batch_size = len(prompt)
|
| 255 |
+
else:
|
| 256 |
+
batch_size = prompt_embeds.shape[0]
|
| 257 |
+
|
| 258 |
+
if prompt_embeds is None:
|
| 259 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 260 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 261 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 262 |
+
|
| 263 |
+
text_inputs = self.tokenizer(
|
| 264 |
+
prompt,
|
| 265 |
+
padding="max_length",
|
| 266 |
+
max_length=self.tokenizer.model_max_length,
|
| 267 |
+
truncation=True,
|
| 268 |
+
return_tensors="pt",
|
| 269 |
+
)
|
| 270 |
+
text_input_ids = text_inputs.input_ids
|
| 271 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 272 |
+
|
| 273 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 274 |
+
text_input_ids, untruncated_ids
|
| 275 |
+
):
|
| 276 |
+
removed_text = self.tokenizer.batch_decode(
|
| 277 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 278 |
+
)
|
| 279 |
+
logger.warning(
|
| 280 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 281 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 285 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 286 |
+
else:
|
| 287 |
+
attention_mask = None
|
| 288 |
+
|
| 289 |
+
if clip_skip is None:
|
| 290 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 291 |
+
prompt_embeds = prompt_embeds[0]
|
| 292 |
+
else:
|
| 293 |
+
prompt_embeds = self.text_encoder(
|
| 294 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 295 |
+
)
|
| 296 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 297 |
+
# all the hidden states from the encoder layers. Then index into
|
| 298 |
+
# the tuple to access the hidden states from the desired layer.
|
| 299 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 300 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 301 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 302 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 303 |
+
# layer.
|
| 304 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 305 |
+
|
| 306 |
+
if self.text_encoder is not None:
|
| 307 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 308 |
+
elif self.unet is not None:
|
| 309 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 310 |
+
else:
|
| 311 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 312 |
+
|
| 313 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 314 |
+
|
| 315 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 316 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 317 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 318 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 319 |
+
|
| 320 |
+
# get unconditional embeddings for classifier free guidance
|
| 321 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 322 |
+
uncond_tokens: List[str]
|
| 323 |
+
if negative_prompt is None:
|
| 324 |
+
uncond_tokens = [""] * batch_size
|
| 325 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 326 |
+
raise TypeError(
|
| 327 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 328 |
+
f" {type(prompt)}."
|
| 329 |
+
)
|
| 330 |
+
elif isinstance(negative_prompt, str):
|
| 331 |
+
uncond_tokens = [negative_prompt]
|
| 332 |
+
elif batch_size != len(negative_prompt):
|
| 333 |
+
raise ValueError(
|
| 334 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 335 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 336 |
+
" the batch size of `prompt`."
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
uncond_tokens = negative_prompt
|
| 340 |
+
|
| 341 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 342 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 343 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 344 |
+
|
| 345 |
+
max_length = prompt_embeds.shape[1]
|
| 346 |
+
uncond_input = self.tokenizer(
|
| 347 |
+
uncond_tokens,
|
| 348 |
+
padding="max_length",
|
| 349 |
+
max_length=max_length,
|
| 350 |
+
truncation=True,
|
| 351 |
+
return_tensors="pt",
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 355 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 356 |
+
else:
|
| 357 |
+
attention_mask = None
|
| 358 |
+
|
| 359 |
+
negative_prompt_embeds = self.text_encoder(
|
| 360 |
+
uncond_input.input_ids.to(device),
|
| 361 |
+
attention_mask=attention_mask,
|
| 362 |
+
)
|
| 363 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 364 |
+
|
| 365 |
+
if do_classifier_free_guidance:
|
| 366 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 367 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 368 |
+
|
| 369 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 370 |
+
|
| 371 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 372 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 373 |
+
|
| 374 |
+
if self.text_encoder is not None:
|
| 375 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 376 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 377 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 378 |
+
|
| 379 |
+
return prompt_embeds, negative_prompt_embeds
|
| 380 |
+
|
| 381 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 382 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 383 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 384 |
+
|
| 385 |
+
if not isinstance(image, torch.Tensor):
|
| 386 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 387 |
+
|
| 388 |
+
image = image.to(device=device, dtype=dtype)
|
| 389 |
+
if output_hidden_states:
|
| 390 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 391 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 392 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 393 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 394 |
+
).hidden_states[-2]
|
| 395 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 396 |
+
num_images_per_prompt, dim=0
|
| 397 |
+
)
|
| 398 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 399 |
+
else:
|
| 400 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 401 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 402 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 403 |
+
|
| 404 |
+
return image_embeds, uncond_image_embeds
|
| 405 |
+
|
| 406 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 407 |
+
def prepare_ip_adapter_image_embeds(
|
| 408 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 409 |
+
):
|
| 410 |
+
image_embeds = []
|
| 411 |
+
if do_classifier_free_guidance:
|
| 412 |
+
negative_image_embeds = []
|
| 413 |
+
if ip_adapter_image_embeds is None:
|
| 414 |
+
if not isinstance(ip_adapter_image, list):
|
| 415 |
+
ip_adapter_image = [ip_adapter_image]
|
| 416 |
+
|
| 417 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 423 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 424 |
+
):
|
| 425 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 426 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 427 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 431 |
+
if do_classifier_free_guidance:
|
| 432 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 433 |
+
else:
|
| 434 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 435 |
+
if do_classifier_free_guidance:
|
| 436 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 437 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 438 |
+
image_embeds.append(single_image_embeds)
|
| 439 |
+
|
| 440 |
+
ip_adapter_image_embeds = []
|
| 441 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 442 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 443 |
+
if do_classifier_free_guidance:
|
| 444 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 445 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 446 |
+
|
| 447 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 448 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 449 |
+
|
| 450 |
+
return ip_adapter_image_embeds
|
| 451 |
+
|
| 452 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
| 453 |
+
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
| 454 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 455 |
+
|
| 456 |
+
batch_size, channels, num_frames, height, width = latents.shape
|
| 457 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
| 458 |
+
|
| 459 |
+
video = []
|
| 460 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 461 |
+
batch_latents = latents[i : i + decode_chunk_size]
|
| 462 |
+
batch_latents = self.vae.decode(batch_latents).sample
|
| 463 |
+
video.append(batch_latents)
|
| 464 |
+
|
| 465 |
+
video = torch.cat(video)
|
| 466 |
+
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 467 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 468 |
+
video = video.float()
|
| 469 |
+
return video
|
| 470 |
+
|
| 471 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 472 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 473 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 474 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 475 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 476 |
+
# and should be between [0, 1]
|
| 477 |
+
|
| 478 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 479 |
+
extra_step_kwargs = {}
|
| 480 |
+
if accepts_eta:
|
| 481 |
+
extra_step_kwargs["eta"] = eta
|
| 482 |
+
|
| 483 |
+
# check if the scheduler accepts generator
|
| 484 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 485 |
+
if accepts_generator:
|
| 486 |
+
extra_step_kwargs["generator"] = generator
|
| 487 |
+
return extra_step_kwargs
|
| 488 |
+
|
| 489 |
+
def check_inputs(
|
| 490 |
+
self,
|
| 491 |
+
prompt,
|
| 492 |
+
height,
|
| 493 |
+
width,
|
| 494 |
+
num_frames,
|
| 495 |
+
negative_prompt=None,
|
| 496 |
+
prompt_embeds=None,
|
| 497 |
+
negative_prompt_embeds=None,
|
| 498 |
+
callback_on_step_end_tensor_inputs=None,
|
| 499 |
+
video=None,
|
| 500 |
+
controlnet_conditioning_scale=1.0,
|
| 501 |
+
control_guidance_start=0.0,
|
| 502 |
+
control_guidance_end=1.0,
|
| 503 |
+
):
|
| 504 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 505 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 506 |
+
|
| 507 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 508 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 509 |
+
):
|
| 510 |
+
raise ValueError(
|
| 511 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
if prompt is not None and prompt_embeds is not None:
|
| 515 |
+
raise ValueError(
|
| 516 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 517 |
+
" only forward one of the two."
|
| 518 |
+
)
|
| 519 |
+
elif prompt is None and prompt_embeds is None:
|
| 520 |
+
raise ValueError(
|
| 521 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 522 |
+
)
|
| 523 |
+
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
| 524 |
+
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
|
| 525 |
+
|
| 526 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 529 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 533 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 534 |
+
raise ValueError(
|
| 535 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 536 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 537 |
+
f" {negative_prompt_embeds.shape}."
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# `prompt` needs more sophisticated handling when there are multiple
|
| 541 |
+
# conditionings.
|
| 542 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 543 |
+
if isinstance(prompt, list):
|
| 544 |
+
logger.warning(
|
| 545 |
+
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
| 546 |
+
" prompts. The conditionings will be fixed across the prompts."
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Check `image`
|
| 550 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
| 551 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
| 552 |
+
)
|
| 553 |
+
if (
|
| 554 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 555 |
+
or is_compiled
|
| 556 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 557 |
+
):
|
| 558 |
+
if not isinstance(video, list):
|
| 559 |
+
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}")
|
| 560 |
+
if len(video) != num_frames:
|
| 561 |
+
raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}")
|
| 562 |
+
elif (
|
| 563 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 564 |
+
or is_compiled
|
| 565 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 566 |
+
):
|
| 567 |
+
if not isinstance(video, list) or not isinstance(video[0], list):
|
| 568 |
+
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}")
|
| 569 |
+
if len(video[0]) != num_frames:
|
| 570 |
+
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}")
|
| 571 |
+
if any(len(img) != len(video[0]) for img in video):
|
| 572 |
+
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
|
| 573 |
+
else:
|
| 574 |
+
assert False
|
| 575 |
+
|
| 576 |
+
# Check `controlnet_conditioning_scale`
|
| 577 |
+
if (
|
| 578 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 579 |
+
or is_compiled
|
| 580 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 581 |
+
):
|
| 582 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
| 583 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
| 584 |
+
elif (
|
| 585 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 586 |
+
or is_compiled
|
| 587 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 588 |
+
):
|
| 589 |
+
if isinstance(controlnet_conditioning_scale, list):
|
| 590 |
+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
| 591 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
| 592 |
+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
| 593 |
+
self.controlnet.nets
|
| 594 |
+
):
|
| 595 |
+
raise ValueError(
|
| 596 |
+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
| 597 |
+
" the same length as the number of controlnets"
|
| 598 |
+
)
|
| 599 |
+
else:
|
| 600 |
+
assert False
|
| 601 |
+
|
| 602 |
+
if not isinstance(control_guidance_start, (tuple, list)):
|
| 603 |
+
control_guidance_start = [control_guidance_start]
|
| 604 |
+
|
| 605 |
+
if not isinstance(control_guidance_end, (tuple, list)):
|
| 606 |
+
control_guidance_end = [control_guidance_end]
|
| 607 |
+
|
| 608 |
+
if len(control_guidance_start) != len(control_guidance_end):
|
| 609 |
+
raise ValueError(
|
| 610 |
+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 614 |
+
if len(control_guidance_start) != len(self.controlnet.nets):
|
| 615 |
+
raise ValueError(
|
| 616 |
+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
for start, end in zip(control_guidance_start, control_guidance_end):
|
| 620 |
+
if start >= end:
|
| 621 |
+
raise ValueError(
|
| 622 |
+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
| 623 |
+
)
|
| 624 |
+
if start < 0.0:
|
| 625 |
+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
| 626 |
+
if end > 1.0:
|
| 627 |
+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
| 628 |
+
|
| 629 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
|
| 630 |
+
def prepare_latents(
|
| 631 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 632 |
+
):
|
| 633 |
+
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169)
|
| 634 |
+
if self.free_noise_enabled:
|
| 635 |
+
latents = self._prepare_latents_free_noise(
|
| 636 |
+
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 640 |
+
raise ValueError(
|
| 641 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 642 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
shape = (
|
| 646 |
+
batch_size,
|
| 647 |
+
num_channels_latents,
|
| 648 |
+
num_frames,
|
| 649 |
+
height // self.vae_scale_factor,
|
| 650 |
+
width // self.vae_scale_factor,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
if latents is None:
|
| 654 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 655 |
+
else:
|
| 656 |
+
latents = latents.to(device)
|
| 657 |
+
|
| 658 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 659 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 660 |
+
return latents
|
| 661 |
+
|
| 662 |
+
def prepare_video(
|
| 663 |
+
self,
|
| 664 |
+
video,
|
| 665 |
+
width,
|
| 666 |
+
height,
|
| 667 |
+
batch_size,
|
| 668 |
+
num_videos_per_prompt,
|
| 669 |
+
device,
|
| 670 |
+
dtype,
|
| 671 |
+
do_classifier_free_guidance=False,
|
| 672 |
+
guess_mode=False,
|
| 673 |
+
):
|
| 674 |
+
video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
|
| 675 |
+
dtype=torch.float32
|
| 676 |
+
)
|
| 677 |
+
video = video.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 678 |
+
video_batch_size = video.shape[0]
|
| 679 |
+
|
| 680 |
+
if video_batch_size == 1:
|
| 681 |
+
repeat_by = batch_size
|
| 682 |
+
else:
|
| 683 |
+
# image batch size is the same as prompt batch size
|
| 684 |
+
repeat_by = num_videos_per_prompt
|
| 685 |
+
|
| 686 |
+
video = video.repeat_interleave(repeat_by, dim=0)
|
| 687 |
+
video = video.to(device=device, dtype=dtype)
|
| 688 |
+
|
| 689 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 690 |
+
video = torch.cat([video] * 2)
|
| 691 |
+
|
| 692 |
+
return video
|
| 693 |
+
|
| 694 |
+
@property
|
| 695 |
+
def guidance_scale(self):
|
| 696 |
+
return self._guidance_scale
|
| 697 |
+
|
| 698 |
+
@property
|
| 699 |
+
def clip_skip(self):
|
| 700 |
+
return self._clip_skip
|
| 701 |
+
|
| 702 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 703 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 704 |
+
# corresponds to doing no classifier free guidance.
|
| 705 |
+
@property
|
| 706 |
+
def do_classifier_free_guidance(self):
|
| 707 |
+
return self._guidance_scale > 1
|
| 708 |
+
|
| 709 |
+
@property
|
| 710 |
+
def cross_attention_kwargs(self):
|
| 711 |
+
return self._cross_attention_kwargs
|
| 712 |
+
|
| 713 |
+
@property
|
| 714 |
+
def num_timesteps(self):
|
| 715 |
+
return self._num_timesteps
|
| 716 |
+
|
| 717 |
+
@property
|
| 718 |
+
def interrupt(self):
|
| 719 |
+
return self._interrupt
|
| 720 |
+
|
| 721 |
+
@torch.no_grad()
|
| 722 |
+
def __call__(
|
| 723 |
+
self,
|
| 724 |
+
prompt: Union[str, List[str]] = None,
|
| 725 |
+
num_frames: Optional[int] = 16,
|
| 726 |
+
height: Optional[int] = None,
|
| 727 |
+
width: Optional[int] = None,
|
| 728 |
+
num_inference_steps: int = 50,
|
| 729 |
+
guidance_scale: float = 7.5,
|
| 730 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 731 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 732 |
+
eta: float = 0.0,
|
| 733 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 734 |
+
latents: Optional[torch.Tensor] = None,
|
| 735 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 736 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 737 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 738 |
+
ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
|
| 739 |
+
conditioning_frames: Optional[List[PipelineImageInput]] = None,
|
| 740 |
+
output_type: Optional[str] = "pil",
|
| 741 |
+
return_dict: bool = True,
|
| 742 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 743 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 744 |
+
guess_mode: bool = False,
|
| 745 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 746 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 747 |
+
clip_skip: Optional[int] = None,
|
| 748 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 749 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 750 |
+
decode_chunk_size: int = 16,
|
| 751 |
+
):
|
| 752 |
+
r"""
|
| 753 |
+
The call function to the pipeline for generation.
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 757 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 758 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 759 |
+
The height in pixels of the generated video.
|
| 760 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 761 |
+
The width in pixels of the generated video.
|
| 762 |
+
num_frames (`int`, *optional*, defaults to 16):
|
| 763 |
+
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
| 764 |
+
amounts to 2 seconds of video.
|
| 765 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 766 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
| 767 |
+
expense of slower inference.
|
| 768 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 769 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 770 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 771 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 772 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 773 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 774 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 775 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 776 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 777 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 778 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 779 |
+
generation deterministic.
|
| 780 |
+
latents (`torch.Tensor`, *optional*):
|
| 781 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 782 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 783 |
+
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
| 784 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
| 785 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 786 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 787 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 788 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 789 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 790 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 791 |
+
ip_adapter_image (`PipelineImageInput`, *optional*):
|
| 792 |
+
Optional image input to work with IP Adapters.
|
| 793 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 794 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 795 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 796 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 797 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 798 |
+
conditioning_frames (`List[PipelineImageInput]`, *optional*):
|
| 799 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
|
| 800 |
+
ControlNets are specified, images must be passed as a list such that each element of the list can be
|
| 801 |
+
correctly batched for input to a single ControlNet.
|
| 802 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 803 |
+
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
|
| 804 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 805 |
+
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
|
| 806 |
+
of a plain tuple.
|
| 807 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 808 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 809 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 810 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 811 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
| 812 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
| 813 |
+
the corresponding scale as a list.
|
| 814 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
| 815 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
| 816 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
| 817 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
| 818 |
+
The percentage of total steps at which the ControlNet starts applying.
|
| 819 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 820 |
+
The percentage of total steps at which the ControlNet stops applying.
|
| 821 |
+
clip_skip (`int`, *optional*):
|
| 822 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 823 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 824 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 825 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 826 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 827 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 828 |
+
`callback_on_step_end_tensor_inputs`.
|
| 829 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 830 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 831 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 832 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 833 |
+
|
| 834 |
+
Examples:
|
| 835 |
+
|
| 836 |
+
Returns:
|
| 837 |
+
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
| 838 |
+
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
| 839 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
| 840 |
+
"""
|
| 841 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 842 |
+
|
| 843 |
+
# align format for control guidance
|
| 844 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 845 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 846 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 847 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 848 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 849 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 850 |
+
control_guidance_start, control_guidance_end = (
|
| 851 |
+
mult * [control_guidance_start],
|
| 852 |
+
mult * [control_guidance_end],
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# 0. Default height and width to unet
|
| 856 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 857 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 858 |
+
|
| 859 |
+
num_videos_per_prompt = 1
|
| 860 |
+
|
| 861 |
+
# 1. Check inputs. Raise error if not correct
|
| 862 |
+
self.check_inputs(
|
| 863 |
+
prompt=prompt,
|
| 864 |
+
height=height,
|
| 865 |
+
width=width,
|
| 866 |
+
num_frames=num_frames,
|
| 867 |
+
negative_prompt=negative_prompt,
|
| 868 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 869 |
+
prompt_embeds=prompt_embeds,
|
| 870 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 871 |
+
video=conditioning_frames,
|
| 872 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 873 |
+
control_guidance_start=control_guidance_start,
|
| 874 |
+
control_guidance_end=control_guidance_end,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
self._guidance_scale = guidance_scale
|
| 878 |
+
self._clip_skip = clip_skip
|
| 879 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 880 |
+
self._interrupt = False
|
| 881 |
+
|
| 882 |
+
# 2. Define call parameters
|
| 883 |
+
if prompt is not None and isinstance(prompt, (str, dict)):
|
| 884 |
+
batch_size = 1
|
| 885 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 886 |
+
batch_size = len(prompt)
|
| 887 |
+
else:
|
| 888 |
+
batch_size = prompt_embeds.shape[0]
|
| 889 |
+
|
| 890 |
+
device = self._execution_device
|
| 891 |
+
|
| 892 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 893 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
| 894 |
+
|
| 895 |
+
global_pool_conditions = (
|
| 896 |
+
controlnet.config.global_pool_conditions
|
| 897 |
+
if isinstance(controlnet, ControlNetModel)
|
| 898 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 899 |
+
)
|
| 900 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 901 |
+
|
| 902 |
+
# 3. Encode input prompt
|
| 903 |
+
text_encoder_lora_scale = (
|
| 904 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
| 905 |
+
)
|
| 906 |
+
if self.free_noise_enabled:
|
| 907 |
+
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
| 908 |
+
prompt=prompt,
|
| 909 |
+
num_frames=num_frames,
|
| 910 |
+
device=device,
|
| 911 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 912 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 913 |
+
negative_prompt=negative_prompt,
|
| 914 |
+
prompt_embeds=prompt_embeds,
|
| 915 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 916 |
+
lora_scale=text_encoder_lora_scale,
|
| 917 |
+
clip_skip=self.clip_skip,
|
| 918 |
+
)
|
| 919 |
+
else:
|
| 920 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 921 |
+
prompt,
|
| 922 |
+
device,
|
| 923 |
+
num_videos_per_prompt,
|
| 924 |
+
self.do_classifier_free_guidance,
|
| 925 |
+
negative_prompt,
|
| 926 |
+
prompt_embeds=prompt_embeds,
|
| 927 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 928 |
+
lora_scale=text_encoder_lora_scale,
|
| 929 |
+
clip_skip=self.clip_skip,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 933 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 934 |
+
# to avoid doing two forward passes
|
| 935 |
+
if self.do_classifier_free_guidance:
|
| 936 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 937 |
+
|
| 938 |
+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
| 939 |
+
|
| 940 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 941 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 942 |
+
ip_adapter_image,
|
| 943 |
+
ip_adapter_image_embeds,
|
| 944 |
+
device,
|
| 945 |
+
batch_size * num_videos_per_prompt,
|
| 946 |
+
self.do_classifier_free_guidance,
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
if isinstance(controlnet, ControlNetModel):
|
| 950 |
+
conditioning_frames = self.prepare_video(
|
| 951 |
+
video=conditioning_frames,
|
| 952 |
+
width=width,
|
| 953 |
+
height=height,
|
| 954 |
+
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
| 955 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 956 |
+
device=device,
|
| 957 |
+
dtype=controlnet.dtype,
|
| 958 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 959 |
+
guess_mode=guess_mode,
|
| 960 |
+
)
|
| 961 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
| 962 |
+
cond_prepared_videos = []
|
| 963 |
+
for frame_ in conditioning_frames:
|
| 964 |
+
prepared_video = self.prepare_video(
|
| 965 |
+
video=frame_,
|
| 966 |
+
width=width,
|
| 967 |
+
height=height,
|
| 968 |
+
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
| 969 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 970 |
+
device=device,
|
| 971 |
+
dtype=controlnet.dtype,
|
| 972 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 973 |
+
guess_mode=guess_mode,
|
| 974 |
+
)
|
| 975 |
+
cond_prepared_videos.append(prepared_video)
|
| 976 |
+
conditioning_frames = cond_prepared_videos
|
| 977 |
+
else:
|
| 978 |
+
assert False
|
| 979 |
+
|
| 980 |
+
# 4. Prepare timesteps
|
| 981 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 982 |
+
timesteps = self.scheduler.timesteps
|
| 983 |
+
|
| 984 |
+
# 5. Prepare latent variables
|
| 985 |
+
num_channels_latents = self.unet.config.in_channels
|
| 986 |
+
latents = self.prepare_latents(
|
| 987 |
+
batch_size * num_videos_per_prompt,
|
| 988 |
+
num_channels_latents,
|
| 989 |
+
num_frames,
|
| 990 |
+
height,
|
| 991 |
+
width,
|
| 992 |
+
prompt_embeds.dtype,
|
| 993 |
+
device,
|
| 994 |
+
generator,
|
| 995 |
+
latents,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 999 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1000 |
+
|
| 1001 |
+
# 7. Add image embeds for IP-Adapter
|
| 1002 |
+
added_cond_kwargs = (
|
| 1003 |
+
{"image_embeds": image_embeds}
|
| 1004 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| 1005 |
+
else None
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
# 7.1 Create tensor stating which controlnets to keep
|
| 1009 |
+
controlnet_keep = []
|
| 1010 |
+
for i in range(len(timesteps)):
|
| 1011 |
+
keeps = [
|
| 1012 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 1013 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 1014 |
+
]
|
| 1015 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
| 1016 |
+
|
| 1017 |
+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
| 1018 |
+
for free_init_iter in range(num_free_init_iters):
|
| 1019 |
+
if self.free_init_enabled:
|
| 1020 |
+
latents, timesteps = self._apply_free_init(
|
| 1021 |
+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
self._num_timesteps = len(timesteps)
|
| 1025 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1026 |
+
|
| 1027 |
+
# 8. Denoising loop
|
| 1028 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
| 1029 |
+
for i, t in enumerate(timesteps):
|
| 1030 |
+
if self.interrupt:
|
| 1031 |
+
continue
|
| 1032 |
+
|
| 1033 |
+
# expand the latents if we are doing classifier free guidance
|
| 1034 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1035 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1036 |
+
|
| 1037 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 1038 |
+
# Infer ControlNet only for the conditional batch.
|
| 1039 |
+
control_model_input = latents
|
| 1040 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 1041 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 1042 |
+
else:
|
| 1043 |
+
control_model_input = latent_model_input
|
| 1044 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 1045 |
+
|
| 1046 |
+
if isinstance(controlnet_keep[i], list):
|
| 1047 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 1048 |
+
else:
|
| 1049 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 1050 |
+
if isinstance(controlnet_cond_scale, list):
|
| 1051 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 1052 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 1053 |
+
|
| 1054 |
+
control_model_input = torch.transpose(control_model_input, 1, 2)
|
| 1055 |
+
control_model_input = control_model_input.reshape(
|
| 1056 |
+
(-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1060 |
+
control_model_input,
|
| 1061 |
+
t,
|
| 1062 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 1063 |
+
controlnet_cond=conditioning_frames,
|
| 1064 |
+
conditioning_scale=cond_scale,
|
| 1065 |
+
guess_mode=guess_mode,
|
| 1066 |
+
return_dict=False,
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
# predict the noise residual
|
| 1070 |
+
noise_pred = self.unet(
|
| 1071 |
+
latent_model_input,
|
| 1072 |
+
t,
|
| 1073 |
+
encoder_hidden_states=prompt_embeds,
|
| 1074 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1075 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1076 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 1077 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 1078 |
+
).sample
|
| 1079 |
+
|
| 1080 |
+
# perform guidance
|
| 1081 |
+
if self.do_classifier_free_guidance:
|
| 1082 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1083 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1084 |
+
|
| 1085 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1086 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 1087 |
+
|
| 1088 |
+
if callback_on_step_end is not None:
|
| 1089 |
+
callback_kwargs = {}
|
| 1090 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1091 |
+
callback_kwargs[k] = locals()[k]
|
| 1092 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1093 |
+
|
| 1094 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1095 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1096 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1097 |
+
|
| 1098 |
+
# call the callback, if provided
|
| 1099 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1100 |
+
progress_bar.update()
|
| 1101 |
+
|
| 1102 |
+
if XLA_AVAILABLE:
|
| 1103 |
+
xm.mark_step()
|
| 1104 |
+
|
| 1105 |
+
# 9. Post processing
|
| 1106 |
+
if output_type == "latent":
|
| 1107 |
+
video = latents
|
| 1108 |
+
else:
|
| 1109 |
+
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
| 1110 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
| 1111 |
+
|
| 1112 |
+
# 10. Offload all models
|
| 1113 |
+
self.maybe_free_model_hooks()
|
| 1114 |
+
|
| 1115 |
+
if not return_dict:
|
| 1116 |
+
return (video,)
|
| 1117 |
+
|
| 1118 |
+
return AnimateDiffPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
ADDED
|
@@ -0,0 +1,1309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import (
|
| 20 |
+
CLIPImageProcessor,
|
| 21 |
+
CLIPTextModel,
|
| 22 |
+
CLIPTextModelWithProjection,
|
| 23 |
+
CLIPTokenizer,
|
| 24 |
+
CLIPVisionModelWithProjection,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
from ...image_processor import PipelineImageInput
|
| 28 |
+
from ...loaders import (
|
| 29 |
+
FromSingleFileMixin,
|
| 30 |
+
IPAdapterMixin,
|
| 31 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 32 |
+
TextualInversionLoaderMixin,
|
| 33 |
+
)
|
| 34 |
+
from ...models import AutoencoderKL, ImageProjection, MotionAdapter, UNet2DConditionModel, UNetMotionModel
|
| 35 |
+
from ...models.attention_processor import (
|
| 36 |
+
AttnProcessor2_0,
|
| 37 |
+
FusedAttnProcessor2_0,
|
| 38 |
+
XFormersAttnProcessor,
|
| 39 |
+
)
|
| 40 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 41 |
+
from ...schedulers import (
|
| 42 |
+
DDIMScheduler,
|
| 43 |
+
DPMSolverMultistepScheduler,
|
| 44 |
+
EulerAncestralDiscreteScheduler,
|
| 45 |
+
EulerDiscreteScheduler,
|
| 46 |
+
LMSDiscreteScheduler,
|
| 47 |
+
PNDMScheduler,
|
| 48 |
+
)
|
| 49 |
+
from ...utils import (
|
| 50 |
+
USE_PEFT_BACKEND,
|
| 51 |
+
is_torch_xla_available,
|
| 52 |
+
logging,
|
| 53 |
+
replace_example_docstring,
|
| 54 |
+
scale_lora_layers,
|
| 55 |
+
unscale_lora_layers,
|
| 56 |
+
)
|
| 57 |
+
from ...utils.torch_utils import randn_tensor
|
| 58 |
+
from ...video_processor import VideoProcessor
|
| 59 |
+
from ..free_init_utils import FreeInitMixin
|
| 60 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 61 |
+
from .pipeline_output import AnimateDiffPipelineOutput
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if is_torch_xla_available():
|
| 65 |
+
import torch_xla.core.xla_model as xm
|
| 66 |
+
|
| 67 |
+
XLA_AVAILABLE = True
|
| 68 |
+
else:
|
| 69 |
+
XLA_AVAILABLE = False
|
| 70 |
+
|
| 71 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
EXAMPLE_DOC_STRING = """
|
| 75 |
+
Examples:
|
| 76 |
+
```py
|
| 77 |
+
>>> import torch
|
| 78 |
+
>>> from diffusers.models import MotionAdapter
|
| 79 |
+
>>> from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
|
| 80 |
+
>>> from diffusers.utils import export_to_gif
|
| 81 |
+
|
| 82 |
+
>>> adapter = MotionAdapter.from_pretrained(
|
| 83 |
+
... "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16
|
| 84 |
+
... )
|
| 85 |
+
|
| 86 |
+
>>> model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 87 |
+
>>> scheduler = DDIMScheduler.from_pretrained(
|
| 88 |
+
... model_id,
|
| 89 |
+
... subfolder="scheduler",
|
| 90 |
+
... clip_sample=False,
|
| 91 |
+
... timestep_spacing="linspace",
|
| 92 |
+
... beta_schedule="linear",
|
| 93 |
+
... steps_offset=1,
|
| 94 |
+
... )
|
| 95 |
+
>>> pipe = AnimateDiffSDXLPipeline.from_pretrained(
|
| 96 |
+
... model_id,
|
| 97 |
+
... motion_adapter=adapter,
|
| 98 |
+
... scheduler=scheduler,
|
| 99 |
+
... torch_dtype=torch.float16,
|
| 100 |
+
... variant="fp16",
|
| 101 |
+
... ).to("cuda")
|
| 102 |
+
|
| 103 |
+
>>> # enable memory savings
|
| 104 |
+
>>> pipe.enable_vae_slicing()
|
| 105 |
+
>>> pipe.enable_vae_tiling()
|
| 106 |
+
|
| 107 |
+
>>> output = pipe(
|
| 108 |
+
... prompt="a panda surfing in the ocean, realistic, high quality",
|
| 109 |
+
... negative_prompt="low quality, worst quality",
|
| 110 |
+
... num_inference_steps=20,
|
| 111 |
+
... guidance_scale=8,
|
| 112 |
+
... width=1024,
|
| 113 |
+
... height=1024,
|
| 114 |
+
... num_frames=16,
|
| 115 |
+
... )
|
| 116 |
+
|
| 117 |
+
>>> frames = output.frames[0]
|
| 118 |
+
>>> export_to_gif(frames, "animation.gif")
|
| 119 |
+
```
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
| 124 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 125 |
+
r"""
|
| 126 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 127 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 128 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
noise_cfg (`torch.Tensor`):
|
| 132 |
+
The predicted noise tensor for the guided diffusion process.
|
| 133 |
+
noise_pred_text (`torch.Tensor`):
|
| 134 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 135 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 136 |
+
A rescale factor applied to the noise predictions.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 140 |
+
"""
|
| 141 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 142 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 143 |
+
# rescale the results from guidance (fixes overexposure)
|
| 144 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 145 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 146 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 147 |
+
return noise_cfg
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 151 |
+
def retrieve_timesteps(
|
| 152 |
+
scheduler,
|
| 153 |
+
num_inference_steps: Optional[int] = None,
|
| 154 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 155 |
+
timesteps: Optional[List[int]] = None,
|
| 156 |
+
sigmas: Optional[List[float]] = None,
|
| 157 |
+
**kwargs,
|
| 158 |
+
):
|
| 159 |
+
r"""
|
| 160 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 161 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
scheduler (`SchedulerMixin`):
|
| 165 |
+
The scheduler to get timesteps from.
|
| 166 |
+
num_inference_steps (`int`):
|
| 167 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 168 |
+
must be `None`.
|
| 169 |
+
device (`str` or `torch.device`, *optional*):
|
| 170 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 171 |
+
timesteps (`List[int]`, *optional*):
|
| 172 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 173 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 174 |
+
sigmas (`List[float]`, *optional*):
|
| 175 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 176 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 180 |
+
second element is the number of inference steps.
|
| 181 |
+
"""
|
| 182 |
+
if timesteps is not None and sigmas is not None:
|
| 183 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 184 |
+
if timesteps is not None:
|
| 185 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 186 |
+
if not accepts_timesteps:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 189 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 190 |
+
)
|
| 191 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 192 |
+
timesteps = scheduler.timesteps
|
| 193 |
+
num_inference_steps = len(timesteps)
|
| 194 |
+
elif sigmas is not None:
|
| 195 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 196 |
+
if not accept_sigmas:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 199 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 200 |
+
)
|
| 201 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 202 |
+
timesteps = scheduler.timesteps
|
| 203 |
+
num_inference_steps = len(timesteps)
|
| 204 |
+
else:
|
| 205 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 206 |
+
timesteps = scheduler.timesteps
|
| 207 |
+
return timesteps, num_inference_steps
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class AnimateDiffSDXLPipeline(
|
| 211 |
+
DiffusionPipeline,
|
| 212 |
+
StableDiffusionMixin,
|
| 213 |
+
FromSingleFileMixin,
|
| 214 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 215 |
+
TextualInversionLoaderMixin,
|
| 216 |
+
IPAdapterMixin,
|
| 217 |
+
FreeInitMixin,
|
| 218 |
+
):
|
| 219 |
+
r"""
|
| 220 |
+
Pipeline for text-to-video generation using Stable Diffusion XL.
|
| 221 |
+
|
| 222 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 223 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 224 |
+
|
| 225 |
+
The pipeline also inherits the following loading methods:
|
| 226 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 227 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 228 |
+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 229 |
+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 230 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
vae ([`AutoencoderKL`]):
|
| 234 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 235 |
+
text_encoder ([`CLIPTextModel`]):
|
| 236 |
+
Frozen text-encoder. Stable Diffusion XL uses the text portion of
|
| 237 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 238 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 239 |
+
text_encoder_2 ([` CLIPTextModelWithProjection`]):
|
| 240 |
+
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
|
| 241 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 242 |
+
specifically the
|
| 243 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
| 244 |
+
variant.
|
| 245 |
+
tokenizer (`CLIPTokenizer`):
|
| 246 |
+
Tokenizer of class
|
| 247 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 248 |
+
tokenizer_2 (`CLIPTokenizer`):
|
| 249 |
+
Second Tokenizer of class
|
| 250 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 251 |
+
unet ([`UNet2DConditionModel`]):
|
| 252 |
+
Conditional U-Net architecture to denoise the encoded image latents.
|
| 253 |
+
scheduler ([`SchedulerMixin`]):
|
| 254 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 255 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 256 |
+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
| 257 |
+
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
|
| 258 |
+
`stabilityai/stable-diffusion-xl-base-1-0`.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
| 262 |
+
_optional_components = [
|
| 263 |
+
"tokenizer",
|
| 264 |
+
"tokenizer_2",
|
| 265 |
+
"text_encoder",
|
| 266 |
+
"text_encoder_2",
|
| 267 |
+
"image_encoder",
|
| 268 |
+
"feature_extractor",
|
| 269 |
+
]
|
| 270 |
+
_callback_tensor_inputs = [
|
| 271 |
+
"latents",
|
| 272 |
+
"prompt_embeds",
|
| 273 |
+
"negative_prompt_embeds",
|
| 274 |
+
"add_text_embeds",
|
| 275 |
+
"add_time_ids",
|
| 276 |
+
"negative_pooled_prompt_embeds",
|
| 277 |
+
"negative_add_time_ids",
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
def __init__(
|
| 281 |
+
self,
|
| 282 |
+
vae: AutoencoderKL,
|
| 283 |
+
text_encoder: CLIPTextModel,
|
| 284 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 285 |
+
tokenizer: CLIPTokenizer,
|
| 286 |
+
tokenizer_2: CLIPTokenizer,
|
| 287 |
+
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
| 288 |
+
motion_adapter: MotionAdapter,
|
| 289 |
+
scheduler: Union[
|
| 290 |
+
DDIMScheduler,
|
| 291 |
+
PNDMScheduler,
|
| 292 |
+
LMSDiscreteScheduler,
|
| 293 |
+
EulerDiscreteScheduler,
|
| 294 |
+
EulerAncestralDiscreteScheduler,
|
| 295 |
+
DPMSolverMultistepScheduler,
|
| 296 |
+
],
|
| 297 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 298 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 299 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 300 |
+
):
|
| 301 |
+
super().__init__()
|
| 302 |
+
|
| 303 |
+
if isinstance(unet, UNet2DConditionModel):
|
| 304 |
+
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
| 305 |
+
|
| 306 |
+
self.register_modules(
|
| 307 |
+
vae=vae,
|
| 308 |
+
text_encoder=text_encoder,
|
| 309 |
+
text_encoder_2=text_encoder_2,
|
| 310 |
+
tokenizer=tokenizer,
|
| 311 |
+
tokenizer_2=tokenizer_2,
|
| 312 |
+
unet=unet,
|
| 313 |
+
motion_adapter=motion_adapter,
|
| 314 |
+
scheduler=scheduler,
|
| 315 |
+
image_encoder=image_encoder,
|
| 316 |
+
feature_extractor=feature_extractor,
|
| 317 |
+
)
|
| 318 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 319 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 320 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 321 |
+
|
| 322 |
+
self.default_sample_size = (
|
| 323 |
+
self.unet.config.sample_size
|
| 324 |
+
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
| 325 |
+
else 128
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt
|
| 329 |
+
def encode_prompt(
|
| 330 |
+
self,
|
| 331 |
+
prompt: str,
|
| 332 |
+
prompt_2: Optional[str] = None,
|
| 333 |
+
device: Optional[torch.device] = None,
|
| 334 |
+
num_videos_per_prompt: int = 1,
|
| 335 |
+
do_classifier_free_guidance: bool = True,
|
| 336 |
+
negative_prompt: Optional[str] = None,
|
| 337 |
+
negative_prompt_2: Optional[str] = None,
|
| 338 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 339 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 340 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 341 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 342 |
+
lora_scale: Optional[float] = None,
|
| 343 |
+
clip_skip: Optional[int] = None,
|
| 344 |
+
):
|
| 345 |
+
r"""
|
| 346 |
+
Encodes the prompt into text encoder hidden states.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 350 |
+
prompt to be encoded
|
| 351 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 352 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 353 |
+
used in both text-encoders
|
| 354 |
+
device: (`torch.device`):
|
| 355 |
+
torch device
|
| 356 |
+
num_videos_per_prompt (`int`):
|
| 357 |
+
number of images that should be generated per prompt
|
| 358 |
+
do_classifier_free_guidance (`bool`):
|
| 359 |
+
whether to use classifier free guidance or not
|
| 360 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 361 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 362 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 363 |
+
less than `1`).
|
| 364 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 365 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 366 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
| 367 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 368 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 369 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 370 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 371 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 372 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 373 |
+
argument.
|
| 374 |
+
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
| 375 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 376 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 377 |
+
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
| 378 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 379 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 380 |
+
input argument.
|
| 381 |
+
lora_scale (`float`, *optional*):
|
| 382 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 383 |
+
clip_skip (`int`, *optional*):
|
| 384 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 385 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 386 |
+
"""
|
| 387 |
+
device = device or self._execution_device
|
| 388 |
+
|
| 389 |
+
# set lora scale so that monkey patched LoRA
|
| 390 |
+
# function of text encoder can correctly access it
|
| 391 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
| 392 |
+
self._lora_scale = lora_scale
|
| 393 |
+
|
| 394 |
+
# dynamically adjust the LoRA scale
|
| 395 |
+
if self.text_encoder is not None:
|
| 396 |
+
if not USE_PEFT_BACKEND:
|
| 397 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 398 |
+
else:
|
| 399 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 400 |
+
|
| 401 |
+
if self.text_encoder_2 is not None:
|
| 402 |
+
if not USE_PEFT_BACKEND:
|
| 403 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
| 404 |
+
else:
|
| 405 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 406 |
+
|
| 407 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 408 |
+
|
| 409 |
+
if prompt is not None:
|
| 410 |
+
batch_size = len(prompt)
|
| 411 |
+
else:
|
| 412 |
+
batch_size = prompt_embeds.shape[0]
|
| 413 |
+
|
| 414 |
+
# Define tokenizers and text encoders
|
| 415 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
| 416 |
+
text_encoders = (
|
| 417 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if prompt_embeds is None:
|
| 421 |
+
prompt_2 = prompt_2 or prompt
|
| 422 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 423 |
+
|
| 424 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 425 |
+
prompt_embeds_list = []
|
| 426 |
+
prompts = [prompt, prompt_2]
|
| 427 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
| 428 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 429 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
| 430 |
+
|
| 431 |
+
text_inputs = tokenizer(
|
| 432 |
+
prompt,
|
| 433 |
+
padding="max_length",
|
| 434 |
+
max_length=tokenizer.model_max_length,
|
| 435 |
+
truncation=True,
|
| 436 |
+
return_tensors="pt",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
text_input_ids = text_inputs.input_ids
|
| 440 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 441 |
+
|
| 442 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 443 |
+
text_input_ids, untruncated_ids
|
| 444 |
+
):
|
| 445 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
| 446 |
+
logger.warning(
|
| 447 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 448 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 452 |
+
|
| 453 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 454 |
+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
| 455 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 456 |
+
|
| 457 |
+
if clip_skip is None:
|
| 458 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 459 |
+
else:
|
| 460 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
| 461 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
| 462 |
+
|
| 463 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 464 |
+
|
| 465 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 466 |
+
|
| 467 |
+
# get unconditional embeddings for classifier free guidance
|
| 468 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
| 469 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
| 470 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
| 471 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
| 472 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 473 |
+
negative_prompt = negative_prompt or ""
|
| 474 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
| 475 |
+
|
| 476 |
+
# normalize str to list
|
| 477 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 478 |
+
negative_prompt_2 = (
|
| 479 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
uncond_tokens: List[str]
|
| 483 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 484 |
+
raise TypeError(
|
| 485 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 486 |
+
f" {type(prompt)}."
|
| 487 |
+
)
|
| 488 |
+
elif batch_size != len(negative_prompt):
|
| 489 |
+
raise ValueError(
|
| 490 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 491 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 492 |
+
" the batch size of `prompt`."
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
| 496 |
+
|
| 497 |
+
negative_prompt_embeds_list = []
|
| 498 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
| 499 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 500 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
| 501 |
+
|
| 502 |
+
max_length = prompt_embeds.shape[1]
|
| 503 |
+
uncond_input = tokenizer(
|
| 504 |
+
negative_prompt,
|
| 505 |
+
padding="max_length",
|
| 506 |
+
max_length=max_length,
|
| 507 |
+
truncation=True,
|
| 508 |
+
return_tensors="pt",
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
negative_prompt_embeds = text_encoder(
|
| 512 |
+
uncond_input.input_ids.to(device),
|
| 513 |
+
output_hidden_states=True,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 517 |
+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
| 518 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
| 519 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
| 520 |
+
|
| 521 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
| 522 |
+
|
| 523 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
| 524 |
+
|
| 525 |
+
if self.text_encoder_2 is not None:
|
| 526 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
| 527 |
+
else:
|
| 528 |
+
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
| 529 |
+
|
| 530 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 531 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 532 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 533 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
| 534 |
+
|
| 535 |
+
if do_classifier_free_guidance:
|
| 536 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 537 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 538 |
+
|
| 539 |
+
if self.text_encoder_2 is not None:
|
| 540 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
| 541 |
+
else:
|
| 542 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
| 543 |
+
|
| 544 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 545 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 546 |
+
|
| 547 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view(
|
| 548 |
+
bs_embed * num_videos_per_prompt, -1
|
| 549 |
+
)
|
| 550 |
+
if do_classifier_free_guidance:
|
| 551 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view(
|
| 552 |
+
bs_embed * num_videos_per_prompt, -1
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
if self.text_encoder is not None:
|
| 556 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 557 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 558 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 559 |
+
|
| 560 |
+
if self.text_encoder_2 is not None:
|
| 561 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 562 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 563 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 564 |
+
|
| 565 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 566 |
+
|
| 567 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 568 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 569 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 570 |
+
|
| 571 |
+
if not isinstance(image, torch.Tensor):
|
| 572 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 573 |
+
|
| 574 |
+
image = image.to(device=device, dtype=dtype)
|
| 575 |
+
if output_hidden_states:
|
| 576 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 577 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 578 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 579 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 580 |
+
).hidden_states[-2]
|
| 581 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 582 |
+
num_images_per_prompt, dim=0
|
| 583 |
+
)
|
| 584 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 585 |
+
else:
|
| 586 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 587 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 588 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 589 |
+
|
| 590 |
+
return image_embeds, uncond_image_embeds
|
| 591 |
+
|
| 592 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 593 |
+
def prepare_ip_adapter_image_embeds(
|
| 594 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 595 |
+
):
|
| 596 |
+
image_embeds = []
|
| 597 |
+
if do_classifier_free_guidance:
|
| 598 |
+
negative_image_embeds = []
|
| 599 |
+
if ip_adapter_image_embeds is None:
|
| 600 |
+
if not isinstance(ip_adapter_image, list):
|
| 601 |
+
ip_adapter_image = [ip_adapter_image]
|
| 602 |
+
|
| 603 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 604 |
+
raise ValueError(
|
| 605 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 609 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 610 |
+
):
|
| 611 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 612 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 613 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 617 |
+
if do_classifier_free_guidance:
|
| 618 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 619 |
+
else:
|
| 620 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 621 |
+
if do_classifier_free_guidance:
|
| 622 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 623 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 624 |
+
image_embeds.append(single_image_embeds)
|
| 625 |
+
|
| 626 |
+
ip_adapter_image_embeds = []
|
| 627 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 628 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 629 |
+
if do_classifier_free_guidance:
|
| 630 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 631 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 632 |
+
|
| 633 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 634 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 635 |
+
|
| 636 |
+
return ip_adapter_image_embeds
|
| 637 |
+
|
| 638 |
+
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
| 639 |
+
def decode_latents(self, latents):
|
| 640 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 641 |
+
|
| 642 |
+
batch_size, channels, num_frames, height, width = latents.shape
|
| 643 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
| 644 |
+
|
| 645 |
+
image = self.vae.decode(latents).sample
|
| 646 |
+
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 647 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 648 |
+
video = video.float()
|
| 649 |
+
return video
|
| 650 |
+
|
| 651 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 652 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 653 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 654 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 655 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 656 |
+
# and should be between [0, 1]
|
| 657 |
+
|
| 658 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 659 |
+
extra_step_kwargs = {}
|
| 660 |
+
if accepts_eta:
|
| 661 |
+
extra_step_kwargs["eta"] = eta
|
| 662 |
+
|
| 663 |
+
# check if the scheduler accepts generator
|
| 664 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 665 |
+
if accepts_generator:
|
| 666 |
+
extra_step_kwargs["generator"] = generator
|
| 667 |
+
return extra_step_kwargs
|
| 668 |
+
|
| 669 |
+
def check_inputs(
|
| 670 |
+
self,
|
| 671 |
+
prompt,
|
| 672 |
+
prompt_2,
|
| 673 |
+
height,
|
| 674 |
+
width,
|
| 675 |
+
negative_prompt=None,
|
| 676 |
+
negative_prompt_2=None,
|
| 677 |
+
prompt_embeds=None,
|
| 678 |
+
negative_prompt_embeds=None,
|
| 679 |
+
pooled_prompt_embeds=None,
|
| 680 |
+
negative_pooled_prompt_embeds=None,
|
| 681 |
+
callback_on_step_end_tensor_inputs=None,
|
| 682 |
+
):
|
| 683 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 684 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 685 |
+
|
| 686 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 687 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 688 |
+
):
|
| 689 |
+
raise ValueError(
|
| 690 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
if prompt is not None and prompt_embeds is not None:
|
| 694 |
+
raise ValueError(
|
| 695 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 696 |
+
" only forward one of the two."
|
| 697 |
+
)
|
| 698 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 699 |
+
raise ValueError(
|
| 700 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 701 |
+
" only forward one of the two."
|
| 702 |
+
)
|
| 703 |
+
elif prompt is None and prompt_embeds is None:
|
| 704 |
+
raise ValueError(
|
| 705 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 706 |
+
)
|
| 707 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 708 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 709 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 710 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 711 |
+
|
| 712 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 713 |
+
raise ValueError(
|
| 714 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 715 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 716 |
+
)
|
| 717 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 718 |
+
raise ValueError(
|
| 719 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 720 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 724 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 725 |
+
raise ValueError(
|
| 726 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 727 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 728 |
+
f" {negative_prompt_embeds.shape}."
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 732 |
+
raise ValueError(
|
| 733 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 737 |
+
raise ValueError(
|
| 738 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
| 742 |
+
def prepare_latents(
|
| 743 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 744 |
+
):
|
| 745 |
+
shape = (
|
| 746 |
+
batch_size,
|
| 747 |
+
num_channels_latents,
|
| 748 |
+
num_frames,
|
| 749 |
+
height // self.vae_scale_factor,
|
| 750 |
+
width // self.vae_scale_factor,
|
| 751 |
+
)
|
| 752 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 753 |
+
raise ValueError(
|
| 754 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 755 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
if latents is None:
|
| 759 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 760 |
+
else:
|
| 761 |
+
latents = latents.to(device)
|
| 762 |
+
|
| 763 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 764 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 765 |
+
return latents
|
| 766 |
+
|
| 767 |
+
def _get_add_time_ids(
|
| 768 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
| 769 |
+
):
|
| 770 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 771 |
+
|
| 772 |
+
passed_add_embed_dim = (
|
| 773 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 774 |
+
)
|
| 775 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
| 776 |
+
|
| 777 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 778 |
+
raise ValueError(
|
| 779 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
| 783 |
+
return add_time_ids
|
| 784 |
+
|
| 785 |
+
def upcast_vae(self):
|
| 786 |
+
dtype = self.vae.dtype
|
| 787 |
+
self.vae.to(dtype=torch.float32)
|
| 788 |
+
use_torch_2_0_or_xformers = isinstance(
|
| 789 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
| 790 |
+
(
|
| 791 |
+
AttnProcessor2_0,
|
| 792 |
+
XFormersAttnProcessor,
|
| 793 |
+
FusedAttnProcessor2_0,
|
| 794 |
+
),
|
| 795 |
+
)
|
| 796 |
+
# if xformers or torch_2_0 is used attention block does not need
|
| 797 |
+
# to be in float32 which can save lots of memory
|
| 798 |
+
if use_torch_2_0_or_xformers:
|
| 799 |
+
self.vae.post_quant_conv.to(dtype)
|
| 800 |
+
self.vae.decoder.conv_in.to(dtype)
|
| 801 |
+
self.vae.decoder.mid_block.to(dtype)
|
| 802 |
+
|
| 803 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
| 804 |
+
def get_guidance_scale_embedding(
|
| 805 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
| 806 |
+
) -> torch.Tensor:
|
| 807 |
+
"""
|
| 808 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
w (`torch.Tensor`):
|
| 812 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
| 813 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 814 |
+
Dimension of the embeddings to generate.
|
| 815 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
| 816 |
+
Data type of the generated embeddings.
|
| 817 |
+
|
| 818 |
+
Returns:
|
| 819 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
| 820 |
+
"""
|
| 821 |
+
assert len(w.shape) == 1
|
| 822 |
+
w = w * 1000.0
|
| 823 |
+
|
| 824 |
+
half_dim = embedding_dim // 2
|
| 825 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 826 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 827 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 828 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 829 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 830 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 831 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 832 |
+
return emb
|
| 833 |
+
|
| 834 |
+
@property
|
| 835 |
+
def guidance_scale(self):
|
| 836 |
+
return self._guidance_scale
|
| 837 |
+
|
| 838 |
+
@property
|
| 839 |
+
def guidance_rescale(self):
|
| 840 |
+
return self._guidance_rescale
|
| 841 |
+
|
| 842 |
+
@property
|
| 843 |
+
def clip_skip(self):
|
| 844 |
+
return self._clip_skip
|
| 845 |
+
|
| 846 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 847 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 848 |
+
# corresponds to doing no classifier free guidance.
|
| 849 |
+
@property
|
| 850 |
+
def do_classifier_free_guidance(self):
|
| 851 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 852 |
+
|
| 853 |
+
@property
|
| 854 |
+
def cross_attention_kwargs(self):
|
| 855 |
+
return self._cross_attention_kwargs
|
| 856 |
+
|
| 857 |
+
@property
|
| 858 |
+
def denoising_end(self):
|
| 859 |
+
return self._denoising_end
|
| 860 |
+
|
| 861 |
+
@property
|
| 862 |
+
def num_timesteps(self):
|
| 863 |
+
return self._num_timesteps
|
| 864 |
+
|
| 865 |
+
@property
|
| 866 |
+
def interrupt(self):
|
| 867 |
+
return self._interrupt
|
| 868 |
+
|
| 869 |
+
@torch.no_grad()
|
| 870 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 871 |
+
def __call__(
|
| 872 |
+
self,
|
| 873 |
+
prompt: Union[str, List[str]] = None,
|
| 874 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 875 |
+
num_frames: int = 16,
|
| 876 |
+
height: Optional[int] = None,
|
| 877 |
+
width: Optional[int] = None,
|
| 878 |
+
num_inference_steps: int = 50,
|
| 879 |
+
timesteps: List[int] = None,
|
| 880 |
+
sigmas: List[float] = None,
|
| 881 |
+
denoising_end: Optional[float] = None,
|
| 882 |
+
guidance_scale: float = 5.0,
|
| 883 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 884 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 885 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 886 |
+
eta: float = 0.0,
|
| 887 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 888 |
+
latents: Optional[torch.Tensor] = None,
|
| 889 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 890 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 891 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 892 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 893 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 894 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 895 |
+
output_type: Optional[str] = "pil",
|
| 896 |
+
return_dict: bool = True,
|
| 897 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 898 |
+
guidance_rescale: float = 0.0,
|
| 899 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 900 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 901 |
+
target_size: Optional[Tuple[int, int]] = None,
|
| 902 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 903 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 904 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 905 |
+
clip_skip: Optional[int] = None,
|
| 906 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 907 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 908 |
+
):
|
| 909 |
+
r"""
|
| 910 |
+
Function invoked when calling the pipeline for generation.
|
| 911 |
+
|
| 912 |
+
Args:
|
| 913 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 914 |
+
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
|
| 915 |
+
instead.
|
| 916 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 917 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 918 |
+
used in both text-encoders
|
| 919 |
+
num_frames:
|
| 920 |
+
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
| 921 |
+
amounts to 2 seconds of video.
|
| 922 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 923 |
+
The height in pixels of the generated video. This is set to 1024 by default for the best results.
|
| 924 |
+
Anything below 512 pixels won't work well for
|
| 925 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 926 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 927 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 928 |
+
The width in pixels of the generated video. This is set to 1024 by default for the best results.
|
| 929 |
+
Anything below 512 pixels won't work well for
|
| 930 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 931 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 932 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 933 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 934 |
+
expense of slower inference.
|
| 935 |
+
timesteps (`List[int]`, *optional*):
|
| 936 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 937 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 938 |
+
passed will be used. Must be in descending order.
|
| 939 |
+
sigmas (`List[float]`, *optional*):
|
| 940 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 941 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 942 |
+
will be used.
|
| 943 |
+
denoising_end (`float`, *optional*):
|
| 944 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
| 945 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
| 946 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
| 947 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
| 948 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
| 949 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
| 950 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 951 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 952 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 953 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 954 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 955 |
+
the text `prompt`, usually at the expense of lower video quality.
|
| 956 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 957 |
+
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
| 958 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 959 |
+
less than `1`).
|
| 960 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 961 |
+
The prompt or prompts not to guide the video generation to be sent to `tokenizer_2` and
|
| 962 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
| 963 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 964 |
+
The number of videos to generate per prompt.
|
| 965 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 966 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
|
| 967 |
+
applies to [`schedulers.DDIMScheduler`], will be ignored for others.
|
| 968 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 969 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 970 |
+
to make generation deterministic.
|
| 971 |
+
latents (`torch.Tensor`, *optional*):
|
| 972 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
| 973 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 974 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 975 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 976 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 977 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 978 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 979 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 980 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 981 |
+
argument.
|
| 982 |
+
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
| 983 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 984 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 985 |
+
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
| 986 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 987 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 988 |
+
input argument.
|
| 989 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
| 990 |
+
Optional image input to work with IP Adapters.
|
| 991 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 992 |
+
Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the
|
| 993 |
+
`ip_adapter_image` input argument.
|
| 994 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 995 |
+
The output format of the generated video. Choose between
|
| 996 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 997 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 998 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead of a
|
| 999 |
+
plain tuple.
|
| 1000 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 1001 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 1002 |
+
`self.processor` in
|
| 1003 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1004 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 1005 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
| 1006 |
+
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
| 1007 |
+
[Common Diffusion Noise Schedules and Sample Steps are
|
| 1008 |
+
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
| 1009 |
+
using zero terminal SNR.
|
| 1010 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 1011 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 1012 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 1013 |
+
explained in section 2.2 of
|
| 1014 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 1015 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 1016 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 1017 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 1018 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 1019 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 1020 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 1021 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
| 1022 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
| 1023 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 1024 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 1025 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
| 1026 |
+
micro-conditioning as explained in section 2.2 of
|
| 1027 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 1028 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 1029 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 1030 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
| 1031 |
+
micro-conditioning as explained in section 2.2 of
|
| 1032 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 1033 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 1034 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 1035 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
| 1036 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 1037 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 1038 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 1039 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 1040 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 1041 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 1042 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 1043 |
+
`callback_on_step_end_tensor_inputs`.
|
| 1044 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1045 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 1046 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 1047 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 1048 |
+
|
| 1049 |
+
Examples:
|
| 1050 |
+
|
| 1051 |
+
Returns:
|
| 1052 |
+
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
| 1053 |
+
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
| 1054 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
| 1055 |
+
"""
|
| 1056 |
+
|
| 1057 |
+
# 0. Default height and width to unet
|
| 1058 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 1059 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 1060 |
+
|
| 1061 |
+
num_videos_per_prompt = 1
|
| 1062 |
+
|
| 1063 |
+
original_size = original_size or (height, width)
|
| 1064 |
+
target_size = target_size or (height, width)
|
| 1065 |
+
|
| 1066 |
+
# 1. Check inputs. Raise error if not correct
|
| 1067 |
+
self.check_inputs(
|
| 1068 |
+
prompt,
|
| 1069 |
+
prompt_2,
|
| 1070 |
+
height,
|
| 1071 |
+
width,
|
| 1072 |
+
negative_prompt,
|
| 1073 |
+
negative_prompt_2,
|
| 1074 |
+
prompt_embeds,
|
| 1075 |
+
negative_prompt_embeds,
|
| 1076 |
+
pooled_prompt_embeds,
|
| 1077 |
+
negative_pooled_prompt_embeds,
|
| 1078 |
+
callback_on_step_end_tensor_inputs,
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
self._guidance_scale = guidance_scale
|
| 1082 |
+
self._guidance_rescale = guidance_rescale
|
| 1083 |
+
self._clip_skip = clip_skip
|
| 1084 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 1085 |
+
self._denoising_end = denoising_end
|
| 1086 |
+
self._interrupt = False
|
| 1087 |
+
|
| 1088 |
+
# 2. Define call parameters
|
| 1089 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1090 |
+
batch_size = 1
|
| 1091 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1092 |
+
batch_size = len(prompt)
|
| 1093 |
+
else:
|
| 1094 |
+
batch_size = prompt_embeds.shape[0]
|
| 1095 |
+
|
| 1096 |
+
device = self._execution_device
|
| 1097 |
+
|
| 1098 |
+
# 3. Encode input prompt
|
| 1099 |
+
lora_scale = (
|
| 1100 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
(
|
| 1104 |
+
prompt_embeds,
|
| 1105 |
+
negative_prompt_embeds,
|
| 1106 |
+
pooled_prompt_embeds,
|
| 1107 |
+
negative_pooled_prompt_embeds,
|
| 1108 |
+
) = self.encode_prompt(
|
| 1109 |
+
prompt=prompt,
|
| 1110 |
+
prompt_2=prompt_2,
|
| 1111 |
+
device=device,
|
| 1112 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 1113 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1114 |
+
negative_prompt=negative_prompt,
|
| 1115 |
+
negative_prompt_2=negative_prompt_2,
|
| 1116 |
+
prompt_embeds=prompt_embeds,
|
| 1117 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1118 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1119 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1120 |
+
lora_scale=lora_scale,
|
| 1121 |
+
clip_skip=self.clip_skip,
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
# 4. Prepare timesteps
|
| 1125 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1126 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
# 5. Prepare latent variables
|
| 1130 |
+
num_channels_latents = self.unet.config.in_channels
|
| 1131 |
+
latents = self.prepare_latents(
|
| 1132 |
+
batch_size * num_videos_per_prompt,
|
| 1133 |
+
num_channels_latents,
|
| 1134 |
+
num_frames,
|
| 1135 |
+
height,
|
| 1136 |
+
width,
|
| 1137 |
+
prompt_embeds.dtype,
|
| 1138 |
+
device,
|
| 1139 |
+
generator,
|
| 1140 |
+
latents,
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1144 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1145 |
+
|
| 1146 |
+
# 7. Prepare added time ids & embeddings
|
| 1147 |
+
add_text_embeds = pooled_prompt_embeds
|
| 1148 |
+
if self.text_encoder_2 is None:
|
| 1149 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 1150 |
+
else:
|
| 1151 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 1152 |
+
|
| 1153 |
+
add_time_ids = self._get_add_time_ids(
|
| 1154 |
+
original_size,
|
| 1155 |
+
crops_coords_top_left,
|
| 1156 |
+
target_size,
|
| 1157 |
+
dtype=prompt_embeds.dtype,
|
| 1158 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 1159 |
+
)
|
| 1160 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 1161 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 1162 |
+
negative_original_size,
|
| 1163 |
+
negative_crops_coords_top_left,
|
| 1164 |
+
negative_target_size,
|
| 1165 |
+
dtype=prompt_embeds.dtype,
|
| 1166 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 1167 |
+
)
|
| 1168 |
+
else:
|
| 1169 |
+
negative_add_time_ids = add_time_ids
|
| 1170 |
+
|
| 1171 |
+
if self.do_classifier_free_guidance:
|
| 1172 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1173 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 1174 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 1175 |
+
|
| 1176 |
+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
| 1177 |
+
|
| 1178 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 1179 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 1180 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
|
| 1181 |
+
|
| 1182 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1183 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1184 |
+
ip_adapter_image,
|
| 1185 |
+
ip_adapter_image_embeds,
|
| 1186 |
+
device,
|
| 1187 |
+
batch_size * num_videos_per_prompt,
|
| 1188 |
+
self.do_classifier_free_guidance,
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
# 7.1 Apply denoising_end
|
| 1192 |
+
if (
|
| 1193 |
+
self.denoising_end is not None
|
| 1194 |
+
and isinstance(self.denoising_end, float)
|
| 1195 |
+
and self.denoising_end > 0
|
| 1196 |
+
and self.denoising_end < 1
|
| 1197 |
+
):
|
| 1198 |
+
discrete_timestep_cutoff = int(
|
| 1199 |
+
round(
|
| 1200 |
+
self.scheduler.config.num_train_timesteps
|
| 1201 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
| 1202 |
+
)
|
| 1203 |
+
)
|
| 1204 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
| 1205 |
+
timesteps = timesteps[:num_inference_steps]
|
| 1206 |
+
|
| 1207 |
+
# 8. Optionally get Guidance Scale Embedding
|
| 1208 |
+
timestep_cond = None
|
| 1209 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 1210 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_videos_per_prompt)
|
| 1211 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 1212 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 1213 |
+
).to(device=device, dtype=latents.dtype)
|
| 1214 |
+
|
| 1215 |
+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
| 1216 |
+
for free_init_iter in range(num_free_init_iters):
|
| 1217 |
+
if self.free_init_enabled:
|
| 1218 |
+
latents, timesteps = self._apply_free_init(
|
| 1219 |
+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
self._num_timesteps = len(timesteps)
|
| 1223 |
+
|
| 1224 |
+
# 9. Denoising loop
|
| 1225 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
| 1226 |
+
for i, t in enumerate(timesteps):
|
| 1227 |
+
if self.interrupt:
|
| 1228 |
+
continue
|
| 1229 |
+
|
| 1230 |
+
# expand the latents if we are doing classifier free guidance
|
| 1231 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1232 |
+
|
| 1233 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1234 |
+
|
| 1235 |
+
# predict the noise residual
|
| 1236 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 1237 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds:
|
| 1238 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
| 1239 |
+
|
| 1240 |
+
noise_pred = self.unet(
|
| 1241 |
+
latent_model_input,
|
| 1242 |
+
t,
|
| 1243 |
+
encoder_hidden_states=prompt_embeds,
|
| 1244 |
+
timestep_cond=timestep_cond,
|
| 1245 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1246 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1247 |
+
return_dict=False,
|
| 1248 |
+
)[0]
|
| 1249 |
+
|
| 1250 |
+
# perform guidance
|
| 1251 |
+
if self.do_classifier_free_guidance:
|
| 1252 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1253 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1254 |
+
|
| 1255 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 1256 |
+
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
| 1257 |
+
noise_pred = rescale_noise_cfg(
|
| 1258 |
+
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1262 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1263 |
+
|
| 1264 |
+
if callback_on_step_end is not None:
|
| 1265 |
+
callback_kwargs = {}
|
| 1266 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1267 |
+
callback_kwargs[k] = locals()[k]
|
| 1268 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1269 |
+
|
| 1270 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1271 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1272 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1273 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
| 1274 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 1275 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
| 1276 |
+
)
|
| 1277 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
| 1278 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
| 1279 |
+
|
| 1280 |
+
progress_bar.update()
|
| 1281 |
+
|
| 1282 |
+
if XLA_AVAILABLE:
|
| 1283 |
+
xm.mark_step()
|
| 1284 |
+
|
| 1285 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 1286 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 1287 |
+
|
| 1288 |
+
if needs_upcasting:
|
| 1289 |
+
self.upcast_vae()
|
| 1290 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 1291 |
+
|
| 1292 |
+
# 10. Post processing
|
| 1293 |
+
if output_type == "latent":
|
| 1294 |
+
video = latents
|
| 1295 |
+
else:
|
| 1296 |
+
video_tensor = self.decode_latents(latents)
|
| 1297 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
| 1298 |
+
|
| 1299 |
+
# cast back to fp16 if needed
|
| 1300 |
+
if needs_upcasting:
|
| 1301 |
+
self.vae.to(dtype=torch.float16)
|
| 1302 |
+
|
| 1303 |
+
# 11. Offload all models
|
| 1304 |
+
self.maybe_free_model_hooks()
|
| 1305 |
+
|
| 1306 |
+
if not return_dict:
|
| 1307 |
+
return (video,)
|
| 1308 |
+
|
| 1309 |
+
return AnimateDiffPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
ADDED
|
@@ -0,0 +1,1023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 23 |
+
|
| 24 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 25 |
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
| 26 |
+
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
| 27 |
+
from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
|
| 28 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 29 |
+
from ...models.unets.unet_motion_model import MotionAdapter
|
| 30 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 31 |
+
from ...utils import (
|
| 32 |
+
USE_PEFT_BACKEND,
|
| 33 |
+
is_torch_xla_available,
|
| 34 |
+
logging,
|
| 35 |
+
replace_example_docstring,
|
| 36 |
+
scale_lora_layers,
|
| 37 |
+
unscale_lora_layers,
|
| 38 |
+
)
|
| 39 |
+
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
| 40 |
+
from ...video_processor import VideoProcessor
|
| 41 |
+
from ..free_init_utils import FreeInitMixin
|
| 42 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 43 |
+
from .pipeline_output import AnimateDiffPipelineOutput
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if is_torch_xla_available():
|
| 47 |
+
import torch_xla.core.xla_model as xm
|
| 48 |
+
|
| 49 |
+
XLA_AVAILABLE = True
|
| 50 |
+
else:
|
| 51 |
+
XLA_AVAILABLE = False
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
EXAMPLE_DOC_STRING = """
|
| 57 |
+
Examples:
|
| 58 |
+
```python
|
| 59 |
+
>>> import torch
|
| 60 |
+
>>> from diffusers import AnimateDiffSparseControlNetPipeline
|
| 61 |
+
>>> from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
|
| 62 |
+
>>> from diffusers.schedulers import DPMSolverMultistepScheduler
|
| 63 |
+
>>> from diffusers.utils import export_to_gif, load_image
|
| 64 |
+
|
| 65 |
+
>>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
| 66 |
+
>>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
|
| 67 |
+
>>> controlnet_id = "guoyww/animatediff-sparsectrl-scribble"
|
| 68 |
+
>>> lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3"
|
| 69 |
+
>>> vae_id = "stabilityai/sd-vae-ft-mse"
|
| 70 |
+
>>> device = "cuda"
|
| 71 |
+
|
| 72 |
+
>>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
|
| 73 |
+
>>> controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
|
| 74 |
+
>>> vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
|
| 75 |
+
>>> scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
| 76 |
+
... model_id,
|
| 77 |
+
... subfolder="scheduler",
|
| 78 |
+
... beta_schedule="linear",
|
| 79 |
+
... algorithm_type="dpmsolver++",
|
| 80 |
+
... use_karras_sigmas=True,
|
| 81 |
+
... )
|
| 82 |
+
>>> pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
|
| 83 |
+
... model_id,
|
| 84 |
+
... motion_adapter=motion_adapter,
|
| 85 |
+
... controlnet=controlnet,
|
| 86 |
+
... vae=vae,
|
| 87 |
+
... scheduler=scheduler,
|
| 88 |
+
... torch_dtype=torch.float16,
|
| 89 |
+
... ).to(device)
|
| 90 |
+
>>> pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
|
| 91 |
+
>>> pipe.fuse_lora(lora_scale=1.0)
|
| 92 |
+
|
| 93 |
+
>>> prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"
|
| 94 |
+
>>> negative_prompt = "low quality, worst quality, letterboxed"
|
| 95 |
+
|
| 96 |
+
>>> image_files = [
|
| 97 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png",
|
| 98 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png",
|
| 99 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png",
|
| 100 |
+
... ]
|
| 101 |
+
>>> condition_frame_indices = [0, 8, 15]
|
| 102 |
+
>>> conditioning_frames = [load_image(img_file) for img_file in image_files]
|
| 103 |
+
|
| 104 |
+
>>> video = pipe(
|
| 105 |
+
... prompt=prompt,
|
| 106 |
+
... negative_prompt=negative_prompt,
|
| 107 |
+
... num_inference_steps=25,
|
| 108 |
+
... conditioning_frames=conditioning_frames,
|
| 109 |
+
... controlnet_conditioning_scale=1.0,
|
| 110 |
+
... controlnet_frame_indices=condition_frame_indices,
|
| 111 |
+
... generator=torch.Generator().manual_seed(1337),
|
| 112 |
+
... ).frames[0]
|
| 113 |
+
>>> export_to_gif(video, "output.gif")
|
| 114 |
+
```
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 119 |
+
def retrieve_latents(
|
| 120 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 121 |
+
):
|
| 122 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 123 |
+
return encoder_output.latent_dist.sample(generator)
|
| 124 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 125 |
+
return encoder_output.latent_dist.mode()
|
| 126 |
+
elif hasattr(encoder_output, "latents"):
|
| 127 |
+
return encoder_output.latents
|
| 128 |
+
else:
|
| 129 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class AnimateDiffSparseControlNetPipeline(
|
| 133 |
+
DiffusionPipeline,
|
| 134 |
+
StableDiffusionMixin,
|
| 135 |
+
TextualInversionLoaderMixin,
|
| 136 |
+
IPAdapterMixin,
|
| 137 |
+
StableDiffusionLoraLoaderMixin,
|
| 138 |
+
FreeInitMixin,
|
| 139 |
+
FromSingleFileMixin,
|
| 140 |
+
):
|
| 141 |
+
r"""
|
| 142 |
+
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
|
| 143 |
+
to Text-to-Video Diffusion Models](https://huggingface.co/papers/2311.16933).
|
| 144 |
+
|
| 145 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 146 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 147 |
+
|
| 148 |
+
The pipeline also inherits the following loading methods:
|
| 149 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 150 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 151 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 152 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
vae ([`AutoencoderKL`]):
|
| 156 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 157 |
+
text_encoder ([`CLIPTextModel`]):
|
| 158 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 159 |
+
tokenizer (`CLIPTokenizer`):
|
| 160 |
+
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
| 161 |
+
unet ([`UNet2DConditionModel`]):
|
| 162 |
+
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
|
| 163 |
+
motion_adapter ([`MotionAdapter`]):
|
| 164 |
+
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
|
| 165 |
+
scheduler ([`SchedulerMixin`]):
|
| 166 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 167 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 171 |
+
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
|
| 172 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
vae: AutoencoderKL,
|
| 177 |
+
text_encoder: CLIPTextModel,
|
| 178 |
+
tokenizer: CLIPTokenizer,
|
| 179 |
+
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
| 180 |
+
motion_adapter: MotionAdapter,
|
| 181 |
+
controlnet: SparseControlNetModel,
|
| 182 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 183 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 184 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
if isinstance(unet, UNet2DConditionModel):
|
| 188 |
+
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
| 189 |
+
|
| 190 |
+
self.register_modules(
|
| 191 |
+
vae=vae,
|
| 192 |
+
text_encoder=text_encoder,
|
| 193 |
+
tokenizer=tokenizer,
|
| 194 |
+
unet=unet,
|
| 195 |
+
motion_adapter=motion_adapter,
|
| 196 |
+
controlnet=controlnet,
|
| 197 |
+
scheduler=scheduler,
|
| 198 |
+
feature_extractor=feature_extractor,
|
| 199 |
+
image_encoder=image_encoder,
|
| 200 |
+
)
|
| 201 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 202 |
+
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
| 203 |
+
self.control_image_processor = VaeImageProcessor(
|
| 204 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
| 208 |
+
def encode_prompt(
|
| 209 |
+
self,
|
| 210 |
+
prompt,
|
| 211 |
+
device,
|
| 212 |
+
num_images_per_prompt,
|
| 213 |
+
do_classifier_free_guidance,
|
| 214 |
+
negative_prompt=None,
|
| 215 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 216 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 217 |
+
lora_scale: Optional[float] = None,
|
| 218 |
+
clip_skip: Optional[int] = None,
|
| 219 |
+
):
|
| 220 |
+
r"""
|
| 221 |
+
Encodes the prompt into text encoder hidden states.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 225 |
+
prompt to be encoded
|
| 226 |
+
device: (`torch.device`):
|
| 227 |
+
torch device
|
| 228 |
+
num_images_per_prompt (`int`):
|
| 229 |
+
number of images that should be generated per prompt
|
| 230 |
+
do_classifier_free_guidance (`bool`):
|
| 231 |
+
whether to use classifier free guidance or not
|
| 232 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 233 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 234 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 235 |
+
less than `1`).
|
| 236 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 237 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 238 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 239 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 240 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 241 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 242 |
+
argument.
|
| 243 |
+
lora_scale (`float`, *optional*):
|
| 244 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 245 |
+
clip_skip (`int`, *optional*):
|
| 246 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 247 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 248 |
+
"""
|
| 249 |
+
# set lora scale so that monkey patched LoRA
|
| 250 |
+
# function of text encoder can correctly access it
|
| 251 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 252 |
+
self._lora_scale = lora_scale
|
| 253 |
+
|
| 254 |
+
# dynamically adjust the LoRA scale
|
| 255 |
+
if not USE_PEFT_BACKEND:
|
| 256 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 257 |
+
else:
|
| 258 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 259 |
+
|
| 260 |
+
if prompt is not None and isinstance(prompt, str):
|
| 261 |
+
batch_size = 1
|
| 262 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 263 |
+
batch_size = len(prompt)
|
| 264 |
+
else:
|
| 265 |
+
batch_size = prompt_embeds.shape[0]
|
| 266 |
+
|
| 267 |
+
if prompt_embeds is None:
|
| 268 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 269 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 270 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 271 |
+
|
| 272 |
+
text_inputs = self.tokenizer(
|
| 273 |
+
prompt,
|
| 274 |
+
padding="max_length",
|
| 275 |
+
max_length=self.tokenizer.model_max_length,
|
| 276 |
+
truncation=True,
|
| 277 |
+
return_tensors="pt",
|
| 278 |
+
)
|
| 279 |
+
text_input_ids = text_inputs.input_ids
|
| 280 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 281 |
+
|
| 282 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 283 |
+
text_input_ids, untruncated_ids
|
| 284 |
+
):
|
| 285 |
+
removed_text = self.tokenizer.batch_decode(
|
| 286 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 287 |
+
)
|
| 288 |
+
logger.warning(
|
| 289 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 290 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 294 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 295 |
+
else:
|
| 296 |
+
attention_mask = None
|
| 297 |
+
|
| 298 |
+
if clip_skip is None:
|
| 299 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 300 |
+
prompt_embeds = prompt_embeds[0]
|
| 301 |
+
else:
|
| 302 |
+
prompt_embeds = self.text_encoder(
|
| 303 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 304 |
+
)
|
| 305 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 306 |
+
# all the hidden states from the encoder layers. Then index into
|
| 307 |
+
# the tuple to access the hidden states from the desired layer.
|
| 308 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 309 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 310 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 311 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 312 |
+
# layer.
|
| 313 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 314 |
+
|
| 315 |
+
if self.text_encoder is not None:
|
| 316 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 317 |
+
elif self.unet is not None:
|
| 318 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 319 |
+
else:
|
| 320 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 321 |
+
|
| 322 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 323 |
+
|
| 324 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 325 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 326 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 327 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 328 |
+
|
| 329 |
+
# get unconditional embeddings for classifier free guidance
|
| 330 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 331 |
+
uncond_tokens: List[str]
|
| 332 |
+
if negative_prompt is None:
|
| 333 |
+
uncond_tokens = [""] * batch_size
|
| 334 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 335 |
+
raise TypeError(
|
| 336 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 337 |
+
f" {type(prompt)}."
|
| 338 |
+
)
|
| 339 |
+
elif isinstance(negative_prompt, str):
|
| 340 |
+
uncond_tokens = [negative_prompt]
|
| 341 |
+
elif batch_size != len(negative_prompt):
|
| 342 |
+
raise ValueError(
|
| 343 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 344 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 345 |
+
" the batch size of `prompt`."
|
| 346 |
+
)
|
| 347 |
+
else:
|
| 348 |
+
uncond_tokens = negative_prompt
|
| 349 |
+
|
| 350 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 351 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 352 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 353 |
+
|
| 354 |
+
max_length = prompt_embeds.shape[1]
|
| 355 |
+
uncond_input = self.tokenizer(
|
| 356 |
+
uncond_tokens,
|
| 357 |
+
padding="max_length",
|
| 358 |
+
max_length=max_length,
|
| 359 |
+
truncation=True,
|
| 360 |
+
return_tensors="pt",
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 364 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 365 |
+
else:
|
| 366 |
+
attention_mask = None
|
| 367 |
+
|
| 368 |
+
negative_prompt_embeds = self.text_encoder(
|
| 369 |
+
uncond_input.input_ids.to(device),
|
| 370 |
+
attention_mask=attention_mask,
|
| 371 |
+
)
|
| 372 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 373 |
+
|
| 374 |
+
if do_classifier_free_guidance:
|
| 375 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 376 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 377 |
+
|
| 378 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 379 |
+
|
| 380 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 381 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 382 |
+
|
| 383 |
+
if self.text_encoder is not None:
|
| 384 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 385 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 386 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 387 |
+
|
| 388 |
+
return prompt_embeds, negative_prompt_embeds
|
| 389 |
+
|
| 390 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 391 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 392 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 393 |
+
|
| 394 |
+
if not isinstance(image, torch.Tensor):
|
| 395 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 396 |
+
|
| 397 |
+
image = image.to(device=device, dtype=dtype)
|
| 398 |
+
if output_hidden_states:
|
| 399 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 400 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 401 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 402 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 403 |
+
).hidden_states[-2]
|
| 404 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 405 |
+
num_images_per_prompt, dim=0
|
| 406 |
+
)
|
| 407 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 408 |
+
else:
|
| 409 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 410 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 411 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 412 |
+
|
| 413 |
+
return image_embeds, uncond_image_embeds
|
| 414 |
+
|
| 415 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 416 |
+
def prepare_ip_adapter_image_embeds(
|
| 417 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 418 |
+
):
|
| 419 |
+
image_embeds = []
|
| 420 |
+
if do_classifier_free_guidance:
|
| 421 |
+
negative_image_embeds = []
|
| 422 |
+
if ip_adapter_image_embeds is None:
|
| 423 |
+
if not isinstance(ip_adapter_image, list):
|
| 424 |
+
ip_adapter_image = [ip_adapter_image]
|
| 425 |
+
|
| 426 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 427 |
+
raise ValueError(
|
| 428 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 432 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 433 |
+
):
|
| 434 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 435 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 436 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 440 |
+
if do_classifier_free_guidance:
|
| 441 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 442 |
+
else:
|
| 443 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 444 |
+
if do_classifier_free_guidance:
|
| 445 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 446 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 447 |
+
image_embeds.append(single_image_embeds)
|
| 448 |
+
|
| 449 |
+
ip_adapter_image_embeds = []
|
| 450 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 451 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 452 |
+
if do_classifier_free_guidance:
|
| 453 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 454 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 455 |
+
|
| 456 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 457 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 458 |
+
|
| 459 |
+
return ip_adapter_image_embeds
|
| 460 |
+
|
| 461 |
+
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
| 462 |
+
def decode_latents(self, latents):
|
| 463 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 464 |
+
|
| 465 |
+
batch_size, channels, num_frames, height, width = latents.shape
|
| 466 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
| 467 |
+
|
| 468 |
+
image = self.vae.decode(latents).sample
|
| 469 |
+
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 470 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 471 |
+
video = video.float()
|
| 472 |
+
return video
|
| 473 |
+
|
| 474 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 475 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 476 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 477 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 478 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 479 |
+
# and should be between [0, 1]
|
| 480 |
+
|
| 481 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 482 |
+
extra_step_kwargs = {}
|
| 483 |
+
if accepts_eta:
|
| 484 |
+
extra_step_kwargs["eta"] = eta
|
| 485 |
+
|
| 486 |
+
# check if the scheduler accepts generator
|
| 487 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 488 |
+
if accepts_generator:
|
| 489 |
+
extra_step_kwargs["generator"] = generator
|
| 490 |
+
return extra_step_kwargs
|
| 491 |
+
|
| 492 |
+
def check_inputs(
|
| 493 |
+
self,
|
| 494 |
+
prompt,
|
| 495 |
+
height,
|
| 496 |
+
width,
|
| 497 |
+
negative_prompt=None,
|
| 498 |
+
prompt_embeds=None,
|
| 499 |
+
negative_prompt_embeds=None,
|
| 500 |
+
ip_adapter_image=None,
|
| 501 |
+
ip_adapter_image_embeds=None,
|
| 502 |
+
callback_on_step_end_tensor_inputs=None,
|
| 503 |
+
image=None,
|
| 504 |
+
controlnet_conditioning_scale: float = 1.0,
|
| 505 |
+
):
|
| 506 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 507 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 508 |
+
|
| 509 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 510 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 511 |
+
):
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if prompt is not None and prompt_embeds is not None:
|
| 517 |
+
raise ValueError(
|
| 518 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 519 |
+
" only forward one of the two."
|
| 520 |
+
)
|
| 521 |
+
elif prompt is None and prompt_embeds is None:
|
| 522 |
+
raise ValueError(
|
| 523 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 524 |
+
)
|
| 525 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 526 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 527 |
+
|
| 528 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 529 |
+
raise ValueError(
|
| 530 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 531 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 535 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 536 |
+
raise ValueError(
|
| 537 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 538 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 539 |
+
f" {negative_prompt_embeds.shape}."
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| 543 |
+
raise ValueError(
|
| 544 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if ip_adapter_image_embeds is not None:
|
| 548 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 549 |
+
raise ValueError(
|
| 550 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| 551 |
+
)
|
| 552 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| 553 |
+
raise ValueError(
|
| 554 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
| 558 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# check `image`
|
| 562 |
+
if (
|
| 563 |
+
isinstance(self.controlnet, SparseControlNetModel)
|
| 564 |
+
or is_compiled
|
| 565 |
+
and isinstance(self.controlnet._orig_mod, SparseControlNetModel)
|
| 566 |
+
):
|
| 567 |
+
if isinstance(image, list):
|
| 568 |
+
for image_ in image:
|
| 569 |
+
self.check_image(image_, prompt, prompt_embeds)
|
| 570 |
+
else:
|
| 571 |
+
self.check_image(image, prompt, prompt_embeds)
|
| 572 |
+
else:
|
| 573 |
+
assert False
|
| 574 |
+
|
| 575 |
+
# Check `controlnet_conditioning_scale`
|
| 576 |
+
if (
|
| 577 |
+
isinstance(self.controlnet, SparseControlNetModel)
|
| 578 |
+
or is_compiled
|
| 579 |
+
and isinstance(self.controlnet._orig_mod, SparseControlNetModel)
|
| 580 |
+
):
|
| 581 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
| 582 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
| 583 |
+
else:
|
| 584 |
+
assert False
|
| 585 |
+
|
| 586 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
| 587 |
+
def check_image(self, image, prompt, prompt_embeds):
|
| 588 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
| 589 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
| 590 |
+
image_is_np = isinstance(image, np.ndarray)
|
| 591 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
| 592 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
| 593 |
+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
| 594 |
+
|
| 595 |
+
if (
|
| 596 |
+
not image_is_pil
|
| 597 |
+
and not image_is_tensor
|
| 598 |
+
and not image_is_np
|
| 599 |
+
and not image_is_pil_list
|
| 600 |
+
and not image_is_tensor_list
|
| 601 |
+
and not image_is_np_list
|
| 602 |
+
):
|
| 603 |
+
raise TypeError(
|
| 604 |
+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if image_is_pil:
|
| 608 |
+
image_batch_size = 1
|
| 609 |
+
else:
|
| 610 |
+
image_batch_size = len(image)
|
| 611 |
+
|
| 612 |
+
if prompt is not None and isinstance(prompt, str):
|
| 613 |
+
prompt_batch_size = 1
|
| 614 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 615 |
+
prompt_batch_size = len(prompt)
|
| 616 |
+
elif prompt_embeds is not None:
|
| 617 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
| 618 |
+
|
| 619 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
| 620 |
+
raise ValueError(
|
| 621 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
| 625 |
+
def prepare_latents(
|
| 626 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 627 |
+
):
|
| 628 |
+
shape = (
|
| 629 |
+
batch_size,
|
| 630 |
+
num_channels_latents,
|
| 631 |
+
num_frames,
|
| 632 |
+
height // self.vae_scale_factor,
|
| 633 |
+
width // self.vae_scale_factor,
|
| 634 |
+
)
|
| 635 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 636 |
+
raise ValueError(
|
| 637 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 638 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if latents is None:
|
| 642 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 643 |
+
else:
|
| 644 |
+
latents = latents.to(device)
|
| 645 |
+
|
| 646 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 647 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 648 |
+
return latents
|
| 649 |
+
|
| 650 |
+
def prepare_image(self, image, width, height, device, dtype):
|
| 651 |
+
image = self.control_image_processor.preprocess(image, height=height, width=width)
|
| 652 |
+
controlnet_images = image.unsqueeze(0).to(device, dtype)
|
| 653 |
+
batch_size, num_frames, channels, height, width = controlnet_images.shape
|
| 654 |
+
|
| 655 |
+
# TODO: remove below line
|
| 656 |
+
assert controlnet_images.min() >= 0 and controlnet_images.max() <= 1
|
| 657 |
+
|
| 658 |
+
if self.controlnet.use_simplified_condition_embedding:
|
| 659 |
+
controlnet_images = controlnet_images.reshape(batch_size * num_frames, channels, height, width)
|
| 660 |
+
controlnet_images = 2 * controlnet_images - 1
|
| 661 |
+
conditioning_frames = retrieve_latents(self.vae.encode(controlnet_images)) * self.vae.config.scaling_factor
|
| 662 |
+
conditioning_frames = conditioning_frames.reshape(
|
| 663 |
+
batch_size, num_frames, 4, height // self.vae_scale_factor, width // self.vae_scale_factor
|
| 664 |
+
)
|
| 665 |
+
else:
|
| 666 |
+
conditioning_frames = controlnet_images
|
| 667 |
+
|
| 668 |
+
conditioning_frames = conditioning_frames.permute(0, 2, 1, 3, 4) # [b, c, f, h, w]
|
| 669 |
+
return conditioning_frames
|
| 670 |
+
|
| 671 |
+
def prepare_sparse_control_conditioning(
|
| 672 |
+
self,
|
| 673 |
+
conditioning_frames: torch.Tensor,
|
| 674 |
+
num_frames: int,
|
| 675 |
+
controlnet_frame_indices: int,
|
| 676 |
+
device: torch.device,
|
| 677 |
+
dtype: torch.dtype,
|
| 678 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 679 |
+
assert conditioning_frames.shape[2] >= len(controlnet_frame_indices)
|
| 680 |
+
|
| 681 |
+
batch_size, channels, _, height, width = conditioning_frames.shape
|
| 682 |
+
controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width), dtype=dtype, device=device)
|
| 683 |
+
controlnet_cond_mask = torch.zeros((batch_size, 1, num_frames, height, width), dtype=dtype, device=device)
|
| 684 |
+
controlnet_cond[:, :, controlnet_frame_indices] = conditioning_frames[:, :, : len(controlnet_frame_indices)]
|
| 685 |
+
controlnet_cond_mask[:, :, controlnet_frame_indices] = 1
|
| 686 |
+
|
| 687 |
+
return controlnet_cond, controlnet_cond_mask
|
| 688 |
+
|
| 689 |
+
@property
|
| 690 |
+
def guidance_scale(self):
|
| 691 |
+
return self._guidance_scale
|
| 692 |
+
|
| 693 |
+
@property
|
| 694 |
+
def clip_skip(self):
|
| 695 |
+
return self._clip_skip
|
| 696 |
+
|
| 697 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 698 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 699 |
+
# corresponds to doing no classifier free guidance.
|
| 700 |
+
@property
|
| 701 |
+
def do_classifier_free_guidance(self):
|
| 702 |
+
return self._guidance_scale > 1
|
| 703 |
+
|
| 704 |
+
@property
|
| 705 |
+
def cross_attention_kwargs(self):
|
| 706 |
+
return self._cross_attention_kwargs
|
| 707 |
+
|
| 708 |
+
@property
|
| 709 |
+
def num_timesteps(self):
|
| 710 |
+
return self._num_timesteps
|
| 711 |
+
|
| 712 |
+
@torch.no_grad()
|
| 713 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 714 |
+
def __call__(
|
| 715 |
+
self,
|
| 716 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 717 |
+
height: Optional[int] = None,
|
| 718 |
+
width: Optional[int] = None,
|
| 719 |
+
num_frames: int = 16,
|
| 720 |
+
num_inference_steps: int = 50,
|
| 721 |
+
guidance_scale: float = 7.5,
|
| 722 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 723 |
+
num_videos_per_prompt: int = 1,
|
| 724 |
+
eta: float = 0.0,
|
| 725 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 726 |
+
latents: Optional[torch.Tensor] = None,
|
| 727 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 728 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 729 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 730 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 731 |
+
conditioning_frames: Optional[List[PipelineImageInput]] = None,
|
| 732 |
+
output_type: str = "pil",
|
| 733 |
+
return_dict: bool = True,
|
| 734 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 735 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 736 |
+
controlnet_frame_indices: List[int] = [0],
|
| 737 |
+
guess_mode: bool = False,
|
| 738 |
+
clip_skip: Optional[int] = None,
|
| 739 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 740 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 741 |
+
):
|
| 742 |
+
r"""
|
| 743 |
+
The call function to the pipeline for generation.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 747 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 748 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 749 |
+
The height in pixels of the generated video.
|
| 750 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 751 |
+
The width in pixels of the generated video.
|
| 752 |
+
num_frames (`int`, *optional*, defaults to 16):
|
| 753 |
+
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
| 754 |
+
amounts to 2 seconds of video.
|
| 755 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 756 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
| 757 |
+
expense of slower inference.
|
| 758 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 759 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 760 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 761 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 762 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 763 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 764 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 765 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 766 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 767 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 768 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 769 |
+
generation deterministic.
|
| 770 |
+
latents (`torch.Tensor`, *optional*):
|
| 771 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 772 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 773 |
+
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
| 774 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
| 775 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 776 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 777 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 778 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 779 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 780 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 781 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
| 782 |
+
Optional image input to work with IP Adapters.
|
| 783 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 784 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 785 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 786 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 787 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 788 |
+
conditioning_frames (`List[PipelineImageInput]`, *optional*):
|
| 789 |
+
The SparseControlNet input to provide guidance to the `unet` for generation.
|
| 790 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 791 |
+
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
|
| 792 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 793 |
+
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
|
| 794 |
+
of a plain tuple.
|
| 795 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 796 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 797 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 798 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 799 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
| 800 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
| 801 |
+
the corresponding scale as a list.
|
| 802 |
+
controlnet_frame_indices (`List[int]`):
|
| 803 |
+
The indices where the conditioning frames must be applied for generation. Multiple frames can be
|
| 804 |
+
provided to guide the model to generate similar structure outputs, where the `unet` can
|
| 805 |
+
"fill-in-the-gaps" for interpolation videos, or a single frame could be provided for general expected
|
| 806 |
+
structure. Must have the same length as `conditioning_frames`.
|
| 807 |
+
clip_skip (`int`, *optional*):
|
| 808 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 809 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 810 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 811 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 812 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 813 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 814 |
+
`callback_on_step_end_tensor_inputs`.
|
| 815 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 816 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 817 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 818 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 819 |
+
|
| 820 |
+
Examples:
|
| 821 |
+
|
| 822 |
+
Returns:
|
| 823 |
+
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
| 824 |
+
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
| 825 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
| 826 |
+
"""
|
| 827 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 828 |
+
|
| 829 |
+
# 0. Default height and width to unet
|
| 830 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 831 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 832 |
+
num_videos_per_prompt = 1
|
| 833 |
+
|
| 834 |
+
# 1. Check inputs. Raise error if not correct
|
| 835 |
+
self.check_inputs(
|
| 836 |
+
prompt=prompt,
|
| 837 |
+
height=height,
|
| 838 |
+
width=width,
|
| 839 |
+
negative_prompt=negative_prompt,
|
| 840 |
+
prompt_embeds=prompt_embeds,
|
| 841 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 842 |
+
ip_adapter_image=ip_adapter_image,
|
| 843 |
+
ip_adapter_image_embeds=ip_adapter_image_embeds,
|
| 844 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 845 |
+
image=conditioning_frames,
|
| 846 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
self._guidance_scale = guidance_scale
|
| 850 |
+
self._clip_skip = clip_skip
|
| 851 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 852 |
+
|
| 853 |
+
# 2. Define call parameters
|
| 854 |
+
if prompt is not None and isinstance(prompt, str):
|
| 855 |
+
batch_size = 1
|
| 856 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 857 |
+
batch_size = len(prompt)
|
| 858 |
+
else:
|
| 859 |
+
batch_size = prompt_embeds.shape[0]
|
| 860 |
+
|
| 861 |
+
device = self._execution_device
|
| 862 |
+
|
| 863 |
+
global_pool_conditions = (
|
| 864 |
+
controlnet.config.global_pool_conditions
|
| 865 |
+
if isinstance(controlnet, SparseControlNetModel)
|
| 866 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 867 |
+
)
|
| 868 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 869 |
+
|
| 870 |
+
# 3. Encode input prompt
|
| 871 |
+
text_encoder_lora_scale = (
|
| 872 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 873 |
+
)
|
| 874 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 875 |
+
prompt,
|
| 876 |
+
device,
|
| 877 |
+
num_videos_per_prompt,
|
| 878 |
+
self.do_classifier_free_guidance,
|
| 879 |
+
negative_prompt,
|
| 880 |
+
prompt_embeds=prompt_embeds,
|
| 881 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 882 |
+
lora_scale=text_encoder_lora_scale,
|
| 883 |
+
clip_skip=self.clip_skip,
|
| 884 |
+
)
|
| 885 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 886 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 887 |
+
# to avoid doing two forward passes
|
| 888 |
+
if self.do_classifier_free_guidance:
|
| 889 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 890 |
+
|
| 891 |
+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
| 892 |
+
|
| 893 |
+
# 4. Prepare IP-Adapter embeddings
|
| 894 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 895 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 896 |
+
ip_adapter_image,
|
| 897 |
+
ip_adapter_image_embeds,
|
| 898 |
+
device,
|
| 899 |
+
batch_size * num_videos_per_prompt,
|
| 900 |
+
self.do_classifier_free_guidance,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
# 5. Prepare controlnet conditioning
|
| 904 |
+
conditioning_frames = self.prepare_image(conditioning_frames, width, height, device, controlnet.dtype)
|
| 905 |
+
controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning(
|
| 906 |
+
conditioning_frames, num_frames, controlnet_frame_indices, device, controlnet.dtype
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# 6. Prepare timesteps
|
| 910 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 911 |
+
timesteps = self.scheduler.timesteps
|
| 912 |
+
|
| 913 |
+
# 7. Prepare latent variables
|
| 914 |
+
num_channels_latents = self.unet.config.in_channels
|
| 915 |
+
latents = self.prepare_latents(
|
| 916 |
+
batch_size * num_videos_per_prompt,
|
| 917 |
+
num_channels_latents,
|
| 918 |
+
num_frames,
|
| 919 |
+
height,
|
| 920 |
+
width,
|
| 921 |
+
prompt_embeds.dtype,
|
| 922 |
+
device,
|
| 923 |
+
generator,
|
| 924 |
+
latents,
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 928 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 929 |
+
|
| 930 |
+
# 9. Add image embeds for IP-Adapter
|
| 931 |
+
added_cond_kwargs = (
|
| 932 |
+
{"image_embeds": image_embeds}
|
| 933 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| 934 |
+
else None
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
| 938 |
+
for free_init_iter in range(num_free_init_iters):
|
| 939 |
+
if self.free_init_enabled:
|
| 940 |
+
latents, timesteps = self._apply_free_init(
|
| 941 |
+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
self._num_timesteps = len(timesteps)
|
| 945 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 946 |
+
|
| 947 |
+
# 10. Denoising loop
|
| 948 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
| 949 |
+
for i, t in enumerate(timesteps):
|
| 950 |
+
# expand the latents if we are doing classifier free guidance
|
| 951 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 952 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 953 |
+
|
| 954 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 955 |
+
# Infer SparseControlNetModel only for the conditional batch.
|
| 956 |
+
control_model_input = latents
|
| 957 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 958 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 959 |
+
else:
|
| 960 |
+
control_model_input = latent_model_input
|
| 961 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 962 |
+
|
| 963 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 964 |
+
control_model_input,
|
| 965 |
+
t,
|
| 966 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 967 |
+
controlnet_cond=controlnet_cond,
|
| 968 |
+
conditioning_mask=controlnet_cond_mask,
|
| 969 |
+
conditioning_scale=controlnet_conditioning_scale,
|
| 970 |
+
guess_mode=guess_mode,
|
| 971 |
+
return_dict=False,
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
# predict the noise residual
|
| 975 |
+
noise_pred = self.unet(
|
| 976 |
+
latent_model_input,
|
| 977 |
+
t,
|
| 978 |
+
encoder_hidden_states=prompt_embeds,
|
| 979 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 980 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 981 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 982 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 983 |
+
).sample
|
| 984 |
+
|
| 985 |
+
# perform guidance
|
| 986 |
+
if self.do_classifier_free_guidance:
|
| 987 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 988 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 989 |
+
|
| 990 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 991 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 992 |
+
|
| 993 |
+
if callback_on_step_end is not None:
|
| 994 |
+
callback_kwargs = {}
|
| 995 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 996 |
+
callback_kwargs[k] = locals()[k]
|
| 997 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 998 |
+
|
| 999 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1000 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1001 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1002 |
+
|
| 1003 |
+
# call the callback, if provided
|
| 1004 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1005 |
+
progress_bar.update()
|
| 1006 |
+
|
| 1007 |
+
if XLA_AVAILABLE:
|
| 1008 |
+
xm.mark_step()
|
| 1009 |
+
|
| 1010 |
+
# 11. Post processing
|
| 1011 |
+
if output_type == "latent":
|
| 1012 |
+
video = latents
|
| 1013 |
+
else:
|
| 1014 |
+
video_tensor = self.decode_latents(latents)
|
| 1015 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
| 1016 |
+
|
| 1017 |
+
# 12. Offload all models
|
| 1018 |
+
self.maybe_free_model_hooks()
|
| 1019 |
+
|
| 1020 |
+
if not return_dict:
|
| 1021 |
+
return (video,)
|
| 1022 |
+
|
| 1023 |
+
return AnimateDiffPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
ADDED
|
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 20 |
+
|
| 21 |
+
from ...image_processor import PipelineImageInput
|
| 22 |
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
| 23 |
+
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
| 24 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 25 |
+
from ...models.unets.unet_motion_model import MotionAdapter
|
| 26 |
+
from ...schedulers import (
|
| 27 |
+
DDIMScheduler,
|
| 28 |
+
DPMSolverMultistepScheduler,
|
| 29 |
+
EulerAncestralDiscreteScheduler,
|
| 30 |
+
EulerDiscreteScheduler,
|
| 31 |
+
LMSDiscreteScheduler,
|
| 32 |
+
PNDMScheduler,
|
| 33 |
+
)
|
| 34 |
+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
|
| 35 |
+
from ...utils.torch_utils import randn_tensor
|
| 36 |
+
from ...video_processor import VideoProcessor
|
| 37 |
+
from ..free_init_utils import FreeInitMixin
|
| 38 |
+
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
| 39 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 40 |
+
from .pipeline_output import AnimateDiffPipelineOutput
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if is_torch_xla_available():
|
| 44 |
+
import torch_xla.core.xla_model as xm
|
| 45 |
+
|
| 46 |
+
XLA_AVAILABLE = True
|
| 47 |
+
else:
|
| 48 |
+
XLA_AVAILABLE = False
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
EXAMPLE_DOC_STRING = """
|
| 54 |
+
Examples:
|
| 55 |
+
```py
|
| 56 |
+
>>> import imageio
|
| 57 |
+
>>> import requests
|
| 58 |
+
>>> import torch
|
| 59 |
+
>>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
|
| 60 |
+
>>> from diffusers.utils import export_to_gif
|
| 61 |
+
>>> from io import BytesIO
|
| 62 |
+
>>> from PIL import Image
|
| 63 |
+
|
| 64 |
+
>>> adapter = MotionAdapter.from_pretrained(
|
| 65 |
+
... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16
|
| 66 |
+
... )
|
| 67 |
+
>>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
|
| 68 |
+
... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter
|
| 69 |
+
... ).to("cuda")
|
| 70 |
+
>>> pipe.scheduler = DDIMScheduler(
|
| 71 |
+
... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace"
|
| 72 |
+
... )
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
>>> def load_video(file_path: str):
|
| 76 |
+
... images = []
|
| 77 |
+
|
| 78 |
+
... if file_path.startswith(("http://", "https://")):
|
| 79 |
+
... # If the file_path is a URL
|
| 80 |
+
... response = requests.get(file_path)
|
| 81 |
+
... response.raise_for_status()
|
| 82 |
+
... content = BytesIO(response.content)
|
| 83 |
+
... vid = imageio.get_reader(content)
|
| 84 |
+
... else:
|
| 85 |
+
... # Assuming it's a local file path
|
| 86 |
+
... vid = imageio.get_reader(file_path)
|
| 87 |
+
|
| 88 |
+
... for frame in vid:
|
| 89 |
+
... pil_image = Image.fromarray(frame)
|
| 90 |
+
... images.append(pil_image)
|
| 91 |
+
|
| 92 |
+
... return images
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
>>> video = load_video(
|
| 96 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
|
| 97 |
+
... )
|
| 98 |
+
>>> output = pipe(
|
| 99 |
+
... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5
|
| 100 |
+
... )
|
| 101 |
+
>>> frames = output.frames[0]
|
| 102 |
+
>>> export_to_gif(frames, "animation.gif")
|
| 103 |
+
```
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 108 |
+
def retrieve_latents(
|
| 109 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 110 |
+
):
|
| 111 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 112 |
+
return encoder_output.latent_dist.sample(generator)
|
| 113 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 114 |
+
return encoder_output.latent_dist.mode()
|
| 115 |
+
elif hasattr(encoder_output, "latents"):
|
| 116 |
+
return encoder_output.latents
|
| 117 |
+
else:
|
| 118 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 122 |
+
def retrieve_timesteps(
|
| 123 |
+
scheduler,
|
| 124 |
+
num_inference_steps: Optional[int] = None,
|
| 125 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 126 |
+
timesteps: Optional[List[int]] = None,
|
| 127 |
+
sigmas: Optional[List[float]] = None,
|
| 128 |
+
**kwargs,
|
| 129 |
+
):
|
| 130 |
+
r"""
|
| 131 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 132 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
scheduler (`SchedulerMixin`):
|
| 136 |
+
The scheduler to get timesteps from.
|
| 137 |
+
num_inference_steps (`int`):
|
| 138 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 139 |
+
must be `None`.
|
| 140 |
+
device (`str` or `torch.device`, *optional*):
|
| 141 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 142 |
+
timesteps (`List[int]`, *optional*):
|
| 143 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 144 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 145 |
+
sigmas (`List[float]`, *optional*):
|
| 146 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 147 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 151 |
+
second element is the number of inference steps.
|
| 152 |
+
"""
|
| 153 |
+
if timesteps is not None and sigmas is not None:
|
| 154 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 155 |
+
if timesteps is not None:
|
| 156 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 157 |
+
if not accepts_timesteps:
|
| 158 |
+
raise ValueError(
|
| 159 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 160 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 161 |
+
)
|
| 162 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 163 |
+
timesteps = scheduler.timesteps
|
| 164 |
+
num_inference_steps = len(timesteps)
|
| 165 |
+
elif sigmas is not None:
|
| 166 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 167 |
+
if not accept_sigmas:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 170 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 171 |
+
)
|
| 172 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 173 |
+
timesteps = scheduler.timesteps
|
| 174 |
+
num_inference_steps = len(timesteps)
|
| 175 |
+
else:
|
| 176 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 177 |
+
timesteps = scheduler.timesteps
|
| 178 |
+
return timesteps, num_inference_steps
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class AnimateDiffVideoToVideoPipeline(
|
| 182 |
+
DiffusionPipeline,
|
| 183 |
+
StableDiffusionMixin,
|
| 184 |
+
TextualInversionLoaderMixin,
|
| 185 |
+
IPAdapterMixin,
|
| 186 |
+
StableDiffusionLoraLoaderMixin,
|
| 187 |
+
FreeInitMixin,
|
| 188 |
+
AnimateDiffFreeNoiseMixin,
|
| 189 |
+
FromSingleFileMixin,
|
| 190 |
+
):
|
| 191 |
+
r"""
|
| 192 |
+
Pipeline for video-to-video generation.
|
| 193 |
+
|
| 194 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 195 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 196 |
+
|
| 197 |
+
The pipeline also inherits the following loading methods:
|
| 198 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 199 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 200 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 201 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
vae ([`AutoencoderKL`]):
|
| 205 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 206 |
+
text_encoder ([`CLIPTextModel`]):
|
| 207 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 208 |
+
tokenizer (`CLIPTokenizer`):
|
| 209 |
+
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
| 210 |
+
unet ([`UNet2DConditionModel`]):
|
| 211 |
+
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
|
| 212 |
+
motion_adapter ([`MotionAdapter`]):
|
| 213 |
+
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
|
| 214 |
+
scheduler ([`SchedulerMixin`]):
|
| 215 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 216 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 220 |
+
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
|
| 221 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
vae: AutoencoderKL,
|
| 226 |
+
text_encoder: CLIPTextModel,
|
| 227 |
+
tokenizer: CLIPTokenizer,
|
| 228 |
+
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
| 229 |
+
motion_adapter: MotionAdapter,
|
| 230 |
+
scheduler: Union[
|
| 231 |
+
DDIMScheduler,
|
| 232 |
+
PNDMScheduler,
|
| 233 |
+
LMSDiscreteScheduler,
|
| 234 |
+
EulerDiscreteScheduler,
|
| 235 |
+
EulerAncestralDiscreteScheduler,
|
| 236 |
+
DPMSolverMultistepScheduler,
|
| 237 |
+
],
|
| 238 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 239 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 240 |
+
):
|
| 241 |
+
super().__init__()
|
| 242 |
+
if isinstance(unet, UNet2DConditionModel):
|
| 243 |
+
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
| 244 |
+
|
| 245 |
+
self.register_modules(
|
| 246 |
+
vae=vae,
|
| 247 |
+
text_encoder=text_encoder,
|
| 248 |
+
tokenizer=tokenizer,
|
| 249 |
+
unet=unet,
|
| 250 |
+
motion_adapter=motion_adapter,
|
| 251 |
+
scheduler=scheduler,
|
| 252 |
+
feature_extractor=feature_extractor,
|
| 253 |
+
image_encoder=image_encoder,
|
| 254 |
+
)
|
| 255 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 256 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 257 |
+
|
| 258 |
+
def encode_prompt(
|
| 259 |
+
self,
|
| 260 |
+
prompt,
|
| 261 |
+
device,
|
| 262 |
+
num_images_per_prompt,
|
| 263 |
+
do_classifier_free_guidance,
|
| 264 |
+
negative_prompt=None,
|
| 265 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 266 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 267 |
+
lora_scale: Optional[float] = None,
|
| 268 |
+
clip_skip: Optional[int] = None,
|
| 269 |
+
):
|
| 270 |
+
r"""
|
| 271 |
+
Encodes the prompt into text encoder hidden states.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 275 |
+
prompt to be encoded
|
| 276 |
+
device: (`torch.device`):
|
| 277 |
+
torch device
|
| 278 |
+
num_images_per_prompt (`int`):
|
| 279 |
+
number of images that should be generated per prompt
|
| 280 |
+
do_classifier_free_guidance (`bool`):
|
| 281 |
+
whether to use classifier free guidance or not
|
| 282 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 283 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 284 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 285 |
+
less than `1`).
|
| 286 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 287 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 288 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 289 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 290 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 291 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 292 |
+
argument.
|
| 293 |
+
lora_scale (`float`, *optional*):
|
| 294 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 295 |
+
clip_skip (`int`, *optional*):
|
| 296 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 297 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 298 |
+
"""
|
| 299 |
+
# set lora scale so that monkey patched LoRA
|
| 300 |
+
# function of text encoder can correctly access it
|
| 301 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 302 |
+
self._lora_scale = lora_scale
|
| 303 |
+
|
| 304 |
+
# dynamically adjust the LoRA scale
|
| 305 |
+
if not USE_PEFT_BACKEND:
|
| 306 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 307 |
+
else:
|
| 308 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 309 |
+
|
| 310 |
+
if prompt is not None and isinstance(prompt, (str, dict)):
|
| 311 |
+
batch_size = 1
|
| 312 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 313 |
+
batch_size = len(prompt)
|
| 314 |
+
else:
|
| 315 |
+
batch_size = prompt_embeds.shape[0]
|
| 316 |
+
|
| 317 |
+
if prompt_embeds is None:
|
| 318 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 319 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 320 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 321 |
+
|
| 322 |
+
text_inputs = self.tokenizer(
|
| 323 |
+
prompt,
|
| 324 |
+
padding="max_length",
|
| 325 |
+
max_length=self.tokenizer.model_max_length,
|
| 326 |
+
truncation=True,
|
| 327 |
+
return_tensors="pt",
|
| 328 |
+
)
|
| 329 |
+
text_input_ids = text_inputs.input_ids
|
| 330 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 331 |
+
|
| 332 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 333 |
+
text_input_ids, untruncated_ids
|
| 334 |
+
):
|
| 335 |
+
removed_text = self.tokenizer.batch_decode(
|
| 336 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 337 |
+
)
|
| 338 |
+
logger.warning(
|
| 339 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 340 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 344 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 345 |
+
else:
|
| 346 |
+
attention_mask = None
|
| 347 |
+
|
| 348 |
+
if clip_skip is None:
|
| 349 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 350 |
+
prompt_embeds = prompt_embeds[0]
|
| 351 |
+
else:
|
| 352 |
+
prompt_embeds = self.text_encoder(
|
| 353 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 354 |
+
)
|
| 355 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 356 |
+
# all the hidden states from the encoder layers. Then index into
|
| 357 |
+
# the tuple to access the hidden states from the desired layer.
|
| 358 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 359 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 360 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 361 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 362 |
+
# layer.
|
| 363 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 364 |
+
|
| 365 |
+
if self.text_encoder is not None:
|
| 366 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 367 |
+
elif self.unet is not None:
|
| 368 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 369 |
+
else:
|
| 370 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 371 |
+
|
| 372 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 373 |
+
|
| 374 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 375 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 376 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 377 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 378 |
+
|
| 379 |
+
# get unconditional embeddings for classifier free guidance
|
| 380 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 381 |
+
uncond_tokens: List[str]
|
| 382 |
+
if negative_prompt is None:
|
| 383 |
+
uncond_tokens = [""] * batch_size
|
| 384 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 385 |
+
raise TypeError(
|
| 386 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 387 |
+
f" {type(prompt)}."
|
| 388 |
+
)
|
| 389 |
+
elif isinstance(negative_prompt, str):
|
| 390 |
+
uncond_tokens = [negative_prompt]
|
| 391 |
+
elif batch_size != len(negative_prompt):
|
| 392 |
+
raise ValueError(
|
| 393 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 394 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 395 |
+
" the batch size of `prompt`."
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
uncond_tokens = negative_prompt
|
| 399 |
+
|
| 400 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 401 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 402 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 403 |
+
|
| 404 |
+
max_length = prompt_embeds.shape[1]
|
| 405 |
+
uncond_input = self.tokenizer(
|
| 406 |
+
uncond_tokens,
|
| 407 |
+
padding="max_length",
|
| 408 |
+
max_length=max_length,
|
| 409 |
+
truncation=True,
|
| 410 |
+
return_tensors="pt",
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 414 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 415 |
+
else:
|
| 416 |
+
attention_mask = None
|
| 417 |
+
|
| 418 |
+
negative_prompt_embeds = self.text_encoder(
|
| 419 |
+
uncond_input.input_ids.to(device),
|
| 420 |
+
attention_mask=attention_mask,
|
| 421 |
+
)
|
| 422 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 423 |
+
|
| 424 |
+
if do_classifier_free_guidance:
|
| 425 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 426 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 427 |
+
|
| 428 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 429 |
+
|
| 430 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 431 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 432 |
+
|
| 433 |
+
if self.text_encoder is not None:
|
| 434 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 435 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 436 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 437 |
+
|
| 438 |
+
return prompt_embeds, negative_prompt_embeds
|
| 439 |
+
|
| 440 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 441 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 442 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 443 |
+
|
| 444 |
+
if not isinstance(image, torch.Tensor):
|
| 445 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 446 |
+
|
| 447 |
+
image = image.to(device=device, dtype=dtype)
|
| 448 |
+
if output_hidden_states:
|
| 449 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 450 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 451 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 452 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 453 |
+
).hidden_states[-2]
|
| 454 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 455 |
+
num_images_per_prompt, dim=0
|
| 456 |
+
)
|
| 457 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 458 |
+
else:
|
| 459 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 460 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 461 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 462 |
+
|
| 463 |
+
return image_embeds, uncond_image_embeds
|
| 464 |
+
|
| 465 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 466 |
+
def prepare_ip_adapter_image_embeds(
|
| 467 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 468 |
+
):
|
| 469 |
+
image_embeds = []
|
| 470 |
+
if do_classifier_free_guidance:
|
| 471 |
+
negative_image_embeds = []
|
| 472 |
+
if ip_adapter_image_embeds is None:
|
| 473 |
+
if not isinstance(ip_adapter_image, list):
|
| 474 |
+
ip_adapter_image = [ip_adapter_image]
|
| 475 |
+
|
| 476 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 477 |
+
raise ValueError(
|
| 478 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 482 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 483 |
+
):
|
| 484 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 485 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 486 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 490 |
+
if do_classifier_free_guidance:
|
| 491 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 492 |
+
else:
|
| 493 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 494 |
+
if do_classifier_free_guidance:
|
| 495 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 496 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 497 |
+
image_embeds.append(single_image_embeds)
|
| 498 |
+
|
| 499 |
+
ip_adapter_image_embeds = []
|
| 500 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 501 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 502 |
+
if do_classifier_free_guidance:
|
| 503 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 504 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 505 |
+
|
| 506 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 507 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 508 |
+
|
| 509 |
+
return ip_adapter_image_embeds
|
| 510 |
+
|
| 511 |
+
def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
|
| 512 |
+
latents = []
|
| 513 |
+
for i in range(0, len(video), decode_chunk_size):
|
| 514 |
+
batch_video = video[i : i + decode_chunk_size]
|
| 515 |
+
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
|
| 516 |
+
latents.append(batch_video)
|
| 517 |
+
return torch.cat(latents)
|
| 518 |
+
|
| 519 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
| 520 |
+
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
| 521 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 522 |
+
|
| 523 |
+
batch_size, channels, num_frames, height, width = latents.shape
|
| 524 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
| 525 |
+
|
| 526 |
+
video = []
|
| 527 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 528 |
+
batch_latents = latents[i : i + decode_chunk_size]
|
| 529 |
+
batch_latents = self.vae.decode(batch_latents).sample
|
| 530 |
+
video.append(batch_latents)
|
| 531 |
+
|
| 532 |
+
video = torch.cat(video)
|
| 533 |
+
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 534 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 535 |
+
video = video.float()
|
| 536 |
+
return video
|
| 537 |
+
|
| 538 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 539 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 540 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 541 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 542 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 543 |
+
# and should be between [0, 1]
|
| 544 |
+
|
| 545 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 546 |
+
extra_step_kwargs = {}
|
| 547 |
+
if accepts_eta:
|
| 548 |
+
extra_step_kwargs["eta"] = eta
|
| 549 |
+
|
| 550 |
+
# check if the scheduler accepts generator
|
| 551 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 552 |
+
if accepts_generator:
|
| 553 |
+
extra_step_kwargs["generator"] = generator
|
| 554 |
+
return extra_step_kwargs
|
| 555 |
+
|
| 556 |
+
def check_inputs(
|
| 557 |
+
self,
|
| 558 |
+
prompt,
|
| 559 |
+
strength,
|
| 560 |
+
height,
|
| 561 |
+
width,
|
| 562 |
+
video=None,
|
| 563 |
+
latents=None,
|
| 564 |
+
negative_prompt=None,
|
| 565 |
+
prompt_embeds=None,
|
| 566 |
+
negative_prompt_embeds=None,
|
| 567 |
+
ip_adapter_image=None,
|
| 568 |
+
ip_adapter_image_embeds=None,
|
| 569 |
+
callback_on_step_end_tensor_inputs=None,
|
| 570 |
+
):
|
| 571 |
+
if strength < 0 or strength > 1:
|
| 572 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 573 |
+
|
| 574 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 575 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 576 |
+
|
| 577 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 578 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 579 |
+
):
|
| 580 |
+
raise ValueError(
|
| 581 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
if prompt is not None and prompt_embeds is not None:
|
| 585 |
+
raise ValueError(
|
| 586 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 587 |
+
" only forward one of the two."
|
| 588 |
+
)
|
| 589 |
+
elif prompt is None and prompt_embeds is None:
|
| 590 |
+
raise ValueError(
|
| 591 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 592 |
+
)
|
| 593 |
+
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
| 594 |
+
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
|
| 595 |
+
|
| 596 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 599 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 603 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 604 |
+
raise ValueError(
|
| 605 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 606 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 607 |
+
f" {negative_prompt_embeds.shape}."
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if video is not None and latents is not None:
|
| 611 |
+
raise ValueError("Only one of `video` or `latents` should be provided")
|
| 612 |
+
|
| 613 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| 614 |
+
raise ValueError(
|
| 615 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
if ip_adapter_image_embeds is not None:
|
| 619 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 620 |
+
raise ValueError(
|
| 621 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| 622 |
+
)
|
| 623 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
| 629 |
+
# get the original timestep using init_timestep
|
| 630 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 631 |
+
|
| 632 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 633 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 634 |
+
|
| 635 |
+
return timesteps, num_inference_steps - t_start
|
| 636 |
+
|
| 637 |
+
def prepare_latents(
|
| 638 |
+
self,
|
| 639 |
+
video: Optional[torch.Tensor] = None,
|
| 640 |
+
height: int = 64,
|
| 641 |
+
width: int = 64,
|
| 642 |
+
num_channels_latents: int = 4,
|
| 643 |
+
batch_size: int = 1,
|
| 644 |
+
timestep: Optional[int] = None,
|
| 645 |
+
dtype: Optional[torch.dtype] = None,
|
| 646 |
+
device: Optional[torch.device] = None,
|
| 647 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 648 |
+
latents: Optional[torch.Tensor] = None,
|
| 649 |
+
decode_chunk_size: int = 16,
|
| 650 |
+
add_noise: bool = False,
|
| 651 |
+
) -> torch.Tensor:
|
| 652 |
+
num_frames = video.shape[1] if latents is None else latents.shape[2]
|
| 653 |
+
shape = (
|
| 654 |
+
batch_size,
|
| 655 |
+
num_channels_latents,
|
| 656 |
+
num_frames,
|
| 657 |
+
height // self.vae_scale_factor,
|
| 658 |
+
width // self.vae_scale_factor,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 662 |
+
raise ValueError(
|
| 663 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 664 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
if latents is None:
|
| 668 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 669 |
+
if self.vae.config.force_upcast:
|
| 670 |
+
video = video.float()
|
| 671 |
+
self.vae.to(dtype=torch.float32)
|
| 672 |
+
|
| 673 |
+
if isinstance(generator, list):
|
| 674 |
+
init_latents = [
|
| 675 |
+
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
|
| 676 |
+
for i in range(batch_size)
|
| 677 |
+
]
|
| 678 |
+
else:
|
| 679 |
+
init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
|
| 680 |
+
|
| 681 |
+
init_latents = torch.cat(init_latents, dim=0)
|
| 682 |
+
|
| 683 |
+
# restore vae to original dtype
|
| 684 |
+
if self.vae.config.force_upcast:
|
| 685 |
+
self.vae.to(dtype)
|
| 686 |
+
|
| 687 |
+
init_latents = init_latents.to(dtype)
|
| 688 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
| 689 |
+
|
| 690 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 691 |
+
# expand init_latents for batch_size
|
| 692 |
+
error_message = (
|
| 693 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
| 694 |
+
" images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
|
| 695 |
+
)
|
| 696 |
+
raise ValueError(error_message)
|
| 697 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
| 698 |
+
raise ValueError(
|
| 699 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 700 |
+
)
|
| 701 |
+
else:
|
| 702 |
+
init_latents = torch.cat([init_latents], dim=0)
|
| 703 |
+
|
| 704 |
+
noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
| 705 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
|
| 706 |
+
else:
|
| 707 |
+
if shape != latents.shape:
|
| 708 |
+
# [B, C, F, H, W]
|
| 709 |
+
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
| 710 |
+
|
| 711 |
+
latents = latents.to(device, dtype=dtype)
|
| 712 |
+
|
| 713 |
+
if add_noise:
|
| 714 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 715 |
+
latents = self.scheduler.add_noise(latents, noise, timestep)
|
| 716 |
+
|
| 717 |
+
return latents
|
| 718 |
+
|
| 719 |
+
@property
|
| 720 |
+
def guidance_scale(self):
|
| 721 |
+
return self._guidance_scale
|
| 722 |
+
|
| 723 |
+
@property
|
| 724 |
+
def clip_skip(self):
|
| 725 |
+
return self._clip_skip
|
| 726 |
+
|
| 727 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 728 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 729 |
+
# corresponds to doing no classifier free guidance.
|
| 730 |
+
@property
|
| 731 |
+
def do_classifier_free_guidance(self):
|
| 732 |
+
return self._guidance_scale > 1
|
| 733 |
+
|
| 734 |
+
@property
|
| 735 |
+
def cross_attention_kwargs(self):
|
| 736 |
+
return self._cross_attention_kwargs
|
| 737 |
+
|
| 738 |
+
@property
|
| 739 |
+
def num_timesteps(self):
|
| 740 |
+
return self._num_timesteps
|
| 741 |
+
|
| 742 |
+
@property
|
| 743 |
+
def interrupt(self):
|
| 744 |
+
return self._interrupt
|
| 745 |
+
|
| 746 |
+
@torch.no_grad()
|
| 747 |
+
def __call__(
|
| 748 |
+
self,
|
| 749 |
+
video: List[List[PipelineImageInput]] = None,
|
| 750 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 751 |
+
height: Optional[int] = None,
|
| 752 |
+
width: Optional[int] = None,
|
| 753 |
+
num_inference_steps: int = 50,
|
| 754 |
+
enforce_inference_steps: bool = False,
|
| 755 |
+
timesteps: Optional[List[int]] = None,
|
| 756 |
+
sigmas: Optional[List[float]] = None,
|
| 757 |
+
guidance_scale: float = 7.5,
|
| 758 |
+
strength: float = 0.8,
|
| 759 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 760 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 761 |
+
eta: float = 0.0,
|
| 762 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 763 |
+
latents: Optional[torch.Tensor] = None,
|
| 764 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 765 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 766 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 767 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 768 |
+
output_type: Optional[str] = "pil",
|
| 769 |
+
return_dict: bool = True,
|
| 770 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 771 |
+
clip_skip: Optional[int] = None,
|
| 772 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 773 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 774 |
+
decode_chunk_size: int = 16,
|
| 775 |
+
):
|
| 776 |
+
r"""
|
| 777 |
+
The call function to the pipeline for generation.
|
| 778 |
+
|
| 779 |
+
Args:
|
| 780 |
+
video (`List[PipelineImageInput]`):
|
| 781 |
+
The input video to condition the generation on. Must be a list of images/frames of the video.
|
| 782 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 783 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 784 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 785 |
+
The height in pixels of the generated video.
|
| 786 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 787 |
+
The width in pixels of the generated video.
|
| 788 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 789 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
| 790 |
+
expense of slower inference.
|
| 791 |
+
timesteps (`List[int]`, *optional*):
|
| 792 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 793 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 794 |
+
passed will be used. Must be in descending order.
|
| 795 |
+
sigmas (`List[float]`, *optional*):
|
| 796 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 797 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 798 |
+
will be used.
|
| 799 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 800 |
+
Higher strength leads to more differences between original video and generated video.
|
| 801 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 802 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 803 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 804 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 805 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 806 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 807 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 808 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 809 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 810 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 811 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 812 |
+
generation deterministic.
|
| 813 |
+
latents (`torch.Tensor`, *optional*):
|
| 814 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 815 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 816 |
+
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
| 817 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
| 818 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 819 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 820 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 821 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 822 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 823 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 824 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
| 825 |
+
Optional image input to work with IP Adapters.
|
| 826 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 827 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 828 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 829 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 830 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 831 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 832 |
+
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
|
| 833 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 834 |
+
Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
|
| 835 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 836 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 837 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 838 |
+
clip_skip (`int`, *optional*):
|
| 839 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 840 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 841 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 842 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 843 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 844 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 845 |
+
`callback_on_step_end_tensor_inputs`.
|
| 846 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 847 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 848 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 849 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 850 |
+
decode_chunk_size (`int`, defaults to `16`):
|
| 851 |
+
The number of frames to decode at a time when calling `decode_latents` method.
|
| 852 |
+
|
| 853 |
+
Examples:
|
| 854 |
+
|
| 855 |
+
Returns:
|
| 856 |
+
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
| 857 |
+
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
| 858 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
| 859 |
+
"""
|
| 860 |
+
|
| 861 |
+
# 0. Default height and width to unet
|
| 862 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 863 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 864 |
+
|
| 865 |
+
num_videos_per_prompt = 1
|
| 866 |
+
|
| 867 |
+
# 1. Check inputs. Raise error if not correct
|
| 868 |
+
self.check_inputs(
|
| 869 |
+
prompt=prompt,
|
| 870 |
+
strength=strength,
|
| 871 |
+
height=height,
|
| 872 |
+
width=width,
|
| 873 |
+
negative_prompt=negative_prompt,
|
| 874 |
+
prompt_embeds=prompt_embeds,
|
| 875 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 876 |
+
video=video,
|
| 877 |
+
latents=latents,
|
| 878 |
+
ip_adapter_image=ip_adapter_image,
|
| 879 |
+
ip_adapter_image_embeds=ip_adapter_image_embeds,
|
| 880 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
self._guidance_scale = guidance_scale
|
| 884 |
+
self._clip_skip = clip_skip
|
| 885 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 886 |
+
self._interrupt = False
|
| 887 |
+
|
| 888 |
+
# 2. Define call parameters
|
| 889 |
+
if prompt is not None and isinstance(prompt, (str, dict)):
|
| 890 |
+
batch_size = 1
|
| 891 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 892 |
+
batch_size = len(prompt)
|
| 893 |
+
else:
|
| 894 |
+
batch_size = prompt_embeds.shape[0]
|
| 895 |
+
|
| 896 |
+
device = self._execution_device
|
| 897 |
+
dtype = self.dtype
|
| 898 |
+
|
| 899 |
+
# 3. Prepare timesteps
|
| 900 |
+
if not enforce_inference_steps:
|
| 901 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 902 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 903 |
+
)
|
| 904 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
| 905 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
| 906 |
+
else:
|
| 907 |
+
denoising_inference_steps = int(num_inference_steps / strength)
|
| 908 |
+
timesteps, denoising_inference_steps = retrieve_timesteps(
|
| 909 |
+
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
|
| 910 |
+
)
|
| 911 |
+
timesteps = timesteps[-num_inference_steps:]
|
| 912 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
| 913 |
+
|
| 914 |
+
# 4. Prepare latent variables
|
| 915 |
+
if latents is None:
|
| 916 |
+
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
| 917 |
+
# Move the number of frames before the number of channels.
|
| 918 |
+
video = video.permute(0, 2, 1, 3, 4)
|
| 919 |
+
video = video.to(device=device, dtype=dtype)
|
| 920 |
+
num_channels_latents = self.unet.config.in_channels
|
| 921 |
+
latents = self.prepare_latents(
|
| 922 |
+
video=video,
|
| 923 |
+
height=height,
|
| 924 |
+
width=width,
|
| 925 |
+
num_channels_latents=num_channels_latents,
|
| 926 |
+
batch_size=batch_size * num_videos_per_prompt,
|
| 927 |
+
timestep=latent_timestep,
|
| 928 |
+
dtype=dtype,
|
| 929 |
+
device=device,
|
| 930 |
+
generator=generator,
|
| 931 |
+
latents=latents,
|
| 932 |
+
decode_chunk_size=decode_chunk_size,
|
| 933 |
+
add_noise=enforce_inference_steps,
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
# 5. Encode input prompt
|
| 937 |
+
text_encoder_lora_scale = (
|
| 938 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 939 |
+
)
|
| 940 |
+
num_frames = latents.shape[2]
|
| 941 |
+
if self.free_noise_enabled:
|
| 942 |
+
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
| 943 |
+
prompt=prompt,
|
| 944 |
+
num_frames=num_frames,
|
| 945 |
+
device=device,
|
| 946 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 947 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 948 |
+
negative_prompt=negative_prompt,
|
| 949 |
+
prompt_embeds=prompt_embeds,
|
| 950 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 951 |
+
lora_scale=text_encoder_lora_scale,
|
| 952 |
+
clip_skip=self.clip_skip,
|
| 953 |
+
)
|
| 954 |
+
else:
|
| 955 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 956 |
+
prompt,
|
| 957 |
+
device,
|
| 958 |
+
num_videos_per_prompt,
|
| 959 |
+
self.do_classifier_free_guidance,
|
| 960 |
+
negative_prompt,
|
| 961 |
+
prompt_embeds=prompt_embeds,
|
| 962 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 963 |
+
lora_scale=text_encoder_lora_scale,
|
| 964 |
+
clip_skip=self.clip_skip,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 968 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 969 |
+
# to avoid doing two forward passes
|
| 970 |
+
if self.do_classifier_free_guidance:
|
| 971 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 972 |
+
|
| 973 |
+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
| 974 |
+
|
| 975 |
+
# 6. Prepare IP-Adapter embeddings
|
| 976 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 977 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 978 |
+
ip_adapter_image,
|
| 979 |
+
ip_adapter_image_embeds,
|
| 980 |
+
device,
|
| 981 |
+
batch_size * num_videos_per_prompt,
|
| 982 |
+
self.do_classifier_free_guidance,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 986 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 987 |
+
|
| 988 |
+
# 8. Add image embeds for IP-Adapter
|
| 989 |
+
added_cond_kwargs = (
|
| 990 |
+
{"image_embeds": image_embeds}
|
| 991 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| 992 |
+
else None
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
| 996 |
+
for free_init_iter in range(num_free_init_iters):
|
| 997 |
+
if self.free_init_enabled:
|
| 998 |
+
latents, timesteps = self._apply_free_init(
|
| 999 |
+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
| 1000 |
+
)
|
| 1001 |
+
num_inference_steps = len(timesteps)
|
| 1002 |
+
# make sure to readjust timesteps based on strength
|
| 1003 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
| 1004 |
+
|
| 1005 |
+
self._num_timesteps = len(timesteps)
|
| 1006 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1007 |
+
|
| 1008 |
+
# 9. Denoising loop
|
| 1009 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
| 1010 |
+
for i, t in enumerate(timesteps):
|
| 1011 |
+
if self.interrupt:
|
| 1012 |
+
continue
|
| 1013 |
+
|
| 1014 |
+
# expand the latents if we are doing classifier free guidance
|
| 1015 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1016 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1017 |
+
|
| 1018 |
+
# predict the noise residual
|
| 1019 |
+
noise_pred = self.unet(
|
| 1020 |
+
latent_model_input,
|
| 1021 |
+
t,
|
| 1022 |
+
encoder_hidden_states=prompt_embeds,
|
| 1023 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1024 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1025 |
+
).sample
|
| 1026 |
+
|
| 1027 |
+
# perform guidance
|
| 1028 |
+
if self.do_classifier_free_guidance:
|
| 1029 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1030 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1031 |
+
|
| 1032 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1033 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 1034 |
+
|
| 1035 |
+
if callback_on_step_end is not None:
|
| 1036 |
+
callback_kwargs = {}
|
| 1037 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1038 |
+
callback_kwargs[k] = locals()[k]
|
| 1039 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1040 |
+
|
| 1041 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1042 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1043 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1044 |
+
|
| 1045 |
+
# call the callback, if provided
|
| 1046 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1047 |
+
progress_bar.update()
|
| 1048 |
+
|
| 1049 |
+
if XLA_AVAILABLE:
|
| 1050 |
+
xm.mark_step()
|
| 1051 |
+
|
| 1052 |
+
# 10. Post-processing
|
| 1053 |
+
if output_type == "latent":
|
| 1054 |
+
video = latents
|
| 1055 |
+
else:
|
| 1056 |
+
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
| 1057 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
| 1058 |
+
|
| 1059 |
+
# 11. Offload all models
|
| 1060 |
+
self.maybe_free_model_hooks()
|
| 1061 |
+
|
| 1062 |
+
if not return_dict:
|
| 1063 |
+
return (video,)
|
| 1064 |
+
|
| 1065 |
+
return AnimateDiffPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
ADDED
|
@@ -0,0 +1,1353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 21 |
+
|
| 22 |
+
from ...image_processor import PipelineImageInput
|
| 23 |
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
| 24 |
+
from ...models import (
|
| 25 |
+
AutoencoderKL,
|
| 26 |
+
ControlNetModel,
|
| 27 |
+
ImageProjection,
|
| 28 |
+
MultiControlNetModel,
|
| 29 |
+
UNet2DConditionModel,
|
| 30 |
+
UNetMotionModel,
|
| 31 |
+
)
|
| 32 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 33 |
+
from ...models.unets.unet_motion_model import MotionAdapter
|
| 34 |
+
from ...schedulers import (
|
| 35 |
+
DDIMScheduler,
|
| 36 |
+
DPMSolverMultistepScheduler,
|
| 37 |
+
EulerAncestralDiscreteScheduler,
|
| 38 |
+
EulerDiscreteScheduler,
|
| 39 |
+
LMSDiscreteScheduler,
|
| 40 |
+
PNDMScheduler,
|
| 41 |
+
)
|
| 42 |
+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
|
| 43 |
+
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
| 44 |
+
from ...video_processor import VideoProcessor
|
| 45 |
+
from ..free_init_utils import FreeInitMixin
|
| 46 |
+
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
| 47 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 48 |
+
from .pipeline_output import AnimateDiffPipelineOutput
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if is_torch_xla_available():
|
| 52 |
+
import torch_xla.core.xla_model as xm
|
| 53 |
+
|
| 54 |
+
XLA_AVAILABLE = True
|
| 55 |
+
else:
|
| 56 |
+
XLA_AVAILABLE = False
|
| 57 |
+
|
| 58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
EXAMPLE_DOC_STRING = """
|
| 62 |
+
Examples:
|
| 63 |
+
```py
|
| 64 |
+
>>> import torch
|
| 65 |
+
>>> from PIL import Image
|
| 66 |
+
>>> from tqdm.auto import tqdm
|
| 67 |
+
|
| 68 |
+
>>> from diffusers import AnimateDiffVideoToVideoControlNetPipeline
|
| 69 |
+
>>> from diffusers.utils import export_to_gif, load_video
|
| 70 |
+
>>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
|
| 71 |
+
|
| 72 |
+
>>> controlnet = ControlNetModel.from_pretrained(
|
| 73 |
+
... "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
|
| 74 |
+
... )
|
| 75 |
+
>>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
| 76 |
+
>>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
| 77 |
+
|
| 78 |
+
>>> pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(
|
| 79 |
+
... "SG161222/Realistic_Vision_V5.1_noVAE",
|
| 80 |
+
... motion_adapter=motion_adapter,
|
| 81 |
+
... controlnet=controlnet,
|
| 82 |
+
... vae=vae,
|
| 83 |
+
... ).to(device="cuda", dtype=torch.float16)
|
| 84 |
+
|
| 85 |
+
>>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
| 86 |
+
>>> pipe.load_lora_weights(
|
| 87 |
+
... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora"
|
| 88 |
+
... )
|
| 89 |
+
>>> pipe.set_adapters(["lcm-lora"], [0.8])
|
| 90 |
+
|
| 91 |
+
>>> video = load_video(
|
| 92 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif"
|
| 93 |
+
... )
|
| 94 |
+
>>> video = [frame.convert("RGB") for frame in video]
|
| 95 |
+
|
| 96 |
+
>>> from controlnet_aux.processor import OpenposeDetector
|
| 97 |
+
|
| 98 |
+
>>> open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
|
| 99 |
+
>>> for frame in tqdm(video):
|
| 100 |
+
... conditioning_frames.append(open_pose(frame))
|
| 101 |
+
|
| 102 |
+
>>> prompt = "astronaut in space, dancing"
|
| 103 |
+
>>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
|
| 104 |
+
|
| 105 |
+
>>> strength = 0.8
|
| 106 |
+
>>> with torch.inference_mode():
|
| 107 |
+
... video = pipe(
|
| 108 |
+
... video=video,
|
| 109 |
+
... prompt=prompt,
|
| 110 |
+
... negative_prompt=negative_prompt,
|
| 111 |
+
... num_inference_steps=10,
|
| 112 |
+
... guidance_scale=2.0,
|
| 113 |
+
... controlnet_conditioning_scale=0.75,
|
| 114 |
+
... conditioning_frames=conditioning_frames,
|
| 115 |
+
... strength=strength,
|
| 116 |
+
... generator=torch.Generator().manual_seed(42),
|
| 117 |
+
... ).frames[0]
|
| 118 |
+
|
| 119 |
+
>>> video = [frame.resize(conditioning_frames[0].size) for frame in video]
|
| 120 |
+
>>> export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8)
|
| 121 |
+
```
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 126 |
+
def retrieve_latents(
|
| 127 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 128 |
+
):
|
| 129 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 130 |
+
return encoder_output.latent_dist.sample(generator)
|
| 131 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 132 |
+
return encoder_output.latent_dist.mode()
|
| 133 |
+
elif hasattr(encoder_output, "latents"):
|
| 134 |
+
return encoder_output.latents
|
| 135 |
+
else:
|
| 136 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 140 |
+
def retrieve_timesteps(
|
| 141 |
+
scheduler,
|
| 142 |
+
num_inference_steps: Optional[int] = None,
|
| 143 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 144 |
+
timesteps: Optional[List[int]] = None,
|
| 145 |
+
sigmas: Optional[List[float]] = None,
|
| 146 |
+
**kwargs,
|
| 147 |
+
):
|
| 148 |
+
r"""
|
| 149 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 150 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
scheduler (`SchedulerMixin`):
|
| 154 |
+
The scheduler to get timesteps from.
|
| 155 |
+
num_inference_steps (`int`):
|
| 156 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 157 |
+
must be `None`.
|
| 158 |
+
device (`str` or `torch.device`, *optional*):
|
| 159 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 160 |
+
timesteps (`List[int]`, *optional*):
|
| 161 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 162 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 163 |
+
sigmas (`List[float]`, *optional*):
|
| 164 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 165 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 169 |
+
second element is the number of inference steps.
|
| 170 |
+
"""
|
| 171 |
+
if timesteps is not None and sigmas is not None:
|
| 172 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 173 |
+
if timesteps is not None:
|
| 174 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 175 |
+
if not accepts_timesteps:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 178 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 179 |
+
)
|
| 180 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 181 |
+
timesteps = scheduler.timesteps
|
| 182 |
+
num_inference_steps = len(timesteps)
|
| 183 |
+
elif sigmas is not None:
|
| 184 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 185 |
+
if not accept_sigmas:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 188 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 189 |
+
)
|
| 190 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 191 |
+
timesteps = scheduler.timesteps
|
| 192 |
+
num_inference_steps = len(timesteps)
|
| 193 |
+
else:
|
| 194 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 195 |
+
timesteps = scheduler.timesteps
|
| 196 |
+
return timesteps, num_inference_steps
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class AnimateDiffVideoToVideoControlNetPipeline(
|
| 200 |
+
DiffusionPipeline,
|
| 201 |
+
StableDiffusionMixin,
|
| 202 |
+
TextualInversionLoaderMixin,
|
| 203 |
+
IPAdapterMixin,
|
| 204 |
+
StableDiffusionLoraLoaderMixin,
|
| 205 |
+
FreeInitMixin,
|
| 206 |
+
AnimateDiffFreeNoiseMixin,
|
| 207 |
+
FromSingleFileMixin,
|
| 208 |
+
):
|
| 209 |
+
r"""
|
| 210 |
+
Pipeline for video-to-video generation with ControlNet guidance.
|
| 211 |
+
|
| 212 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 213 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 214 |
+
|
| 215 |
+
The pipeline also inherits the following loading methods:
|
| 216 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 217 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 218 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 219 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
vae ([`AutoencoderKL`]):
|
| 223 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 224 |
+
text_encoder ([`CLIPTextModel`]):
|
| 225 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 226 |
+
tokenizer (`CLIPTokenizer`):
|
| 227 |
+
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
| 228 |
+
unet ([`UNet2DConditionModel`]):
|
| 229 |
+
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
|
| 230 |
+
motion_adapter ([`MotionAdapter`]):
|
| 231 |
+
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
|
| 232 |
+
controlnet ([`ControlNetModel`] or `List[ControlNetModel]` or `Tuple[ControlNetModel]` or `MultiControlNetModel`):
|
| 233 |
+
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
| 234 |
+
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
| 235 |
+
additional conditioning.
|
| 236 |
+
scheduler ([`SchedulerMixin`]):
|
| 237 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 238 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 242 |
+
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
|
| 243 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 244 |
+
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
vae: AutoencoderKL,
|
| 248 |
+
text_encoder: CLIPTextModel,
|
| 249 |
+
tokenizer: CLIPTokenizer,
|
| 250 |
+
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
| 251 |
+
motion_adapter: MotionAdapter,
|
| 252 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
| 253 |
+
scheduler: Union[
|
| 254 |
+
DDIMScheduler,
|
| 255 |
+
PNDMScheduler,
|
| 256 |
+
LMSDiscreteScheduler,
|
| 257 |
+
EulerDiscreteScheduler,
|
| 258 |
+
EulerAncestralDiscreteScheduler,
|
| 259 |
+
DPMSolverMultistepScheduler,
|
| 260 |
+
],
|
| 261 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 262 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 263 |
+
):
|
| 264 |
+
super().__init__()
|
| 265 |
+
if isinstance(unet, UNet2DConditionModel):
|
| 266 |
+
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
| 267 |
+
|
| 268 |
+
if isinstance(controlnet, (list, tuple)):
|
| 269 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 270 |
+
|
| 271 |
+
self.register_modules(
|
| 272 |
+
vae=vae,
|
| 273 |
+
text_encoder=text_encoder,
|
| 274 |
+
tokenizer=tokenizer,
|
| 275 |
+
unet=unet,
|
| 276 |
+
motion_adapter=motion_adapter,
|
| 277 |
+
controlnet=controlnet,
|
| 278 |
+
scheduler=scheduler,
|
| 279 |
+
feature_extractor=feature_extractor,
|
| 280 |
+
image_encoder=image_encoder,
|
| 281 |
+
)
|
| 282 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 283 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 284 |
+
self.control_video_processor = VideoProcessor(
|
| 285 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_prompt
|
| 289 |
+
def encode_prompt(
|
| 290 |
+
self,
|
| 291 |
+
prompt,
|
| 292 |
+
device,
|
| 293 |
+
num_images_per_prompt,
|
| 294 |
+
do_classifier_free_guidance,
|
| 295 |
+
negative_prompt=None,
|
| 296 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 297 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 298 |
+
lora_scale: Optional[float] = None,
|
| 299 |
+
clip_skip: Optional[int] = None,
|
| 300 |
+
):
|
| 301 |
+
r"""
|
| 302 |
+
Encodes the prompt into text encoder hidden states.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 306 |
+
prompt to be encoded
|
| 307 |
+
device: (`torch.device`):
|
| 308 |
+
torch device
|
| 309 |
+
num_images_per_prompt (`int`):
|
| 310 |
+
number of images that should be generated per prompt
|
| 311 |
+
do_classifier_free_guidance (`bool`):
|
| 312 |
+
whether to use classifier free guidance or not
|
| 313 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 314 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 315 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 316 |
+
less than `1`).
|
| 317 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 318 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 319 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 320 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 321 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 322 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 323 |
+
argument.
|
| 324 |
+
lora_scale (`float`, *optional*):
|
| 325 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 326 |
+
clip_skip (`int`, *optional*):
|
| 327 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 328 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 329 |
+
"""
|
| 330 |
+
# set lora scale so that monkey patched LoRA
|
| 331 |
+
# function of text encoder can correctly access it
|
| 332 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 333 |
+
self._lora_scale = lora_scale
|
| 334 |
+
|
| 335 |
+
# dynamically adjust the LoRA scale
|
| 336 |
+
if not USE_PEFT_BACKEND:
|
| 337 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 338 |
+
else:
|
| 339 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 340 |
+
|
| 341 |
+
if prompt is not None and isinstance(prompt, (str, dict)):
|
| 342 |
+
batch_size = 1
|
| 343 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 344 |
+
batch_size = len(prompt)
|
| 345 |
+
else:
|
| 346 |
+
batch_size = prompt_embeds.shape[0]
|
| 347 |
+
|
| 348 |
+
if prompt_embeds is None:
|
| 349 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 350 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 351 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 352 |
+
|
| 353 |
+
text_inputs = self.tokenizer(
|
| 354 |
+
prompt,
|
| 355 |
+
padding="max_length",
|
| 356 |
+
max_length=self.tokenizer.model_max_length,
|
| 357 |
+
truncation=True,
|
| 358 |
+
return_tensors="pt",
|
| 359 |
+
)
|
| 360 |
+
text_input_ids = text_inputs.input_ids
|
| 361 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 362 |
+
|
| 363 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 364 |
+
text_input_ids, untruncated_ids
|
| 365 |
+
):
|
| 366 |
+
removed_text = self.tokenizer.batch_decode(
|
| 367 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 368 |
+
)
|
| 369 |
+
logger.warning(
|
| 370 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 371 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 375 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 376 |
+
else:
|
| 377 |
+
attention_mask = None
|
| 378 |
+
|
| 379 |
+
if clip_skip is None:
|
| 380 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 381 |
+
prompt_embeds = prompt_embeds[0]
|
| 382 |
+
else:
|
| 383 |
+
prompt_embeds = self.text_encoder(
|
| 384 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 385 |
+
)
|
| 386 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 387 |
+
# all the hidden states from the encoder layers. Then index into
|
| 388 |
+
# the tuple to access the hidden states from the desired layer.
|
| 389 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 390 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 391 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 392 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 393 |
+
# layer.
|
| 394 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 395 |
+
|
| 396 |
+
if self.text_encoder is not None:
|
| 397 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 398 |
+
elif self.unet is not None:
|
| 399 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 400 |
+
else:
|
| 401 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 402 |
+
|
| 403 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 404 |
+
|
| 405 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 406 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 407 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 408 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 409 |
+
|
| 410 |
+
# get unconditional embeddings for classifier free guidance
|
| 411 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 412 |
+
uncond_tokens: List[str]
|
| 413 |
+
if negative_prompt is None:
|
| 414 |
+
uncond_tokens = [""] * batch_size
|
| 415 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 416 |
+
raise TypeError(
|
| 417 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 418 |
+
f" {type(prompt)}."
|
| 419 |
+
)
|
| 420 |
+
elif isinstance(negative_prompt, str):
|
| 421 |
+
uncond_tokens = [negative_prompt]
|
| 422 |
+
elif batch_size != len(negative_prompt):
|
| 423 |
+
raise ValueError(
|
| 424 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 425 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 426 |
+
" the batch size of `prompt`."
|
| 427 |
+
)
|
| 428 |
+
else:
|
| 429 |
+
uncond_tokens = negative_prompt
|
| 430 |
+
|
| 431 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 432 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 433 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 434 |
+
|
| 435 |
+
max_length = prompt_embeds.shape[1]
|
| 436 |
+
uncond_input = self.tokenizer(
|
| 437 |
+
uncond_tokens,
|
| 438 |
+
padding="max_length",
|
| 439 |
+
max_length=max_length,
|
| 440 |
+
truncation=True,
|
| 441 |
+
return_tensors="pt",
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 445 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 446 |
+
else:
|
| 447 |
+
attention_mask = None
|
| 448 |
+
|
| 449 |
+
negative_prompt_embeds = self.text_encoder(
|
| 450 |
+
uncond_input.input_ids.to(device),
|
| 451 |
+
attention_mask=attention_mask,
|
| 452 |
+
)
|
| 453 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 454 |
+
|
| 455 |
+
if do_classifier_free_guidance:
|
| 456 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 457 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 458 |
+
|
| 459 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 460 |
+
|
| 461 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 462 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 463 |
+
|
| 464 |
+
if self.text_encoder is not None:
|
| 465 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 466 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 467 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 468 |
+
|
| 469 |
+
return prompt_embeds, negative_prompt_embeds
|
| 470 |
+
|
| 471 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 472 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 473 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 474 |
+
|
| 475 |
+
if not isinstance(image, torch.Tensor):
|
| 476 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 477 |
+
|
| 478 |
+
image = image.to(device=device, dtype=dtype)
|
| 479 |
+
if output_hidden_states:
|
| 480 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 481 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 482 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 483 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 484 |
+
).hidden_states[-2]
|
| 485 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 486 |
+
num_images_per_prompt, dim=0
|
| 487 |
+
)
|
| 488 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 489 |
+
else:
|
| 490 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 491 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 492 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 493 |
+
|
| 494 |
+
return image_embeds, uncond_image_embeds
|
| 495 |
+
|
| 496 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 497 |
+
def prepare_ip_adapter_image_embeds(
|
| 498 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 499 |
+
):
|
| 500 |
+
image_embeds = []
|
| 501 |
+
if do_classifier_free_guidance:
|
| 502 |
+
negative_image_embeds = []
|
| 503 |
+
if ip_adapter_image_embeds is None:
|
| 504 |
+
if not isinstance(ip_adapter_image, list):
|
| 505 |
+
ip_adapter_image = [ip_adapter_image]
|
| 506 |
+
|
| 507 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 508 |
+
raise ValueError(
|
| 509 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 513 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 514 |
+
):
|
| 515 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 516 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 517 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 521 |
+
if do_classifier_free_guidance:
|
| 522 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 523 |
+
else:
|
| 524 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 525 |
+
if do_classifier_free_guidance:
|
| 526 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 527 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 528 |
+
image_embeds.append(single_image_embeds)
|
| 529 |
+
|
| 530 |
+
ip_adapter_image_embeds = []
|
| 531 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 532 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 533 |
+
if do_classifier_free_guidance:
|
| 534 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 535 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 536 |
+
|
| 537 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 538 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 539 |
+
|
| 540 |
+
return ip_adapter_image_embeds
|
| 541 |
+
|
| 542 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_video
|
| 543 |
+
def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
|
| 544 |
+
latents = []
|
| 545 |
+
for i in range(0, len(video), decode_chunk_size):
|
| 546 |
+
batch_video = video[i : i + decode_chunk_size]
|
| 547 |
+
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
|
| 548 |
+
latents.append(batch_video)
|
| 549 |
+
return torch.cat(latents)
|
| 550 |
+
|
| 551 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
| 552 |
+
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
| 553 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 554 |
+
|
| 555 |
+
batch_size, channels, num_frames, height, width = latents.shape
|
| 556 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
| 557 |
+
|
| 558 |
+
video = []
|
| 559 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 560 |
+
batch_latents = latents[i : i + decode_chunk_size]
|
| 561 |
+
batch_latents = self.vae.decode(batch_latents).sample
|
| 562 |
+
video.append(batch_latents)
|
| 563 |
+
|
| 564 |
+
video = torch.cat(video)
|
| 565 |
+
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 566 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 567 |
+
video = video.float()
|
| 568 |
+
return video
|
| 569 |
+
|
| 570 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 571 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 572 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 573 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 574 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 575 |
+
# and should be between [0, 1]
|
| 576 |
+
|
| 577 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 578 |
+
extra_step_kwargs = {}
|
| 579 |
+
if accepts_eta:
|
| 580 |
+
extra_step_kwargs["eta"] = eta
|
| 581 |
+
|
| 582 |
+
# check if the scheduler accepts generator
|
| 583 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 584 |
+
if accepts_generator:
|
| 585 |
+
extra_step_kwargs["generator"] = generator
|
| 586 |
+
return extra_step_kwargs
|
| 587 |
+
|
| 588 |
+
def check_inputs(
|
| 589 |
+
self,
|
| 590 |
+
prompt,
|
| 591 |
+
strength,
|
| 592 |
+
height,
|
| 593 |
+
width,
|
| 594 |
+
video=None,
|
| 595 |
+
conditioning_frames=None,
|
| 596 |
+
latents=None,
|
| 597 |
+
negative_prompt=None,
|
| 598 |
+
prompt_embeds=None,
|
| 599 |
+
negative_prompt_embeds=None,
|
| 600 |
+
ip_adapter_image=None,
|
| 601 |
+
ip_adapter_image_embeds=None,
|
| 602 |
+
callback_on_step_end_tensor_inputs=None,
|
| 603 |
+
controlnet_conditioning_scale=1.0,
|
| 604 |
+
control_guidance_start=0.0,
|
| 605 |
+
control_guidance_end=1.0,
|
| 606 |
+
):
|
| 607 |
+
if strength < 0 or strength > 1:
|
| 608 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 609 |
+
|
| 610 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 611 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 612 |
+
|
| 613 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 614 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 615 |
+
):
|
| 616 |
+
raise ValueError(
|
| 617 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
if prompt is not None and prompt_embeds is not None:
|
| 621 |
+
raise ValueError(
|
| 622 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 623 |
+
" only forward one of the two."
|
| 624 |
+
)
|
| 625 |
+
elif prompt is None and prompt_embeds is None:
|
| 626 |
+
raise ValueError(
|
| 627 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 628 |
+
)
|
| 629 |
+
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
| 630 |
+
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
|
| 631 |
+
|
| 632 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 633 |
+
raise ValueError(
|
| 634 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 635 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 639 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 640 |
+
raise ValueError(
|
| 641 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 642 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 643 |
+
f" {negative_prompt_embeds.shape}."
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if video is not None and latents is not None:
|
| 647 |
+
raise ValueError("Only one of `video` or `latents` should be provided")
|
| 648 |
+
|
| 649 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| 650 |
+
raise ValueError(
|
| 651 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
if ip_adapter_image_embeds is not None:
|
| 655 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 656 |
+
raise ValueError(
|
| 657 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| 658 |
+
)
|
| 659 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| 660 |
+
raise ValueError(
|
| 661 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 665 |
+
if isinstance(prompt, list):
|
| 666 |
+
logger.warning(
|
| 667 |
+
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
| 668 |
+
" prompts. The conditionings will be fixed across the prompts."
|
| 669 |
+
)
|
| 670 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
| 671 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
num_frames = len(video) if latents is None else latents.shape[2]
|
| 675 |
+
|
| 676 |
+
if (
|
| 677 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 678 |
+
or is_compiled
|
| 679 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 680 |
+
):
|
| 681 |
+
if not isinstance(conditioning_frames, list):
|
| 682 |
+
raise TypeError(
|
| 683 |
+
f"For single controlnet, `image` must be of type `list` but got {type(conditioning_frames)}"
|
| 684 |
+
)
|
| 685 |
+
if len(conditioning_frames) != num_frames:
|
| 686 |
+
raise ValueError(f"Excepted image to have length {num_frames} but got {len(conditioning_frames)=}")
|
| 687 |
+
elif (
|
| 688 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 689 |
+
or is_compiled
|
| 690 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 691 |
+
):
|
| 692 |
+
if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list):
|
| 693 |
+
raise TypeError(
|
| 694 |
+
f"For multiple controlnets: `image` must be type list of lists but got {type(conditioning_frames)=}"
|
| 695 |
+
)
|
| 696 |
+
if len(conditioning_frames[0]) != num_frames:
|
| 697 |
+
raise ValueError(
|
| 698 |
+
f"Expected length of image sublist as {num_frames} but got {len(conditioning_frames)=}"
|
| 699 |
+
)
|
| 700 |
+
if any(len(img) != len(conditioning_frames[0]) for img in conditioning_frames):
|
| 701 |
+
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
|
| 702 |
+
else:
|
| 703 |
+
assert False
|
| 704 |
+
|
| 705 |
+
# Check `controlnet_conditioning_scale`
|
| 706 |
+
if (
|
| 707 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 708 |
+
or is_compiled
|
| 709 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 710 |
+
):
|
| 711 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
| 712 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
| 713 |
+
elif (
|
| 714 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 715 |
+
or is_compiled
|
| 716 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 717 |
+
):
|
| 718 |
+
if isinstance(controlnet_conditioning_scale, list):
|
| 719 |
+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
| 720 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
| 721 |
+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
| 722 |
+
self.controlnet.nets
|
| 723 |
+
):
|
| 724 |
+
raise ValueError(
|
| 725 |
+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
| 726 |
+
" the same length as the number of controlnets"
|
| 727 |
+
)
|
| 728 |
+
else:
|
| 729 |
+
assert False
|
| 730 |
+
|
| 731 |
+
if not isinstance(control_guidance_start, (tuple, list)):
|
| 732 |
+
control_guidance_start = [control_guidance_start]
|
| 733 |
+
|
| 734 |
+
if not isinstance(control_guidance_end, (tuple, list)):
|
| 735 |
+
control_guidance_end = [control_guidance_end]
|
| 736 |
+
|
| 737 |
+
if len(control_guidance_start) != len(control_guidance_end):
|
| 738 |
+
raise ValueError(
|
| 739 |
+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 743 |
+
if len(control_guidance_start) != len(self.controlnet.nets):
|
| 744 |
+
raise ValueError(
|
| 745 |
+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
for start, end in zip(control_guidance_start, control_guidance_end):
|
| 749 |
+
if start >= end:
|
| 750 |
+
raise ValueError(
|
| 751 |
+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
| 752 |
+
)
|
| 753 |
+
if start < 0.0:
|
| 754 |
+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
| 755 |
+
if end > 1.0:
|
| 756 |
+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
| 757 |
+
|
| 758 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
|
| 759 |
+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
| 760 |
+
# get the original timestep using init_timestep
|
| 761 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 762 |
+
|
| 763 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 764 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 765 |
+
|
| 766 |
+
return timesteps, num_inference_steps - t_start
|
| 767 |
+
|
| 768 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.prepare_latents
|
| 769 |
+
def prepare_latents(
|
| 770 |
+
self,
|
| 771 |
+
video: Optional[torch.Tensor] = None,
|
| 772 |
+
height: int = 64,
|
| 773 |
+
width: int = 64,
|
| 774 |
+
num_channels_latents: int = 4,
|
| 775 |
+
batch_size: int = 1,
|
| 776 |
+
timestep: Optional[int] = None,
|
| 777 |
+
dtype: Optional[torch.dtype] = None,
|
| 778 |
+
device: Optional[torch.device] = None,
|
| 779 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 780 |
+
latents: Optional[torch.Tensor] = None,
|
| 781 |
+
decode_chunk_size: int = 16,
|
| 782 |
+
add_noise: bool = False,
|
| 783 |
+
) -> torch.Tensor:
|
| 784 |
+
num_frames = video.shape[1] if latents is None else latents.shape[2]
|
| 785 |
+
shape = (
|
| 786 |
+
batch_size,
|
| 787 |
+
num_channels_latents,
|
| 788 |
+
num_frames,
|
| 789 |
+
height // self.vae_scale_factor,
|
| 790 |
+
width // self.vae_scale_factor,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 794 |
+
raise ValueError(
|
| 795 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 796 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
if latents is None:
|
| 800 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 801 |
+
if self.vae.config.force_upcast:
|
| 802 |
+
video = video.float()
|
| 803 |
+
self.vae.to(dtype=torch.float32)
|
| 804 |
+
|
| 805 |
+
if isinstance(generator, list):
|
| 806 |
+
init_latents = [
|
| 807 |
+
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
|
| 808 |
+
for i in range(batch_size)
|
| 809 |
+
]
|
| 810 |
+
else:
|
| 811 |
+
init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
|
| 812 |
+
|
| 813 |
+
init_latents = torch.cat(init_latents, dim=0)
|
| 814 |
+
|
| 815 |
+
# restore vae to original dtype
|
| 816 |
+
if self.vae.config.force_upcast:
|
| 817 |
+
self.vae.to(dtype)
|
| 818 |
+
|
| 819 |
+
init_latents = init_latents.to(dtype)
|
| 820 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
| 821 |
+
|
| 822 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 823 |
+
# expand init_latents for batch_size
|
| 824 |
+
error_message = (
|
| 825 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
| 826 |
+
" images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
|
| 827 |
+
)
|
| 828 |
+
raise ValueError(error_message)
|
| 829 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
| 830 |
+
raise ValueError(
|
| 831 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 832 |
+
)
|
| 833 |
+
else:
|
| 834 |
+
init_latents = torch.cat([init_latents], dim=0)
|
| 835 |
+
|
| 836 |
+
noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
| 837 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
|
| 838 |
+
else:
|
| 839 |
+
if shape != latents.shape:
|
| 840 |
+
# [B, C, F, H, W]
|
| 841 |
+
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
| 842 |
+
|
| 843 |
+
latents = latents.to(device, dtype=dtype)
|
| 844 |
+
|
| 845 |
+
if add_noise:
|
| 846 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 847 |
+
latents = self.scheduler.add_noise(latents, noise, timestep)
|
| 848 |
+
|
| 849 |
+
return latents
|
| 850 |
+
|
| 851 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet.AnimateDiffControlNetPipeline.prepare_video
|
| 852 |
+
def prepare_conditioning_frames(
|
| 853 |
+
self,
|
| 854 |
+
video,
|
| 855 |
+
width,
|
| 856 |
+
height,
|
| 857 |
+
batch_size,
|
| 858 |
+
num_videos_per_prompt,
|
| 859 |
+
device,
|
| 860 |
+
dtype,
|
| 861 |
+
do_classifier_free_guidance=False,
|
| 862 |
+
guess_mode=False,
|
| 863 |
+
):
|
| 864 |
+
video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
|
| 865 |
+
dtype=torch.float32
|
| 866 |
+
)
|
| 867 |
+
video = video.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 868 |
+
video_batch_size = video.shape[0]
|
| 869 |
+
|
| 870 |
+
if video_batch_size == 1:
|
| 871 |
+
repeat_by = batch_size
|
| 872 |
+
else:
|
| 873 |
+
# image batch size is the same as prompt batch size
|
| 874 |
+
repeat_by = num_videos_per_prompt
|
| 875 |
+
|
| 876 |
+
video = video.repeat_interleave(repeat_by, dim=0)
|
| 877 |
+
video = video.to(device=device, dtype=dtype)
|
| 878 |
+
|
| 879 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 880 |
+
video = torch.cat([video] * 2)
|
| 881 |
+
|
| 882 |
+
return video
|
| 883 |
+
|
| 884 |
+
@property
|
| 885 |
+
def guidance_scale(self):
|
| 886 |
+
return self._guidance_scale
|
| 887 |
+
|
| 888 |
+
@property
|
| 889 |
+
def clip_skip(self):
|
| 890 |
+
return self._clip_skip
|
| 891 |
+
|
| 892 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 893 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 894 |
+
# corresponds to doing no classifier free guidance.
|
| 895 |
+
@property
|
| 896 |
+
def do_classifier_free_guidance(self):
|
| 897 |
+
return self._guidance_scale > 1
|
| 898 |
+
|
| 899 |
+
@property
|
| 900 |
+
def cross_attention_kwargs(self):
|
| 901 |
+
return self._cross_attention_kwargs
|
| 902 |
+
|
| 903 |
+
@property
|
| 904 |
+
def num_timesteps(self):
|
| 905 |
+
return self._num_timesteps
|
| 906 |
+
|
| 907 |
+
@property
|
| 908 |
+
def interrupt(self):
|
| 909 |
+
return self._interrupt
|
| 910 |
+
|
| 911 |
+
@torch.no_grad()
|
| 912 |
+
def __call__(
|
| 913 |
+
self,
|
| 914 |
+
video: List[List[PipelineImageInput]] = None,
|
| 915 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 916 |
+
height: Optional[int] = None,
|
| 917 |
+
width: Optional[int] = None,
|
| 918 |
+
num_inference_steps: int = 50,
|
| 919 |
+
enforce_inference_steps: bool = False,
|
| 920 |
+
timesteps: Optional[List[int]] = None,
|
| 921 |
+
sigmas: Optional[List[float]] = None,
|
| 922 |
+
guidance_scale: float = 7.5,
|
| 923 |
+
strength: float = 0.8,
|
| 924 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 925 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 926 |
+
eta: float = 0.0,
|
| 927 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 928 |
+
latents: Optional[torch.Tensor] = None,
|
| 929 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 930 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 931 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 932 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 933 |
+
conditioning_frames: Optional[List[PipelineImageInput]] = None,
|
| 934 |
+
output_type: Optional[str] = "pil",
|
| 935 |
+
return_dict: bool = True,
|
| 936 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 937 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 938 |
+
guess_mode: bool = False,
|
| 939 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 940 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 941 |
+
clip_skip: Optional[int] = None,
|
| 942 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 943 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 944 |
+
decode_chunk_size: int = 16,
|
| 945 |
+
):
|
| 946 |
+
r"""
|
| 947 |
+
The call function to the pipeline for generation.
|
| 948 |
+
|
| 949 |
+
Args:
|
| 950 |
+
video (`List[PipelineImageInput]`):
|
| 951 |
+
The input video to condition the generation on. Must be a list of images/frames of the video.
|
| 952 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 953 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 954 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 955 |
+
The height in pixels of the generated video.
|
| 956 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 957 |
+
The width in pixels of the generated video.
|
| 958 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 959 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
| 960 |
+
expense of slower inference.
|
| 961 |
+
timesteps (`List[int]`, *optional*):
|
| 962 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 963 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 964 |
+
passed will be used. Must be in descending order.
|
| 965 |
+
sigmas (`List[float]`, *optional*):
|
| 966 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 967 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 968 |
+
will be used.
|
| 969 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 970 |
+
Higher strength leads to more differences between original video and generated video.
|
| 971 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 972 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 973 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 974 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 975 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 976 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 977 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 978 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 979 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 980 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 981 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 982 |
+
generation deterministic.
|
| 983 |
+
latents (`torch.Tensor`, *optional*):
|
| 984 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 985 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 986 |
+
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
| 987 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
| 988 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 989 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 990 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 991 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 992 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 993 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 994 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
| 995 |
+
Optional image input to work with IP Adapters.
|
| 996 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 997 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 998 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 999 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 1000 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 1001 |
+
conditioning_frames (`List[PipelineImageInput]`, *optional*):
|
| 1002 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
|
| 1003 |
+
ControlNets are specified, images must be passed as a list such that each element of the list can be
|
| 1004 |
+
correctly batched for input to a single ControlNet.
|
| 1005 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1006 |
+
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
|
| 1007 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1008 |
+
Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
|
| 1009 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 1010 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 1011 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1012 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 1013 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
| 1014 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
| 1015 |
+
the corresponding scale as a list.
|
| 1016 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
| 1017 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
| 1018 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
| 1019 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
| 1020 |
+
The percentage of total steps at which the ControlNet starts applying.
|
| 1021 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 1022 |
+
The percentage of total steps at which the ControlNet stops applying.
|
| 1023 |
+
clip_skip (`int`, *optional*):
|
| 1024 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 1025 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 1026 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 1027 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 1028 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 1029 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 1030 |
+
`callback_on_step_end_tensor_inputs`.
|
| 1031 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1032 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 1033 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 1034 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 1035 |
+
decode_chunk_size (`int`, defaults to `16`):
|
| 1036 |
+
The number of frames to decode at a time when calling `decode_latents` method.
|
| 1037 |
+
|
| 1038 |
+
Examples:
|
| 1039 |
+
|
| 1040 |
+
Returns:
|
| 1041 |
+
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
| 1042 |
+
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
| 1043 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
| 1044 |
+
"""
|
| 1045 |
+
|
| 1046 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 1047 |
+
|
| 1048 |
+
# align format for control guidance
|
| 1049 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 1050 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 1051 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 1052 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 1053 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 1054 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 1055 |
+
control_guidance_start, control_guidance_end = (
|
| 1056 |
+
mult * [control_guidance_start],
|
| 1057 |
+
mult * [control_guidance_end],
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
# 0. Default height and width to unet
|
| 1061 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 1062 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 1063 |
+
|
| 1064 |
+
num_videos_per_prompt = 1
|
| 1065 |
+
|
| 1066 |
+
# 1. Check inputs. Raise error if not correct
|
| 1067 |
+
self.check_inputs(
|
| 1068 |
+
prompt=prompt,
|
| 1069 |
+
strength=strength,
|
| 1070 |
+
height=height,
|
| 1071 |
+
width=width,
|
| 1072 |
+
negative_prompt=negative_prompt,
|
| 1073 |
+
prompt_embeds=prompt_embeds,
|
| 1074 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1075 |
+
video=video,
|
| 1076 |
+
conditioning_frames=conditioning_frames,
|
| 1077 |
+
latents=latents,
|
| 1078 |
+
ip_adapter_image=ip_adapter_image,
|
| 1079 |
+
ip_adapter_image_embeds=ip_adapter_image_embeds,
|
| 1080 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1081 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 1082 |
+
control_guidance_start=control_guidance_start,
|
| 1083 |
+
control_guidance_end=control_guidance_end,
|
| 1084 |
+
)
|
| 1085 |
+
|
| 1086 |
+
self._guidance_scale = guidance_scale
|
| 1087 |
+
self._clip_skip = clip_skip
|
| 1088 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 1089 |
+
self._interrupt = False
|
| 1090 |
+
|
| 1091 |
+
# 2. Define call parameters
|
| 1092 |
+
if prompt is not None and isinstance(prompt, (str, dict)):
|
| 1093 |
+
batch_size = 1
|
| 1094 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1095 |
+
batch_size = len(prompt)
|
| 1096 |
+
else:
|
| 1097 |
+
batch_size = prompt_embeds.shape[0]
|
| 1098 |
+
|
| 1099 |
+
device = self._execution_device
|
| 1100 |
+
dtype = self.dtype
|
| 1101 |
+
|
| 1102 |
+
# 3. Prepare timesteps
|
| 1103 |
+
if not enforce_inference_steps:
|
| 1104 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1105 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 1106 |
+
)
|
| 1107 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
| 1108 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
| 1109 |
+
else:
|
| 1110 |
+
denoising_inference_steps = int(num_inference_steps / strength)
|
| 1111 |
+
timesteps, denoising_inference_steps = retrieve_timesteps(
|
| 1112 |
+
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
|
| 1113 |
+
)
|
| 1114 |
+
timesteps = timesteps[-num_inference_steps:]
|
| 1115 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
| 1116 |
+
|
| 1117 |
+
# 4. Prepare latent variables
|
| 1118 |
+
if latents is None:
|
| 1119 |
+
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
| 1120 |
+
# Move the number of frames before the number of channels.
|
| 1121 |
+
video = video.permute(0, 2, 1, 3, 4)
|
| 1122 |
+
video = video.to(device=device, dtype=dtype)
|
| 1123 |
+
|
| 1124 |
+
num_channels_latents = self.unet.config.in_channels
|
| 1125 |
+
latents = self.prepare_latents(
|
| 1126 |
+
video=video,
|
| 1127 |
+
height=height,
|
| 1128 |
+
width=width,
|
| 1129 |
+
num_channels_latents=num_channels_latents,
|
| 1130 |
+
batch_size=batch_size * num_videos_per_prompt,
|
| 1131 |
+
timestep=latent_timestep,
|
| 1132 |
+
dtype=dtype,
|
| 1133 |
+
device=device,
|
| 1134 |
+
generator=generator,
|
| 1135 |
+
latents=latents,
|
| 1136 |
+
decode_chunk_size=decode_chunk_size,
|
| 1137 |
+
add_noise=enforce_inference_steps,
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
# 5. Encode input prompt
|
| 1141 |
+
text_encoder_lora_scale = (
|
| 1142 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 1143 |
+
)
|
| 1144 |
+
num_frames = latents.shape[2]
|
| 1145 |
+
if self.free_noise_enabled:
|
| 1146 |
+
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
| 1147 |
+
prompt=prompt,
|
| 1148 |
+
num_frames=num_frames,
|
| 1149 |
+
device=device,
|
| 1150 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 1151 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1152 |
+
negative_prompt=negative_prompt,
|
| 1153 |
+
prompt_embeds=prompt_embeds,
|
| 1154 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1155 |
+
lora_scale=text_encoder_lora_scale,
|
| 1156 |
+
clip_skip=self.clip_skip,
|
| 1157 |
+
)
|
| 1158 |
+
else:
|
| 1159 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 1160 |
+
prompt,
|
| 1161 |
+
device,
|
| 1162 |
+
num_videos_per_prompt,
|
| 1163 |
+
self.do_classifier_free_guidance,
|
| 1164 |
+
negative_prompt,
|
| 1165 |
+
prompt_embeds=prompt_embeds,
|
| 1166 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1167 |
+
lora_scale=text_encoder_lora_scale,
|
| 1168 |
+
clip_skip=self.clip_skip,
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 1172 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 1173 |
+
# to avoid doing two forward passes
|
| 1174 |
+
if self.do_classifier_free_guidance:
|
| 1175 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 1176 |
+
|
| 1177 |
+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
| 1178 |
+
|
| 1179 |
+
# 6. Prepare IP-Adapter embeddings
|
| 1180 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1181 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1182 |
+
ip_adapter_image,
|
| 1183 |
+
ip_adapter_image_embeds,
|
| 1184 |
+
device,
|
| 1185 |
+
batch_size * num_videos_per_prompt,
|
| 1186 |
+
self.do_classifier_free_guidance,
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
# 7. Prepare ControlNet conditions
|
| 1190 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 1191 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
| 1192 |
+
|
| 1193 |
+
global_pool_conditions = (
|
| 1194 |
+
controlnet.config.global_pool_conditions
|
| 1195 |
+
if isinstance(controlnet, ControlNetModel)
|
| 1196 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 1197 |
+
)
|
| 1198 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 1199 |
+
|
| 1200 |
+
controlnet_keep = []
|
| 1201 |
+
for i in range(len(timesteps)):
|
| 1202 |
+
keeps = [
|
| 1203 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 1204 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 1205 |
+
]
|
| 1206 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
| 1207 |
+
|
| 1208 |
+
if isinstance(controlnet, ControlNetModel):
|
| 1209 |
+
conditioning_frames = self.prepare_conditioning_frames(
|
| 1210 |
+
video=conditioning_frames,
|
| 1211 |
+
width=width,
|
| 1212 |
+
height=height,
|
| 1213 |
+
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
| 1214 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 1215 |
+
device=device,
|
| 1216 |
+
dtype=controlnet.dtype,
|
| 1217 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1218 |
+
guess_mode=guess_mode,
|
| 1219 |
+
)
|
| 1220 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
| 1221 |
+
cond_prepared_videos = []
|
| 1222 |
+
for frame_ in conditioning_frames:
|
| 1223 |
+
prepared_video = self.prepare_conditioning_frames(
|
| 1224 |
+
video=frame_,
|
| 1225 |
+
width=width,
|
| 1226 |
+
height=height,
|
| 1227 |
+
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
| 1228 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 1229 |
+
device=device,
|
| 1230 |
+
dtype=controlnet.dtype,
|
| 1231 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1232 |
+
guess_mode=guess_mode,
|
| 1233 |
+
)
|
| 1234 |
+
cond_prepared_videos.append(prepared_video)
|
| 1235 |
+
conditioning_frames = cond_prepared_videos
|
| 1236 |
+
else:
|
| 1237 |
+
assert False
|
| 1238 |
+
|
| 1239 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1240 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1241 |
+
|
| 1242 |
+
# 9. Add image embeds for IP-Adapter
|
| 1243 |
+
added_cond_kwargs = (
|
| 1244 |
+
{"image_embeds": image_embeds}
|
| 1245 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| 1246 |
+
else None
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
| 1250 |
+
for free_init_iter in range(num_free_init_iters):
|
| 1251 |
+
if self.free_init_enabled:
|
| 1252 |
+
latents, timesteps = self._apply_free_init(
|
| 1253 |
+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
| 1254 |
+
)
|
| 1255 |
+
num_inference_steps = len(timesteps)
|
| 1256 |
+
# make sure to readjust timesteps based on strength
|
| 1257 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
| 1258 |
+
|
| 1259 |
+
self._num_timesteps = len(timesteps)
|
| 1260 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1261 |
+
|
| 1262 |
+
# 10. Denoising loop
|
| 1263 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
| 1264 |
+
for i, t in enumerate(timesteps):
|
| 1265 |
+
if self.interrupt:
|
| 1266 |
+
continue
|
| 1267 |
+
|
| 1268 |
+
# expand the latents if we are doing classifier free guidance
|
| 1269 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1270 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1271 |
+
|
| 1272 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 1273 |
+
# Infer ControlNet only for the conditional batch.
|
| 1274 |
+
control_model_input = latents
|
| 1275 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 1276 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 1277 |
+
else:
|
| 1278 |
+
control_model_input = latent_model_input
|
| 1279 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 1280 |
+
|
| 1281 |
+
if isinstance(controlnet_keep[i], list):
|
| 1282 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 1283 |
+
else:
|
| 1284 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 1285 |
+
if isinstance(controlnet_cond_scale, list):
|
| 1286 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 1287 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 1288 |
+
|
| 1289 |
+
control_model_input = torch.transpose(control_model_input, 1, 2)
|
| 1290 |
+
control_model_input = control_model_input.reshape(
|
| 1291 |
+
(-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1295 |
+
control_model_input,
|
| 1296 |
+
t,
|
| 1297 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 1298 |
+
controlnet_cond=conditioning_frames,
|
| 1299 |
+
conditioning_scale=cond_scale,
|
| 1300 |
+
guess_mode=guess_mode,
|
| 1301 |
+
return_dict=False,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
# predict the noise residual
|
| 1305 |
+
noise_pred = self.unet(
|
| 1306 |
+
latent_model_input,
|
| 1307 |
+
t,
|
| 1308 |
+
encoder_hidden_states=prompt_embeds,
|
| 1309 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1310 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1311 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 1312 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 1313 |
+
).sample
|
| 1314 |
+
|
| 1315 |
+
# perform guidance
|
| 1316 |
+
if self.do_classifier_free_guidance:
|
| 1317 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1318 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1319 |
+
|
| 1320 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1321 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 1322 |
+
|
| 1323 |
+
if callback_on_step_end is not None:
|
| 1324 |
+
callback_kwargs = {}
|
| 1325 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1326 |
+
callback_kwargs[k] = locals()[k]
|
| 1327 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1328 |
+
|
| 1329 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1330 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1331 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1332 |
+
|
| 1333 |
+
# call the callback, if provided
|
| 1334 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1335 |
+
progress_bar.update()
|
| 1336 |
+
|
| 1337 |
+
if XLA_AVAILABLE:
|
| 1338 |
+
xm.mark_step()
|
| 1339 |
+
|
| 1340 |
+
# 11. Post-processing
|
| 1341 |
+
if output_type == "latent":
|
| 1342 |
+
video = latents
|
| 1343 |
+
else:
|
| 1344 |
+
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
| 1345 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
| 1346 |
+
|
| 1347 |
+
# 12. Offload all models
|
| 1348 |
+
self.maybe_free_model_hooks()
|
| 1349 |
+
|
| 1350 |
+
if not return_dict:
|
| 1351 |
+
return (video,)
|
| 1352 |
+
|
| 1353 |
+
return AnimateDiffPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_output.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ...utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class AnimateDiffPipelineOutput(BaseOutput):
|
| 13 |
+
r"""
|
| 14 |
+
Output class for AnimateDiff pipelines.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 18 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 19 |
+
denoised
|
| 20 |
+
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 21 |
+
`(batch_size, num_frames, channels, height, width)`
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
is_torch_available,
|
| 8 |
+
is_transformers_available,
|
| 9 |
+
is_transformers_version,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
| 18 |
+
raise OptionalDependencyNotAvailable()
|
| 19 |
+
except OptionalDependencyNotAvailable:
|
| 20 |
+
from ...utils.dummy_torch_and_transformers_objects import (
|
| 21 |
+
AudioLDMPipeline,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
_dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline})
|
| 25 |
+
else:
|
| 26 |
+
_import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 30 |
+
try:
|
| 31 |
+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
| 32 |
+
raise OptionalDependencyNotAvailable()
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils.dummy_torch_and_transformers_objects import (
|
| 35 |
+
AudioLDMPipeline,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
from .pipeline_audioldm import AudioLDMPipeline
|
| 40 |
+
else:
|
| 41 |
+
import sys
|
| 42 |
+
|
| 43 |
+
sys.modules[__name__] = _LazyModule(
|
| 44 |
+
__name__,
|
| 45 |
+
globals()["__file__"],
|
| 46 |
+
_import_structure,
|
| 47 |
+
module_spec=__spec__,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
for name, value in _dummy_objects.items():
|
| 51 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/pipeline_audioldm.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan
|
| 22 |
+
|
| 23 |
+
from ...models import AutoencoderKL, UNet2DConditionModel
|
| 24 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 25 |
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 26 |
+
from ...utils.torch_utils import randn_tensor
|
| 27 |
+
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if is_torch_xla_available():
|
| 31 |
+
import torch_xla.core.xla_model as xm
|
| 32 |
+
|
| 33 |
+
XLA_AVAILABLE = True
|
| 34 |
+
else:
|
| 35 |
+
XLA_AVAILABLE = False
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
EXAMPLE_DOC_STRING = """
|
| 41 |
+
Examples:
|
| 42 |
+
```py
|
| 43 |
+
>>> from diffusers import AudioLDMPipeline
|
| 44 |
+
>>> import torch
|
| 45 |
+
>>> import scipy
|
| 46 |
+
|
| 47 |
+
>>> repo_id = "cvssp/audioldm-s-full-v2"
|
| 48 |
+
>>> pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
| 49 |
+
>>> pipe = pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
|
| 52 |
+
>>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
|
| 53 |
+
|
| 54 |
+
>>> # save the audio sample as a .wav file
|
| 55 |
+
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
| 56 |
+
```
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AudioLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
| 61 |
+
r"""
|
| 62 |
+
Pipeline for text-to-audio generation using AudioLDM.
|
| 63 |
+
|
| 64 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 65 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
vae ([`AutoencoderKL`]):
|
| 69 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 70 |
+
text_encoder ([`~transformers.ClapTextModelWithProjection`]):
|
| 71 |
+
Frozen text-encoder (`ClapTextModelWithProjection`, specifically the
|
| 72 |
+
[laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
|
| 73 |
+
tokenizer ([`PreTrainedTokenizer`]):
|
| 74 |
+
A [`~transformers.RobertaTokenizer`] to tokenize text.
|
| 75 |
+
unet ([`UNet2DConditionModel`]):
|
| 76 |
+
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
| 77 |
+
scheduler ([`SchedulerMixin`]):
|
| 78 |
+
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
| 79 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 80 |
+
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
| 81 |
+
Vocoder of class `SpeechT5HifiGan`.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
_last_supported_version = "0.33.1"
|
| 85 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
vae: AutoencoderKL,
|
| 90 |
+
text_encoder: ClapTextModelWithProjection,
|
| 91 |
+
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
| 92 |
+
unet: UNet2DConditionModel,
|
| 93 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 94 |
+
vocoder: SpeechT5HifiGan,
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.register_modules(
|
| 99 |
+
vae=vae,
|
| 100 |
+
text_encoder=text_encoder,
|
| 101 |
+
tokenizer=tokenizer,
|
| 102 |
+
unet=unet,
|
| 103 |
+
scheduler=scheduler,
|
| 104 |
+
vocoder=vocoder,
|
| 105 |
+
)
|
| 106 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 107 |
+
|
| 108 |
+
def _encode_prompt(
|
| 109 |
+
self,
|
| 110 |
+
prompt,
|
| 111 |
+
device,
|
| 112 |
+
num_waveforms_per_prompt,
|
| 113 |
+
do_classifier_free_guidance,
|
| 114 |
+
negative_prompt=None,
|
| 115 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 116 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 117 |
+
):
|
| 118 |
+
r"""
|
| 119 |
+
Encodes the prompt into text encoder hidden states.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 123 |
+
prompt to be encoded
|
| 124 |
+
device (`torch.device`):
|
| 125 |
+
torch device
|
| 126 |
+
num_waveforms_per_prompt (`int`):
|
| 127 |
+
number of waveforms that should be generated per prompt
|
| 128 |
+
do_classifier_free_guidance (`bool`):
|
| 129 |
+
whether to use classifier free guidance or not
|
| 130 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 131 |
+
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
| 132 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 133 |
+
less than `1`).
|
| 134 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 135 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 136 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 137 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 138 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 139 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 140 |
+
argument.
|
| 141 |
+
"""
|
| 142 |
+
if prompt is not None and isinstance(prompt, str):
|
| 143 |
+
batch_size = 1
|
| 144 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 145 |
+
batch_size = len(prompt)
|
| 146 |
+
else:
|
| 147 |
+
batch_size = prompt_embeds.shape[0]
|
| 148 |
+
|
| 149 |
+
if prompt_embeds is None:
|
| 150 |
+
text_inputs = self.tokenizer(
|
| 151 |
+
prompt,
|
| 152 |
+
padding="max_length",
|
| 153 |
+
max_length=self.tokenizer.model_max_length,
|
| 154 |
+
truncation=True,
|
| 155 |
+
return_tensors="pt",
|
| 156 |
+
)
|
| 157 |
+
text_input_ids = text_inputs.input_ids
|
| 158 |
+
attention_mask = text_inputs.attention_mask
|
| 159 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 160 |
+
|
| 161 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 162 |
+
text_input_ids, untruncated_ids
|
| 163 |
+
):
|
| 164 |
+
removed_text = self.tokenizer.batch_decode(
|
| 165 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 166 |
+
)
|
| 167 |
+
logger.warning(
|
| 168 |
+
"The following part of your input was truncated because CLAP can only handle sequences up to"
|
| 169 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
prompt_embeds = self.text_encoder(
|
| 173 |
+
text_input_ids.to(device),
|
| 174 |
+
attention_mask=attention_mask.to(device),
|
| 175 |
+
)
|
| 176 |
+
prompt_embeds = prompt_embeds.text_embeds
|
| 177 |
+
# additional L_2 normalization over each hidden-state
|
| 178 |
+
prompt_embeds = F.normalize(prompt_embeds, dim=-1)
|
| 179 |
+
|
| 180 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 181 |
+
|
| 182 |
+
(
|
| 183 |
+
bs_embed,
|
| 184 |
+
seq_len,
|
| 185 |
+
) = prompt_embeds.shape
|
| 186 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 187 |
+
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
|
| 188 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
|
| 189 |
+
|
| 190 |
+
# get unconditional embeddings for classifier free guidance
|
| 191 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 192 |
+
uncond_tokens: List[str]
|
| 193 |
+
if negative_prompt is None:
|
| 194 |
+
uncond_tokens = [""] * batch_size
|
| 195 |
+
elif type(prompt) is not type(negative_prompt):
|
| 196 |
+
raise TypeError(
|
| 197 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 198 |
+
f" {type(prompt)}."
|
| 199 |
+
)
|
| 200 |
+
elif isinstance(negative_prompt, str):
|
| 201 |
+
uncond_tokens = [negative_prompt]
|
| 202 |
+
elif batch_size != len(negative_prompt):
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 205 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 206 |
+
" the batch size of `prompt`."
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
uncond_tokens = negative_prompt
|
| 210 |
+
|
| 211 |
+
max_length = prompt_embeds.shape[1]
|
| 212 |
+
uncond_input = self.tokenizer(
|
| 213 |
+
uncond_tokens,
|
| 214 |
+
padding="max_length",
|
| 215 |
+
max_length=max_length,
|
| 216 |
+
truncation=True,
|
| 217 |
+
return_tensors="pt",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
uncond_input_ids = uncond_input.input_ids.to(device)
|
| 221 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 222 |
+
|
| 223 |
+
negative_prompt_embeds = self.text_encoder(
|
| 224 |
+
uncond_input_ids,
|
| 225 |
+
attention_mask=attention_mask,
|
| 226 |
+
)
|
| 227 |
+
negative_prompt_embeds = negative_prompt_embeds.text_embeds
|
| 228 |
+
# additional L_2 normalization over each hidden-state
|
| 229 |
+
negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1)
|
| 230 |
+
|
| 231 |
+
if do_classifier_free_guidance:
|
| 232 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 233 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 234 |
+
|
| 235 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 236 |
+
|
| 237 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
|
| 238 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
|
| 239 |
+
|
| 240 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 241 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 242 |
+
# to avoid doing two forward passes
|
| 243 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 244 |
+
|
| 245 |
+
return prompt_embeds
|
| 246 |
+
|
| 247 |
+
def decode_latents(self, latents):
|
| 248 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 249 |
+
mel_spectrogram = self.vae.decode(latents).sample
|
| 250 |
+
return mel_spectrogram
|
| 251 |
+
|
| 252 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
| 253 |
+
if mel_spectrogram.dim() == 4:
|
| 254 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
| 255 |
+
|
| 256 |
+
waveform = self.vocoder(mel_spectrogram)
|
| 257 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 258 |
+
waveform = waveform.cpu().float()
|
| 259 |
+
return waveform
|
| 260 |
+
|
| 261 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 262 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 263 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 264 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 265 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 266 |
+
# and should be between [0, 1]
|
| 267 |
+
|
| 268 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 269 |
+
extra_step_kwargs = {}
|
| 270 |
+
if accepts_eta:
|
| 271 |
+
extra_step_kwargs["eta"] = eta
|
| 272 |
+
|
| 273 |
+
# check if the scheduler accepts generator
|
| 274 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 275 |
+
if accepts_generator:
|
| 276 |
+
extra_step_kwargs["generator"] = generator
|
| 277 |
+
return extra_step_kwargs
|
| 278 |
+
|
| 279 |
+
def check_inputs(
|
| 280 |
+
self,
|
| 281 |
+
prompt,
|
| 282 |
+
audio_length_in_s,
|
| 283 |
+
vocoder_upsample_factor,
|
| 284 |
+
callback_steps,
|
| 285 |
+
negative_prompt=None,
|
| 286 |
+
prompt_embeds=None,
|
| 287 |
+
negative_prompt_embeds=None,
|
| 288 |
+
):
|
| 289 |
+
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
| 290 |
+
if audio_length_in_s < min_audio_length_in_s:
|
| 291 |
+
raise ValueError(
|
| 292 |
+
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
| 293 |
+
f"is {audio_length_in_s}."
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
| 299 |
+
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
| 300 |
+
f"{self.vae_scale_factor}."
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if (callback_steps is None) or (
|
| 304 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 305 |
+
):
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 308 |
+
f" {type(callback_steps)}."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if prompt is not None and prompt_embeds is not None:
|
| 312 |
+
raise ValueError(
|
| 313 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 314 |
+
" only forward one of the two."
|
| 315 |
+
)
|
| 316 |
+
elif prompt is None and prompt_embeds is None:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 319 |
+
)
|
| 320 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 321 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 322 |
+
|
| 323 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 324 |
+
raise ValueError(
|
| 325 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 326 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 330 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 331 |
+
raise ValueError(
|
| 332 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 333 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 334 |
+
f" {negative_prompt_embeds.shape}."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
|
| 338 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
|
| 339 |
+
shape = (
|
| 340 |
+
batch_size,
|
| 341 |
+
num_channels_latents,
|
| 342 |
+
int(height) // self.vae_scale_factor,
|
| 343 |
+
int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
|
| 344 |
+
)
|
| 345 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 346 |
+
raise ValueError(
|
| 347 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 348 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
if latents is None:
|
| 352 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 353 |
+
else:
|
| 354 |
+
latents = latents.to(device)
|
| 355 |
+
|
| 356 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 357 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 358 |
+
return latents
|
| 359 |
+
|
| 360 |
+
@torch.no_grad()
|
| 361 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 362 |
+
def __call__(
|
| 363 |
+
self,
|
| 364 |
+
prompt: Union[str, List[str]] = None,
|
| 365 |
+
audio_length_in_s: Optional[float] = None,
|
| 366 |
+
num_inference_steps: int = 10,
|
| 367 |
+
guidance_scale: float = 2.5,
|
| 368 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 369 |
+
num_waveforms_per_prompt: Optional[int] = 1,
|
| 370 |
+
eta: float = 0.0,
|
| 371 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 372 |
+
latents: Optional[torch.Tensor] = None,
|
| 373 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 374 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 375 |
+
return_dict: bool = True,
|
| 376 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 377 |
+
callback_steps: Optional[int] = 1,
|
| 378 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 379 |
+
output_type: Optional[str] = "np",
|
| 380 |
+
):
|
| 381 |
+
r"""
|
| 382 |
+
The call function to the pipeline for generation.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 386 |
+
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
| 387 |
+
audio_length_in_s (`int`, *optional*, defaults to 5.12):
|
| 388 |
+
The length of the generated audio sample in seconds.
|
| 389 |
+
num_inference_steps (`int`, *optional*, defaults to 10):
|
| 390 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
| 391 |
+
expense of slower inference.
|
| 392 |
+
guidance_scale (`float`, *optional*, defaults to 2.5):
|
| 393 |
+
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
| 394 |
+
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 395 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 396 |
+
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
| 397 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 398 |
+
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
| 399 |
+
The number of waveforms to generate per prompt.
|
| 400 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 401 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 402 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 403 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 404 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 405 |
+
generation deterministic.
|
| 406 |
+
latents (`torch.Tensor`, *optional*):
|
| 407 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 408 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 409 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 410 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 411 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 412 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 413 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 414 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 415 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 416 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 417 |
+
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
| 418 |
+
callback (`Callable`, *optional*):
|
| 419 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 420 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 421 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 422 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 423 |
+
every step.
|
| 424 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 425 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 426 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 427 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
| 428 |
+
The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or
|
| 429 |
+
`"pt"` to return a PyTorch `torch.Tensor` object.
|
| 430 |
+
|
| 431 |
+
Examples:
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
[`~pipelines.AudioPipelineOutput`] or `tuple`:
|
| 435 |
+
If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
|
| 436 |
+
returned where the first element is a list with the generated audio.
|
| 437 |
+
"""
|
| 438 |
+
# 0. Convert audio input length from seconds to spectrogram height
|
| 439 |
+
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
| 440 |
+
|
| 441 |
+
if audio_length_in_s is None:
|
| 442 |
+
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
| 443 |
+
|
| 444 |
+
height = int(audio_length_in_s / vocoder_upsample_factor)
|
| 445 |
+
|
| 446 |
+
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
| 447 |
+
if height % self.vae_scale_factor != 0:
|
| 448 |
+
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
| 449 |
+
logger.info(
|
| 450 |
+
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
| 451 |
+
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
| 452 |
+
f"denoising process."
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# 1. Check inputs. Raise error if not correct
|
| 456 |
+
self.check_inputs(
|
| 457 |
+
prompt,
|
| 458 |
+
audio_length_in_s,
|
| 459 |
+
vocoder_upsample_factor,
|
| 460 |
+
callback_steps,
|
| 461 |
+
negative_prompt,
|
| 462 |
+
prompt_embeds,
|
| 463 |
+
negative_prompt_embeds,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# 2. Define call parameters
|
| 467 |
+
if prompt is not None and isinstance(prompt, str):
|
| 468 |
+
batch_size = 1
|
| 469 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 470 |
+
batch_size = len(prompt)
|
| 471 |
+
else:
|
| 472 |
+
batch_size = prompt_embeds.shape[0]
|
| 473 |
+
|
| 474 |
+
device = self._execution_device
|
| 475 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 476 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 477 |
+
# corresponds to doing no classifier free guidance.
|
| 478 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 479 |
+
|
| 480 |
+
# 3. Encode input prompt
|
| 481 |
+
prompt_embeds = self._encode_prompt(
|
| 482 |
+
prompt,
|
| 483 |
+
device,
|
| 484 |
+
num_waveforms_per_prompt,
|
| 485 |
+
do_classifier_free_guidance,
|
| 486 |
+
negative_prompt,
|
| 487 |
+
prompt_embeds=prompt_embeds,
|
| 488 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# 4. Prepare timesteps
|
| 492 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 493 |
+
timesteps = self.scheduler.timesteps
|
| 494 |
+
|
| 495 |
+
# 5. Prepare latent variables
|
| 496 |
+
num_channels_latents = self.unet.config.in_channels
|
| 497 |
+
latents = self.prepare_latents(
|
| 498 |
+
batch_size * num_waveforms_per_prompt,
|
| 499 |
+
num_channels_latents,
|
| 500 |
+
height,
|
| 501 |
+
prompt_embeds.dtype,
|
| 502 |
+
device,
|
| 503 |
+
generator,
|
| 504 |
+
latents,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# 6. Prepare extra step kwargs
|
| 508 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 509 |
+
|
| 510 |
+
# 7. Denoising loop
|
| 511 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 512 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 513 |
+
for i, t in enumerate(timesteps):
|
| 514 |
+
# expand the latents if we are doing classifier free guidance
|
| 515 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 516 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 517 |
+
|
| 518 |
+
# predict the noise residual
|
| 519 |
+
noise_pred = self.unet(
|
| 520 |
+
latent_model_input,
|
| 521 |
+
t,
|
| 522 |
+
encoder_hidden_states=None,
|
| 523 |
+
class_labels=prompt_embeds,
|
| 524 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 525 |
+
).sample
|
| 526 |
+
|
| 527 |
+
# perform guidance
|
| 528 |
+
if do_classifier_free_guidance:
|
| 529 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 530 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 531 |
+
|
| 532 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 533 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 534 |
+
|
| 535 |
+
# call the callback, if provided
|
| 536 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 537 |
+
progress_bar.update()
|
| 538 |
+
if callback is not None and i % callback_steps == 0:
|
| 539 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 540 |
+
callback(step_idx, t, latents)
|
| 541 |
+
|
| 542 |
+
if XLA_AVAILABLE:
|
| 543 |
+
xm.mark_step()
|
| 544 |
+
|
| 545 |
+
# 8. Post-processing
|
| 546 |
+
mel_spectrogram = self.decode_latents(latents)
|
| 547 |
+
|
| 548 |
+
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
| 549 |
+
|
| 550 |
+
audio = audio[:, :original_waveform_length]
|
| 551 |
+
|
| 552 |
+
if output_type == "np":
|
| 553 |
+
audio = audio.numpy()
|
| 554 |
+
|
| 555 |
+
if not return_dict:
|
| 556 |
+
return (audio,)
|
| 557 |
+
|
| 558 |
+
return AudioPipelineOutput(audios=audio)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
is_transformers_version,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_dummy_objects = {}
|
| 15 |
+
_import_structure = {}
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"]
|
| 26 |
+
_import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 30 |
+
try:
|
| 31 |
+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
| 32 |
+
raise OptionalDependencyNotAvailable()
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 35 |
+
|
| 36 |
+
else:
|
| 37 |
+
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
| 38 |
+
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
| 39 |
+
|
| 40 |
+
else:
|
| 41 |
+
import sys
|
| 42 |
+
|
| 43 |
+
sys.modules[__name__] = _LazyModule(
|
| 44 |
+
__name__,
|
| 45 |
+
globals()["__file__"],
|
| 46 |
+
_import_structure,
|
| 47 |
+
module_spec=__spec__,
|
| 48 |
+
)
|
| 49 |
+
for name, value in _dummy_objects.items():
|
| 50 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/modeling_audioldm2.py
ADDED
|
@@ -0,0 +1,1475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ...loaders import UNet2DConditionLoadersMixin
|
| 24 |
+
from ...models.activations import get_activation
|
| 25 |
+
from ...models.attention_processor import (
|
| 26 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 27 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 28 |
+
AttentionProcessor,
|
| 29 |
+
AttnAddedKVProcessor,
|
| 30 |
+
AttnProcessor,
|
| 31 |
+
)
|
| 32 |
+
from ...models.embeddings import (
|
| 33 |
+
TimestepEmbedding,
|
| 34 |
+
Timesteps,
|
| 35 |
+
)
|
| 36 |
+
from ...models.modeling_utils import ModelMixin
|
| 37 |
+
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
| 38 |
+
from ...models.transformers.transformer_2d import Transformer2DModel
|
| 39 |
+
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
|
| 40 |
+
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 41 |
+
from ...utils import BaseOutput, logging
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token):
|
| 48 |
+
batch_size = hidden_states.shape[0]
|
| 49 |
+
|
| 50 |
+
if attention_mask is not None:
|
| 51 |
+
# Add two more steps to attn mask
|
| 52 |
+
new_attn_mask_step = attention_mask.new_ones((batch_size, 1))
|
| 53 |
+
attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1)
|
| 54 |
+
|
| 55 |
+
# Add the SOS / EOS tokens at the start / end of the sequence respectively
|
| 56 |
+
sos_token = sos_token.expand(batch_size, 1, -1)
|
| 57 |
+
eos_token = eos_token.expand(batch_size, 1, -1)
|
| 58 |
+
hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1)
|
| 59 |
+
return hidden_states, attention_mask
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class AudioLDM2ProjectionModelOutput(BaseOutput):
|
| 64 |
+
"""
|
| 65 |
+
Args:
|
| 66 |
+
Class for AudioLDM2 projection layer's outputs.
|
| 67 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 68 |
+
Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text
|
| 69 |
+
encoders and subsequently concatenating them together.
|
| 70 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 71 |
+
Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks
|
| 72 |
+
for the two text encoders together. Mask values selected in `[0, 1]`:
|
| 73 |
+
|
| 74 |
+
- 1 for tokens that are **not masked**,
|
| 75 |
+
- 0 for tokens that are **masked**.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
hidden_states: torch.Tensor
|
| 79 |
+
attention_mask: Optional[torch.LongTensor] = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
|
| 83 |
+
"""
|
| 84 |
+
A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned
|
| 85 |
+
embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with
|
| 86 |
+
`_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
text_encoder_dim (`int`):
|
| 90 |
+
Dimensionality of the text embeddings from the first text encoder (CLAP).
|
| 91 |
+
text_encoder_1_dim (`int`):
|
| 92 |
+
Dimensionality of the text embeddings from the second text encoder (T5 or VITS).
|
| 93 |
+
langauge_model_dim (`int`):
|
| 94 |
+
Dimensionality of the text embeddings from the language model (GPT2).
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
@register_to_config
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
text_encoder_dim,
|
| 101 |
+
text_encoder_1_dim,
|
| 102 |
+
langauge_model_dim,
|
| 103 |
+
use_learned_position_embedding=None,
|
| 104 |
+
max_seq_length=None,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
# additional projection layers for each text encoder
|
| 108 |
+
self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
|
| 109 |
+
self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim)
|
| 110 |
+
|
| 111 |
+
# learnable SOS / EOS token embeddings for each text encoder
|
| 112 |
+
self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim))
|
| 113 |
+
self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim))
|
| 114 |
+
|
| 115 |
+
self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
|
| 116 |
+
self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
|
| 117 |
+
|
| 118 |
+
self.use_learned_position_embedding = use_learned_position_embedding
|
| 119 |
+
|
| 120 |
+
# learable positional embedding for vits encoder
|
| 121 |
+
if self.use_learned_position_embedding is not None:
|
| 122 |
+
self.learnable_positional_embedding = torch.nn.Parameter(
|
| 123 |
+
torch.zeros((1, text_encoder_1_dim, max_seq_length))
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
hidden_states: Optional[torch.Tensor] = None,
|
| 129 |
+
hidden_states_1: Optional[torch.Tensor] = None,
|
| 130 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 131 |
+
attention_mask_1: Optional[torch.LongTensor] = None,
|
| 132 |
+
):
|
| 133 |
+
hidden_states = self.projection(hidden_states)
|
| 134 |
+
hidden_states, attention_mask = add_special_tokens(
|
| 135 |
+
hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Add positional embedding for Vits hidden state
|
| 139 |
+
if self.use_learned_position_embedding is not None:
|
| 140 |
+
hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1)
|
| 141 |
+
|
| 142 |
+
hidden_states_1 = self.projection_1(hidden_states_1)
|
| 143 |
+
hidden_states_1, attention_mask_1 = add_special_tokens(
|
| 144 |
+
hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# concatenate clap and t5 text encoding
|
| 148 |
+
hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1)
|
| 149 |
+
|
| 150 |
+
# concatenate attention masks
|
| 151 |
+
if attention_mask is None and attention_mask_1 is not None:
|
| 152 |
+
attention_mask = attention_mask_1.new_ones((hidden_states[:2]))
|
| 153 |
+
elif attention_mask is not None and attention_mask_1 is None:
|
| 154 |
+
attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2]))
|
| 155 |
+
|
| 156 |
+
if attention_mask is not None and attention_mask_1 is not None:
|
| 157 |
+
attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1)
|
| 158 |
+
else:
|
| 159 |
+
attention_mask = None
|
| 160 |
+
|
| 161 |
+
return AudioLDM2ProjectionModelOutput(
|
| 162 |
+
hidden_states=hidden_states,
|
| 163 |
+
attention_mask=attention_mask,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
| 168 |
+
r"""
|
| 169 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
| 170 |
+
shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
|
| 171 |
+
self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up
|
| 172 |
+
to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`.
|
| 173 |
+
|
| 174 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 175 |
+
for all models (such as downloading or saving).
|
| 176 |
+
|
| 177 |
+
Parameters:
|
| 178 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 179 |
+
Height and width of input/output sample.
|
| 180 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
| 181 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
| 182 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
| 183 |
+
Whether to flip the sin to cos in the time embedding.
|
| 184 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
| 185 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 186 |
+
The tuple of downsample blocks to use.
|
| 187 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
| 188 |
+
Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2.
|
| 189 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
| 190 |
+
The tuple of upsample blocks to use.
|
| 191 |
+
only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
| 192 |
+
Whether to include self-attention in the basic transformer blocks, see
|
| 193 |
+
[`~models.attention.BasicTransformerBlock`].
|
| 194 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
| 195 |
+
The tuple of output channels for each block.
|
| 196 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
| 197 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
| 198 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
| 199 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 200 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
| 201 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
| 202 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
| 203 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
| 204 |
+
The dimension of the cross attention features.
|
| 205 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
| 206 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 207 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 208 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 209 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
| 210 |
+
num_attention_heads (`int`, *optional*):
|
| 211 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
| 212 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
| 213 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
| 214 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 215 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
| 216 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 217 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
| 218 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 219 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 220 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
| 221 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
| 222 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
| 223 |
+
An optional override for the dimension of the projected time embedding.
|
| 224 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
| 225 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
| 226 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
| 227 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
| 228 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
| 229 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
| 230 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
| 231 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
| 232 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
| 233 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
| 234 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
| 235 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
| 236 |
+
embeddings with the class embeddings.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
_supports_gradient_checkpointing = True
|
| 240 |
+
|
| 241 |
+
@register_to_config
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
sample_size: Optional[int] = None,
|
| 245 |
+
in_channels: int = 4,
|
| 246 |
+
out_channels: int = 4,
|
| 247 |
+
flip_sin_to_cos: bool = True,
|
| 248 |
+
freq_shift: int = 0,
|
| 249 |
+
down_block_types: Tuple[str] = (
|
| 250 |
+
"CrossAttnDownBlock2D",
|
| 251 |
+
"CrossAttnDownBlock2D",
|
| 252 |
+
"CrossAttnDownBlock2D",
|
| 253 |
+
"DownBlock2D",
|
| 254 |
+
),
|
| 255 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
| 256 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
| 257 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 258 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 259 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
| 260 |
+
downsample_padding: int = 1,
|
| 261 |
+
mid_block_scale_factor: float = 1,
|
| 262 |
+
act_fn: str = "silu",
|
| 263 |
+
norm_num_groups: Optional[int] = 32,
|
| 264 |
+
norm_eps: float = 1e-5,
|
| 265 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
| 266 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
| 267 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 268 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
| 269 |
+
use_linear_projection: bool = False,
|
| 270 |
+
class_embed_type: Optional[str] = None,
|
| 271 |
+
num_class_embeds: Optional[int] = None,
|
| 272 |
+
upcast_attention: bool = False,
|
| 273 |
+
resnet_time_scale_shift: str = "default",
|
| 274 |
+
time_embedding_type: str = "positional",
|
| 275 |
+
time_embedding_dim: Optional[int] = None,
|
| 276 |
+
time_embedding_act_fn: Optional[str] = None,
|
| 277 |
+
timestep_post_act: Optional[str] = None,
|
| 278 |
+
time_cond_proj_dim: Optional[int] = None,
|
| 279 |
+
conv_in_kernel: int = 3,
|
| 280 |
+
conv_out_kernel: int = 3,
|
| 281 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 282 |
+
class_embeddings_concat: bool = False,
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
|
| 286 |
+
self.sample_size = sample_size
|
| 287 |
+
|
| 288 |
+
if num_attention_heads is not None:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
| 294 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
| 295 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
| 296 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
| 297 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
| 298 |
+
# which is why we correct for the naming here.
|
| 299 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
| 300 |
+
|
| 301 |
+
# Check inputs
|
| 302 |
+
if len(down_block_types) != len(up_block_types):
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if len(block_out_channels) != len(down_block_types):
|
| 308 |
+
raise ValueError(
|
| 309 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
| 313 |
+
raise ValueError(
|
| 314 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
| 328 |
+
raise ValueError(
|
| 329 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
| 333 |
+
raise ValueError(
|
| 334 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# input
|
| 338 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 339 |
+
self.conv_in = nn.Conv2d(
|
| 340 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# time
|
| 344 |
+
if time_embedding_type == "positional":
|
| 345 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
| 346 |
+
|
| 347 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 348 |
+
timestep_input_dim = block_out_channels[0]
|
| 349 |
+
else:
|
| 350 |
+
raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.")
|
| 351 |
+
|
| 352 |
+
self.time_embedding = TimestepEmbedding(
|
| 353 |
+
timestep_input_dim,
|
| 354 |
+
time_embed_dim,
|
| 355 |
+
act_fn=act_fn,
|
| 356 |
+
post_act_fn=timestep_post_act,
|
| 357 |
+
cond_proj_dim=time_cond_proj_dim,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# class embedding
|
| 361 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 362 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 363 |
+
elif class_embed_type == "timestep":
|
| 364 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
| 365 |
+
elif class_embed_type == "identity":
|
| 366 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 367 |
+
elif class_embed_type == "projection":
|
| 368 |
+
if projection_class_embeddings_input_dim is None:
|
| 369 |
+
raise ValueError(
|
| 370 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
| 371 |
+
)
|
| 372 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
| 373 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
| 374 |
+
# 2. it projects from an arbitrary input dimension.
|
| 375 |
+
#
|
| 376 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
| 377 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
| 378 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
| 379 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 380 |
+
elif class_embed_type == "simple_projection":
|
| 381 |
+
if projection_class_embeddings_input_dim is None:
|
| 382 |
+
raise ValueError(
|
| 383 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
| 384 |
+
)
|
| 385 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
| 386 |
+
else:
|
| 387 |
+
self.class_embedding = None
|
| 388 |
+
|
| 389 |
+
if time_embedding_act_fn is None:
|
| 390 |
+
self.time_embed_act = None
|
| 391 |
+
else:
|
| 392 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
| 393 |
+
|
| 394 |
+
self.down_blocks = nn.ModuleList([])
|
| 395 |
+
self.up_blocks = nn.ModuleList([])
|
| 396 |
+
|
| 397 |
+
if isinstance(only_cross_attention, bool):
|
| 398 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 399 |
+
|
| 400 |
+
if isinstance(num_attention_heads, int):
|
| 401 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 402 |
+
|
| 403 |
+
if isinstance(cross_attention_dim, int):
|
| 404 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
| 405 |
+
|
| 406 |
+
if isinstance(layers_per_block, int):
|
| 407 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
| 408 |
+
|
| 409 |
+
if isinstance(transformer_layers_per_block, int):
|
| 410 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 411 |
+
|
| 412 |
+
if class_embeddings_concat:
|
| 413 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
| 414 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
| 415 |
+
# regular time embeddings
|
| 416 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
| 417 |
+
else:
|
| 418 |
+
blocks_time_embed_dim = time_embed_dim
|
| 419 |
+
|
| 420 |
+
# down
|
| 421 |
+
output_channel = block_out_channels[0]
|
| 422 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 423 |
+
input_channel = output_channel
|
| 424 |
+
output_channel = block_out_channels[i]
|
| 425 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 426 |
+
|
| 427 |
+
down_block = get_down_block(
|
| 428 |
+
down_block_type,
|
| 429 |
+
num_layers=layers_per_block[i],
|
| 430 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 431 |
+
in_channels=input_channel,
|
| 432 |
+
out_channels=output_channel,
|
| 433 |
+
temb_channels=blocks_time_embed_dim,
|
| 434 |
+
add_downsample=not is_final_block,
|
| 435 |
+
resnet_eps=norm_eps,
|
| 436 |
+
resnet_act_fn=act_fn,
|
| 437 |
+
resnet_groups=norm_num_groups,
|
| 438 |
+
cross_attention_dim=cross_attention_dim[i],
|
| 439 |
+
num_attention_heads=num_attention_heads[i],
|
| 440 |
+
downsample_padding=downsample_padding,
|
| 441 |
+
use_linear_projection=use_linear_projection,
|
| 442 |
+
only_cross_attention=only_cross_attention[i],
|
| 443 |
+
upcast_attention=upcast_attention,
|
| 444 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 445 |
+
)
|
| 446 |
+
self.down_blocks.append(down_block)
|
| 447 |
+
|
| 448 |
+
# mid
|
| 449 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
| 450 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
| 451 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 452 |
+
in_channels=block_out_channels[-1],
|
| 453 |
+
temb_channels=blocks_time_embed_dim,
|
| 454 |
+
resnet_eps=norm_eps,
|
| 455 |
+
resnet_act_fn=act_fn,
|
| 456 |
+
output_scale_factor=mid_block_scale_factor,
|
| 457 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 458 |
+
cross_attention_dim=cross_attention_dim[-1],
|
| 459 |
+
num_attention_heads=num_attention_heads[-1],
|
| 460 |
+
resnet_groups=norm_num_groups,
|
| 461 |
+
use_linear_projection=use_linear_projection,
|
| 462 |
+
upcast_attention=upcast_attention,
|
| 463 |
+
)
|
| 464 |
+
else:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2."
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# count how many layers upsample the images
|
| 470 |
+
self.num_upsamplers = 0
|
| 471 |
+
|
| 472 |
+
# up
|
| 473 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 474 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 475 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
| 476 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
| 477 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
| 478 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 479 |
+
|
| 480 |
+
output_channel = reversed_block_out_channels[0]
|
| 481 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 482 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 483 |
+
|
| 484 |
+
prev_output_channel = output_channel
|
| 485 |
+
output_channel = reversed_block_out_channels[i]
|
| 486 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 487 |
+
|
| 488 |
+
# add upsample block for all BUT final layer
|
| 489 |
+
if not is_final_block:
|
| 490 |
+
add_upsample = True
|
| 491 |
+
self.num_upsamplers += 1
|
| 492 |
+
else:
|
| 493 |
+
add_upsample = False
|
| 494 |
+
|
| 495 |
+
up_block = get_up_block(
|
| 496 |
+
up_block_type,
|
| 497 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
| 498 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 499 |
+
in_channels=input_channel,
|
| 500 |
+
out_channels=output_channel,
|
| 501 |
+
prev_output_channel=prev_output_channel,
|
| 502 |
+
temb_channels=blocks_time_embed_dim,
|
| 503 |
+
add_upsample=add_upsample,
|
| 504 |
+
resnet_eps=norm_eps,
|
| 505 |
+
resnet_act_fn=act_fn,
|
| 506 |
+
resnet_groups=norm_num_groups,
|
| 507 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
| 508 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 509 |
+
use_linear_projection=use_linear_projection,
|
| 510 |
+
only_cross_attention=only_cross_attention[i],
|
| 511 |
+
upcast_attention=upcast_attention,
|
| 512 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 513 |
+
)
|
| 514 |
+
self.up_blocks.append(up_block)
|
| 515 |
+
prev_output_channel = output_channel
|
| 516 |
+
|
| 517 |
+
# out
|
| 518 |
+
if norm_num_groups is not None:
|
| 519 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 520 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
self.conv_act = get_activation(act_fn)
|
| 524 |
+
|
| 525 |
+
else:
|
| 526 |
+
self.conv_norm_out = None
|
| 527 |
+
self.conv_act = None
|
| 528 |
+
|
| 529 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
| 530 |
+
self.conv_out = nn.Conv2d(
|
| 531 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
@property
|
| 535 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 536 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 537 |
+
r"""
|
| 538 |
+
Returns:
|
| 539 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 540 |
+
indexed by its weight name.
|
| 541 |
+
"""
|
| 542 |
+
# set recursively
|
| 543 |
+
processors = {}
|
| 544 |
+
|
| 545 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 546 |
+
if hasattr(module, "get_processor"):
|
| 547 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 548 |
+
|
| 549 |
+
for sub_name, child in module.named_children():
|
| 550 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 551 |
+
|
| 552 |
+
return processors
|
| 553 |
+
|
| 554 |
+
for name, module in self.named_children():
|
| 555 |
+
fn_recursive_add_processors(name, module, processors)
|
| 556 |
+
|
| 557 |
+
return processors
|
| 558 |
+
|
| 559 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 560 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 561 |
+
r"""
|
| 562 |
+
Sets the attention processor to use to compute attention.
|
| 563 |
+
|
| 564 |
+
Parameters:
|
| 565 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 566 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 567 |
+
for **all** `Attention` layers.
|
| 568 |
+
|
| 569 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 570 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 571 |
+
|
| 572 |
+
"""
|
| 573 |
+
count = len(self.attn_processors.keys())
|
| 574 |
+
|
| 575 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 576 |
+
raise ValueError(
|
| 577 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 578 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 582 |
+
if hasattr(module, "set_processor"):
|
| 583 |
+
if not isinstance(processor, dict):
|
| 584 |
+
module.set_processor(processor)
|
| 585 |
+
else:
|
| 586 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 587 |
+
|
| 588 |
+
for sub_name, child in module.named_children():
|
| 589 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 590 |
+
|
| 591 |
+
for name, module in self.named_children():
|
| 592 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 593 |
+
|
| 594 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 595 |
+
def set_default_attn_processor(self):
|
| 596 |
+
"""
|
| 597 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 598 |
+
"""
|
| 599 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 600 |
+
processor = AttnAddedKVProcessor()
|
| 601 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 602 |
+
processor = AttnProcessor()
|
| 603 |
+
else:
|
| 604 |
+
raise ValueError(
|
| 605 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
self.set_attn_processor(processor)
|
| 609 |
+
|
| 610 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
| 611 |
+
def set_attention_slice(self, slice_size):
|
| 612 |
+
r"""
|
| 613 |
+
Enable sliced attention computation.
|
| 614 |
+
|
| 615 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
| 616 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 620 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
| 621 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
| 622 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 623 |
+
must be a multiple of `slice_size`.
|
| 624 |
+
"""
|
| 625 |
+
sliceable_head_dims = []
|
| 626 |
+
|
| 627 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
| 628 |
+
if hasattr(module, "set_attention_slice"):
|
| 629 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 630 |
+
|
| 631 |
+
for child in module.children():
|
| 632 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
| 633 |
+
|
| 634 |
+
# retrieve number of attention layers
|
| 635 |
+
for module in self.children():
|
| 636 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
| 637 |
+
|
| 638 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
| 639 |
+
|
| 640 |
+
if slice_size == "auto":
|
| 641 |
+
# half the attention head size is usually a good trade-off between
|
| 642 |
+
# speed and memory
|
| 643 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 644 |
+
elif slice_size == "max":
|
| 645 |
+
# make smallest slice possible
|
| 646 |
+
slice_size = num_sliceable_layers * [1]
|
| 647 |
+
|
| 648 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 649 |
+
|
| 650 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 651 |
+
raise ValueError(
|
| 652 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 653 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
for i in range(len(slice_size)):
|
| 657 |
+
size = slice_size[i]
|
| 658 |
+
dim = sliceable_head_dims[i]
|
| 659 |
+
if size is not None and size > dim:
|
| 660 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 661 |
+
|
| 662 |
+
# Recursively walk through all the children.
|
| 663 |
+
# Any children which exposes the set_attention_slice method
|
| 664 |
+
# gets the message
|
| 665 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 666 |
+
if hasattr(module, "set_attention_slice"):
|
| 667 |
+
module.set_attention_slice(slice_size.pop())
|
| 668 |
+
|
| 669 |
+
for child in module.children():
|
| 670 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 671 |
+
|
| 672 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 673 |
+
for module in self.children():
|
| 674 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 675 |
+
|
| 676 |
+
def forward(
|
| 677 |
+
self,
|
| 678 |
+
sample: torch.Tensor,
|
| 679 |
+
timestep: Union[torch.Tensor, float, int],
|
| 680 |
+
encoder_hidden_states: torch.Tensor,
|
| 681 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 682 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 683 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 684 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 685 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 686 |
+
return_dict: bool = True,
|
| 687 |
+
encoder_hidden_states_1: Optional[torch.Tensor] = None,
|
| 688 |
+
encoder_attention_mask_1: Optional[torch.Tensor] = None,
|
| 689 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 690 |
+
r"""
|
| 691 |
+
The [`AudioLDM2UNet2DConditionModel`] forward method.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
sample (`torch.Tensor`):
|
| 695 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
| 696 |
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 697 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 698 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
| 699 |
+
encoder_attention_mask (`torch.Tensor`):
|
| 700 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
| 701 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
| 702 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
| 703 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 704 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 705 |
+
tuple.
|
| 706 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 707 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
| 708 |
+
encoder_hidden_states_1 (`torch.Tensor`, *optional*):
|
| 709 |
+
A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be
|
| 710 |
+
used to condition the model on a different set of embeddings to `encoder_hidden_states`.
|
| 711 |
+
encoder_attention_mask_1 (`torch.Tensor`, *optional*):
|
| 712 |
+
A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`.
|
| 713 |
+
If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
| 714 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
| 715 |
+
|
| 716 |
+
Returns:
|
| 717 |
+
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 718 |
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
| 719 |
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
| 720 |
+
"""
|
| 721 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 722 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
| 723 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 724 |
+
# on the fly if necessary.
|
| 725 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 726 |
+
|
| 727 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 728 |
+
forward_upsample_size = False
|
| 729 |
+
upsample_size = None
|
| 730 |
+
|
| 731 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 732 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 733 |
+
forward_upsample_size = True
|
| 734 |
+
|
| 735 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
| 736 |
+
# expects mask of shape:
|
| 737 |
+
# [batch, key_tokens]
|
| 738 |
+
# adds singleton query_tokens dimension:
|
| 739 |
+
# [batch, 1, key_tokens]
|
| 740 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 741 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 742 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 743 |
+
if attention_mask is not None:
|
| 744 |
+
# assume that mask is expressed as:
|
| 745 |
+
# (1 = keep, 0 = discard)
|
| 746 |
+
# convert mask into a bias that can be added to attention scores:
|
| 747 |
+
# (keep = +0, discard = -10000.0)
|
| 748 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 749 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 750 |
+
|
| 751 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 752 |
+
if encoder_attention_mask is not None:
|
| 753 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 754 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 755 |
+
|
| 756 |
+
if encoder_attention_mask_1 is not None:
|
| 757 |
+
encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
|
| 758 |
+
encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
|
| 759 |
+
|
| 760 |
+
# 1. time
|
| 761 |
+
timesteps = timestep
|
| 762 |
+
if not torch.is_tensor(timesteps):
|
| 763 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 764 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 765 |
+
is_mps = sample.device.type == "mps"
|
| 766 |
+
is_npu = sample.device.type == "npu"
|
| 767 |
+
if isinstance(timestep, float):
|
| 768 |
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 769 |
+
else:
|
| 770 |
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
| 771 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 772 |
+
elif len(timesteps.shape) == 0:
|
| 773 |
+
timesteps = timesteps[None].to(sample.device)
|
| 774 |
+
|
| 775 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 776 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 777 |
+
|
| 778 |
+
t_emb = self.time_proj(timesteps)
|
| 779 |
+
|
| 780 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 781 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 782 |
+
# there might be better ways to encapsulate this.
|
| 783 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 784 |
+
|
| 785 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 786 |
+
aug_emb = None
|
| 787 |
+
|
| 788 |
+
if self.class_embedding is not None:
|
| 789 |
+
if class_labels is None:
|
| 790 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 791 |
+
|
| 792 |
+
if self.config.class_embed_type == "timestep":
|
| 793 |
+
class_labels = self.time_proj(class_labels)
|
| 794 |
+
|
| 795 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 796 |
+
# there might be better ways to encapsulate this.
|
| 797 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 798 |
+
|
| 799 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 800 |
+
|
| 801 |
+
if self.config.class_embeddings_concat:
|
| 802 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 803 |
+
else:
|
| 804 |
+
emb = emb + class_emb
|
| 805 |
+
|
| 806 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 807 |
+
|
| 808 |
+
if self.time_embed_act is not None:
|
| 809 |
+
emb = self.time_embed_act(emb)
|
| 810 |
+
|
| 811 |
+
# 2. pre-process
|
| 812 |
+
sample = self.conv_in(sample)
|
| 813 |
+
|
| 814 |
+
# 3. down
|
| 815 |
+
down_block_res_samples = (sample,)
|
| 816 |
+
for downsample_block in self.down_blocks:
|
| 817 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 818 |
+
sample, res_samples = downsample_block(
|
| 819 |
+
hidden_states=sample,
|
| 820 |
+
temb=emb,
|
| 821 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 822 |
+
attention_mask=attention_mask,
|
| 823 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 824 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 825 |
+
encoder_hidden_states_1=encoder_hidden_states_1,
|
| 826 |
+
encoder_attention_mask_1=encoder_attention_mask_1,
|
| 827 |
+
)
|
| 828 |
+
else:
|
| 829 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 830 |
+
|
| 831 |
+
down_block_res_samples += res_samples
|
| 832 |
+
|
| 833 |
+
# 4. mid
|
| 834 |
+
if self.mid_block is not None:
|
| 835 |
+
sample = self.mid_block(
|
| 836 |
+
sample,
|
| 837 |
+
emb,
|
| 838 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 839 |
+
attention_mask=attention_mask,
|
| 840 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 841 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 842 |
+
encoder_hidden_states_1=encoder_hidden_states_1,
|
| 843 |
+
encoder_attention_mask_1=encoder_attention_mask_1,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
# 5. up
|
| 847 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 848 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 849 |
+
|
| 850 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 851 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 852 |
+
|
| 853 |
+
# if we have not reached the final block and need to forward the
|
| 854 |
+
# upsample size, we do it here
|
| 855 |
+
if not is_final_block and forward_upsample_size:
|
| 856 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 857 |
+
|
| 858 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 859 |
+
sample = upsample_block(
|
| 860 |
+
hidden_states=sample,
|
| 861 |
+
temb=emb,
|
| 862 |
+
res_hidden_states_tuple=res_samples,
|
| 863 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 864 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 865 |
+
upsample_size=upsample_size,
|
| 866 |
+
attention_mask=attention_mask,
|
| 867 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 868 |
+
encoder_hidden_states_1=encoder_hidden_states_1,
|
| 869 |
+
encoder_attention_mask_1=encoder_attention_mask_1,
|
| 870 |
+
)
|
| 871 |
+
else:
|
| 872 |
+
sample = upsample_block(
|
| 873 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
# 6. post-process
|
| 877 |
+
if self.conv_norm_out:
|
| 878 |
+
sample = self.conv_norm_out(sample)
|
| 879 |
+
sample = self.conv_act(sample)
|
| 880 |
+
sample = self.conv_out(sample)
|
| 881 |
+
|
| 882 |
+
if not return_dict:
|
| 883 |
+
return (sample,)
|
| 884 |
+
|
| 885 |
+
return UNet2DConditionOutput(sample=sample)
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def get_down_block(
|
| 889 |
+
down_block_type,
|
| 890 |
+
num_layers,
|
| 891 |
+
in_channels,
|
| 892 |
+
out_channels,
|
| 893 |
+
temb_channels,
|
| 894 |
+
add_downsample,
|
| 895 |
+
resnet_eps,
|
| 896 |
+
resnet_act_fn,
|
| 897 |
+
transformer_layers_per_block=1,
|
| 898 |
+
num_attention_heads=None,
|
| 899 |
+
resnet_groups=None,
|
| 900 |
+
cross_attention_dim=None,
|
| 901 |
+
downsample_padding=None,
|
| 902 |
+
use_linear_projection=False,
|
| 903 |
+
only_cross_attention=False,
|
| 904 |
+
upcast_attention=False,
|
| 905 |
+
resnet_time_scale_shift="default",
|
| 906 |
+
):
|
| 907 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
| 908 |
+
if down_block_type == "DownBlock2D":
|
| 909 |
+
return DownBlock2D(
|
| 910 |
+
num_layers=num_layers,
|
| 911 |
+
in_channels=in_channels,
|
| 912 |
+
out_channels=out_channels,
|
| 913 |
+
temb_channels=temb_channels,
|
| 914 |
+
add_downsample=add_downsample,
|
| 915 |
+
resnet_eps=resnet_eps,
|
| 916 |
+
resnet_act_fn=resnet_act_fn,
|
| 917 |
+
resnet_groups=resnet_groups,
|
| 918 |
+
downsample_padding=downsample_padding,
|
| 919 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 920 |
+
)
|
| 921 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
| 922 |
+
if cross_attention_dim is None:
|
| 923 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
| 924 |
+
return CrossAttnDownBlock2D(
|
| 925 |
+
num_layers=num_layers,
|
| 926 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 927 |
+
in_channels=in_channels,
|
| 928 |
+
out_channels=out_channels,
|
| 929 |
+
temb_channels=temb_channels,
|
| 930 |
+
add_downsample=add_downsample,
|
| 931 |
+
resnet_eps=resnet_eps,
|
| 932 |
+
resnet_act_fn=resnet_act_fn,
|
| 933 |
+
resnet_groups=resnet_groups,
|
| 934 |
+
downsample_padding=downsample_padding,
|
| 935 |
+
cross_attention_dim=cross_attention_dim,
|
| 936 |
+
num_attention_heads=num_attention_heads,
|
| 937 |
+
use_linear_projection=use_linear_projection,
|
| 938 |
+
only_cross_attention=only_cross_attention,
|
| 939 |
+
upcast_attention=upcast_attention,
|
| 940 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 941 |
+
)
|
| 942 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
def get_up_block(
|
| 946 |
+
up_block_type,
|
| 947 |
+
num_layers,
|
| 948 |
+
in_channels,
|
| 949 |
+
out_channels,
|
| 950 |
+
prev_output_channel,
|
| 951 |
+
temb_channels,
|
| 952 |
+
add_upsample,
|
| 953 |
+
resnet_eps,
|
| 954 |
+
resnet_act_fn,
|
| 955 |
+
transformer_layers_per_block=1,
|
| 956 |
+
num_attention_heads=None,
|
| 957 |
+
resnet_groups=None,
|
| 958 |
+
cross_attention_dim=None,
|
| 959 |
+
use_linear_projection=False,
|
| 960 |
+
only_cross_attention=False,
|
| 961 |
+
upcast_attention=False,
|
| 962 |
+
resnet_time_scale_shift="default",
|
| 963 |
+
):
|
| 964 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
| 965 |
+
if up_block_type == "UpBlock2D":
|
| 966 |
+
return UpBlock2D(
|
| 967 |
+
num_layers=num_layers,
|
| 968 |
+
in_channels=in_channels,
|
| 969 |
+
out_channels=out_channels,
|
| 970 |
+
prev_output_channel=prev_output_channel,
|
| 971 |
+
temb_channels=temb_channels,
|
| 972 |
+
add_upsample=add_upsample,
|
| 973 |
+
resnet_eps=resnet_eps,
|
| 974 |
+
resnet_act_fn=resnet_act_fn,
|
| 975 |
+
resnet_groups=resnet_groups,
|
| 976 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 977 |
+
)
|
| 978 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
| 979 |
+
if cross_attention_dim is None:
|
| 980 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
| 981 |
+
return CrossAttnUpBlock2D(
|
| 982 |
+
num_layers=num_layers,
|
| 983 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 984 |
+
in_channels=in_channels,
|
| 985 |
+
out_channels=out_channels,
|
| 986 |
+
prev_output_channel=prev_output_channel,
|
| 987 |
+
temb_channels=temb_channels,
|
| 988 |
+
add_upsample=add_upsample,
|
| 989 |
+
resnet_eps=resnet_eps,
|
| 990 |
+
resnet_act_fn=resnet_act_fn,
|
| 991 |
+
resnet_groups=resnet_groups,
|
| 992 |
+
cross_attention_dim=cross_attention_dim,
|
| 993 |
+
num_attention_heads=num_attention_heads,
|
| 994 |
+
use_linear_projection=use_linear_projection,
|
| 995 |
+
only_cross_attention=only_cross_attention,
|
| 996 |
+
upcast_attention=upcast_attention,
|
| 997 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 998 |
+
)
|
| 999 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
class CrossAttnDownBlock2D(nn.Module):
|
| 1003 |
+
def __init__(
|
| 1004 |
+
self,
|
| 1005 |
+
in_channels: int,
|
| 1006 |
+
out_channels: int,
|
| 1007 |
+
temb_channels: int,
|
| 1008 |
+
dropout: float = 0.0,
|
| 1009 |
+
num_layers: int = 1,
|
| 1010 |
+
transformer_layers_per_block: int = 1,
|
| 1011 |
+
resnet_eps: float = 1e-6,
|
| 1012 |
+
resnet_time_scale_shift: str = "default",
|
| 1013 |
+
resnet_act_fn: str = "swish",
|
| 1014 |
+
resnet_groups: int = 32,
|
| 1015 |
+
resnet_pre_norm: bool = True,
|
| 1016 |
+
num_attention_heads=1,
|
| 1017 |
+
cross_attention_dim=1280,
|
| 1018 |
+
output_scale_factor=1.0,
|
| 1019 |
+
downsample_padding=1,
|
| 1020 |
+
add_downsample=True,
|
| 1021 |
+
use_linear_projection=False,
|
| 1022 |
+
only_cross_attention=False,
|
| 1023 |
+
upcast_attention=False,
|
| 1024 |
+
):
|
| 1025 |
+
super().__init__()
|
| 1026 |
+
resnets = []
|
| 1027 |
+
attentions = []
|
| 1028 |
+
|
| 1029 |
+
self.has_cross_attention = True
|
| 1030 |
+
self.num_attention_heads = num_attention_heads
|
| 1031 |
+
|
| 1032 |
+
if isinstance(cross_attention_dim, int):
|
| 1033 |
+
cross_attention_dim = (cross_attention_dim,)
|
| 1034 |
+
if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
|
| 1035 |
+
raise ValueError(
|
| 1036 |
+
"Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
|
| 1037 |
+
f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
|
| 1038 |
+
)
|
| 1039 |
+
self.cross_attention_dim = cross_attention_dim
|
| 1040 |
+
|
| 1041 |
+
for i in range(num_layers):
|
| 1042 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 1043 |
+
resnets.append(
|
| 1044 |
+
ResnetBlock2D(
|
| 1045 |
+
in_channels=in_channels,
|
| 1046 |
+
out_channels=out_channels,
|
| 1047 |
+
temb_channels=temb_channels,
|
| 1048 |
+
eps=resnet_eps,
|
| 1049 |
+
groups=resnet_groups,
|
| 1050 |
+
dropout=dropout,
|
| 1051 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1052 |
+
non_linearity=resnet_act_fn,
|
| 1053 |
+
output_scale_factor=output_scale_factor,
|
| 1054 |
+
pre_norm=resnet_pre_norm,
|
| 1055 |
+
)
|
| 1056 |
+
)
|
| 1057 |
+
for j in range(len(cross_attention_dim)):
|
| 1058 |
+
attentions.append(
|
| 1059 |
+
Transformer2DModel(
|
| 1060 |
+
num_attention_heads,
|
| 1061 |
+
out_channels // num_attention_heads,
|
| 1062 |
+
in_channels=out_channels,
|
| 1063 |
+
num_layers=transformer_layers_per_block,
|
| 1064 |
+
cross_attention_dim=cross_attention_dim[j],
|
| 1065 |
+
norm_num_groups=resnet_groups,
|
| 1066 |
+
use_linear_projection=use_linear_projection,
|
| 1067 |
+
only_cross_attention=only_cross_attention,
|
| 1068 |
+
upcast_attention=upcast_attention,
|
| 1069 |
+
double_self_attention=True if cross_attention_dim[j] is None else False,
|
| 1070 |
+
)
|
| 1071 |
+
)
|
| 1072 |
+
self.attentions = nn.ModuleList(attentions)
|
| 1073 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1074 |
+
|
| 1075 |
+
if add_downsample:
|
| 1076 |
+
self.downsamplers = nn.ModuleList(
|
| 1077 |
+
[
|
| 1078 |
+
Downsample2D(
|
| 1079 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
| 1080 |
+
)
|
| 1081 |
+
]
|
| 1082 |
+
)
|
| 1083 |
+
else:
|
| 1084 |
+
self.downsamplers = None
|
| 1085 |
+
|
| 1086 |
+
self.gradient_checkpointing = False
|
| 1087 |
+
|
| 1088 |
+
def forward(
|
| 1089 |
+
self,
|
| 1090 |
+
hidden_states: torch.Tensor,
|
| 1091 |
+
temb: Optional[torch.Tensor] = None,
|
| 1092 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1093 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1094 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1095 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1096 |
+
encoder_hidden_states_1: Optional[torch.Tensor] = None,
|
| 1097 |
+
encoder_attention_mask_1: Optional[torch.Tensor] = None,
|
| 1098 |
+
):
|
| 1099 |
+
output_states = ()
|
| 1100 |
+
num_layers = len(self.resnets)
|
| 1101 |
+
num_attention_per_layer = len(self.attentions) // num_layers
|
| 1102 |
+
|
| 1103 |
+
encoder_hidden_states_1 = (
|
| 1104 |
+
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
|
| 1105 |
+
)
|
| 1106 |
+
encoder_attention_mask_1 = (
|
| 1107 |
+
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
for i in range(num_layers):
|
| 1111 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1112 |
+
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
|
| 1113 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1114 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1115 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1116 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1117 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1118 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1119 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1120 |
+
else:
|
| 1121 |
+
forward_encoder_hidden_states = None
|
| 1122 |
+
forward_encoder_attention_mask = None
|
| 1123 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 1124 |
+
self.attentions[i * num_attention_per_layer + idx],
|
| 1125 |
+
hidden_states,
|
| 1126 |
+
forward_encoder_hidden_states,
|
| 1127 |
+
None, # timestep
|
| 1128 |
+
None, # class_labels
|
| 1129 |
+
cross_attention_kwargs,
|
| 1130 |
+
attention_mask,
|
| 1131 |
+
forward_encoder_attention_mask,
|
| 1132 |
+
)[0]
|
| 1133 |
+
else:
|
| 1134 |
+
hidden_states = self.resnets[i](hidden_states, temb)
|
| 1135 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1136 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1137 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1138 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1139 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1140 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1141 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1142 |
+
else:
|
| 1143 |
+
forward_encoder_hidden_states = None
|
| 1144 |
+
forward_encoder_attention_mask = None
|
| 1145 |
+
hidden_states = self.attentions[i * num_attention_per_layer + idx](
|
| 1146 |
+
hidden_states,
|
| 1147 |
+
attention_mask=attention_mask,
|
| 1148 |
+
encoder_hidden_states=forward_encoder_hidden_states,
|
| 1149 |
+
encoder_attention_mask=forward_encoder_attention_mask,
|
| 1150 |
+
return_dict=False,
|
| 1151 |
+
)[0]
|
| 1152 |
+
|
| 1153 |
+
output_states = output_states + (hidden_states,)
|
| 1154 |
+
|
| 1155 |
+
if self.downsamplers is not None:
|
| 1156 |
+
for downsampler in self.downsamplers:
|
| 1157 |
+
hidden_states = downsampler(hidden_states)
|
| 1158 |
+
|
| 1159 |
+
output_states = output_states + (hidden_states,)
|
| 1160 |
+
|
| 1161 |
+
return hidden_states, output_states
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
| 1165 |
+
def __init__(
|
| 1166 |
+
self,
|
| 1167 |
+
in_channels: int,
|
| 1168 |
+
temb_channels: int,
|
| 1169 |
+
dropout: float = 0.0,
|
| 1170 |
+
num_layers: int = 1,
|
| 1171 |
+
transformer_layers_per_block: int = 1,
|
| 1172 |
+
resnet_eps: float = 1e-6,
|
| 1173 |
+
resnet_time_scale_shift: str = "default",
|
| 1174 |
+
resnet_act_fn: str = "swish",
|
| 1175 |
+
resnet_groups: int = 32,
|
| 1176 |
+
resnet_pre_norm: bool = True,
|
| 1177 |
+
num_attention_heads=1,
|
| 1178 |
+
output_scale_factor=1.0,
|
| 1179 |
+
cross_attention_dim=1280,
|
| 1180 |
+
use_linear_projection=False,
|
| 1181 |
+
upcast_attention=False,
|
| 1182 |
+
):
|
| 1183 |
+
super().__init__()
|
| 1184 |
+
|
| 1185 |
+
self.has_cross_attention = True
|
| 1186 |
+
self.num_attention_heads = num_attention_heads
|
| 1187 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 1188 |
+
|
| 1189 |
+
if isinstance(cross_attention_dim, int):
|
| 1190 |
+
cross_attention_dim = (cross_attention_dim,)
|
| 1191 |
+
if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
|
| 1192 |
+
raise ValueError(
|
| 1193 |
+
"Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
|
| 1194 |
+
f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
|
| 1195 |
+
)
|
| 1196 |
+
self.cross_attention_dim = cross_attention_dim
|
| 1197 |
+
|
| 1198 |
+
# there is always at least one resnet
|
| 1199 |
+
resnets = [
|
| 1200 |
+
ResnetBlock2D(
|
| 1201 |
+
in_channels=in_channels,
|
| 1202 |
+
out_channels=in_channels,
|
| 1203 |
+
temb_channels=temb_channels,
|
| 1204 |
+
eps=resnet_eps,
|
| 1205 |
+
groups=resnet_groups,
|
| 1206 |
+
dropout=dropout,
|
| 1207 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1208 |
+
non_linearity=resnet_act_fn,
|
| 1209 |
+
output_scale_factor=output_scale_factor,
|
| 1210 |
+
pre_norm=resnet_pre_norm,
|
| 1211 |
+
)
|
| 1212 |
+
]
|
| 1213 |
+
attentions = []
|
| 1214 |
+
|
| 1215 |
+
for i in range(num_layers):
|
| 1216 |
+
for j in range(len(cross_attention_dim)):
|
| 1217 |
+
attentions.append(
|
| 1218 |
+
Transformer2DModel(
|
| 1219 |
+
num_attention_heads,
|
| 1220 |
+
in_channels // num_attention_heads,
|
| 1221 |
+
in_channels=in_channels,
|
| 1222 |
+
num_layers=transformer_layers_per_block,
|
| 1223 |
+
cross_attention_dim=cross_attention_dim[j],
|
| 1224 |
+
norm_num_groups=resnet_groups,
|
| 1225 |
+
use_linear_projection=use_linear_projection,
|
| 1226 |
+
upcast_attention=upcast_attention,
|
| 1227 |
+
double_self_attention=True if cross_attention_dim[j] is None else False,
|
| 1228 |
+
)
|
| 1229 |
+
)
|
| 1230 |
+
resnets.append(
|
| 1231 |
+
ResnetBlock2D(
|
| 1232 |
+
in_channels=in_channels,
|
| 1233 |
+
out_channels=in_channels,
|
| 1234 |
+
temb_channels=temb_channels,
|
| 1235 |
+
eps=resnet_eps,
|
| 1236 |
+
groups=resnet_groups,
|
| 1237 |
+
dropout=dropout,
|
| 1238 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1239 |
+
non_linearity=resnet_act_fn,
|
| 1240 |
+
output_scale_factor=output_scale_factor,
|
| 1241 |
+
pre_norm=resnet_pre_norm,
|
| 1242 |
+
)
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
self.attentions = nn.ModuleList(attentions)
|
| 1246 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1247 |
+
|
| 1248 |
+
self.gradient_checkpointing = False
|
| 1249 |
+
|
| 1250 |
+
def forward(
|
| 1251 |
+
self,
|
| 1252 |
+
hidden_states: torch.Tensor,
|
| 1253 |
+
temb: Optional[torch.Tensor] = None,
|
| 1254 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1255 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1256 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1257 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1258 |
+
encoder_hidden_states_1: Optional[torch.Tensor] = None,
|
| 1259 |
+
encoder_attention_mask_1: Optional[torch.Tensor] = None,
|
| 1260 |
+
) -> torch.Tensor:
|
| 1261 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
| 1262 |
+
num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1)
|
| 1263 |
+
|
| 1264 |
+
encoder_hidden_states_1 = (
|
| 1265 |
+
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
|
| 1266 |
+
)
|
| 1267 |
+
encoder_attention_mask_1 = (
|
| 1268 |
+
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
for i in range(len(self.resnets[1:])):
|
| 1272 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1273 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1274 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1275 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1276 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1277 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1278 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1279 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1280 |
+
else:
|
| 1281 |
+
forward_encoder_hidden_states = None
|
| 1282 |
+
forward_encoder_attention_mask = None
|
| 1283 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 1284 |
+
self.attentions[i * num_attention_per_layer + idx],
|
| 1285 |
+
hidden_states,
|
| 1286 |
+
forward_encoder_hidden_states,
|
| 1287 |
+
None, # timestep
|
| 1288 |
+
None, # class_labels
|
| 1289 |
+
cross_attention_kwargs,
|
| 1290 |
+
attention_mask,
|
| 1291 |
+
forward_encoder_attention_mask,
|
| 1292 |
+
)[0]
|
| 1293 |
+
hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
|
| 1294 |
+
else:
|
| 1295 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1296 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1297 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1298 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1299 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1300 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1301 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1302 |
+
else:
|
| 1303 |
+
forward_encoder_hidden_states = None
|
| 1304 |
+
forward_encoder_attention_mask = None
|
| 1305 |
+
hidden_states = self.attentions[i * num_attention_per_layer + idx](
|
| 1306 |
+
hidden_states,
|
| 1307 |
+
attention_mask=attention_mask,
|
| 1308 |
+
encoder_hidden_states=forward_encoder_hidden_states,
|
| 1309 |
+
encoder_attention_mask=forward_encoder_attention_mask,
|
| 1310 |
+
return_dict=False,
|
| 1311 |
+
)[0]
|
| 1312 |
+
|
| 1313 |
+
hidden_states = self.resnets[i + 1](hidden_states, temb)
|
| 1314 |
+
|
| 1315 |
+
return hidden_states
|
| 1316 |
+
|
| 1317 |
+
|
| 1318 |
+
class CrossAttnUpBlock2D(nn.Module):
|
| 1319 |
+
def __init__(
|
| 1320 |
+
self,
|
| 1321 |
+
in_channels: int,
|
| 1322 |
+
out_channels: int,
|
| 1323 |
+
prev_output_channel: int,
|
| 1324 |
+
temb_channels: int,
|
| 1325 |
+
dropout: float = 0.0,
|
| 1326 |
+
num_layers: int = 1,
|
| 1327 |
+
transformer_layers_per_block: int = 1,
|
| 1328 |
+
resnet_eps: float = 1e-6,
|
| 1329 |
+
resnet_time_scale_shift: str = "default",
|
| 1330 |
+
resnet_act_fn: str = "swish",
|
| 1331 |
+
resnet_groups: int = 32,
|
| 1332 |
+
resnet_pre_norm: bool = True,
|
| 1333 |
+
num_attention_heads=1,
|
| 1334 |
+
cross_attention_dim=1280,
|
| 1335 |
+
output_scale_factor=1.0,
|
| 1336 |
+
add_upsample=True,
|
| 1337 |
+
use_linear_projection=False,
|
| 1338 |
+
only_cross_attention=False,
|
| 1339 |
+
upcast_attention=False,
|
| 1340 |
+
):
|
| 1341 |
+
super().__init__()
|
| 1342 |
+
resnets = []
|
| 1343 |
+
attentions = []
|
| 1344 |
+
|
| 1345 |
+
self.has_cross_attention = True
|
| 1346 |
+
self.num_attention_heads = num_attention_heads
|
| 1347 |
+
|
| 1348 |
+
if isinstance(cross_attention_dim, int):
|
| 1349 |
+
cross_attention_dim = (cross_attention_dim,)
|
| 1350 |
+
if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
|
| 1351 |
+
raise ValueError(
|
| 1352 |
+
"Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
|
| 1353 |
+
f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
|
| 1354 |
+
)
|
| 1355 |
+
self.cross_attention_dim = cross_attention_dim
|
| 1356 |
+
|
| 1357 |
+
for i in range(num_layers):
|
| 1358 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 1359 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1360 |
+
|
| 1361 |
+
resnets.append(
|
| 1362 |
+
ResnetBlock2D(
|
| 1363 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1364 |
+
out_channels=out_channels,
|
| 1365 |
+
temb_channels=temb_channels,
|
| 1366 |
+
eps=resnet_eps,
|
| 1367 |
+
groups=resnet_groups,
|
| 1368 |
+
dropout=dropout,
|
| 1369 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1370 |
+
non_linearity=resnet_act_fn,
|
| 1371 |
+
output_scale_factor=output_scale_factor,
|
| 1372 |
+
pre_norm=resnet_pre_norm,
|
| 1373 |
+
)
|
| 1374 |
+
)
|
| 1375 |
+
for j in range(len(cross_attention_dim)):
|
| 1376 |
+
attentions.append(
|
| 1377 |
+
Transformer2DModel(
|
| 1378 |
+
num_attention_heads,
|
| 1379 |
+
out_channels // num_attention_heads,
|
| 1380 |
+
in_channels=out_channels,
|
| 1381 |
+
num_layers=transformer_layers_per_block,
|
| 1382 |
+
cross_attention_dim=cross_attention_dim[j],
|
| 1383 |
+
norm_num_groups=resnet_groups,
|
| 1384 |
+
use_linear_projection=use_linear_projection,
|
| 1385 |
+
only_cross_attention=only_cross_attention,
|
| 1386 |
+
upcast_attention=upcast_attention,
|
| 1387 |
+
double_self_attention=True if cross_attention_dim[j] is None else False,
|
| 1388 |
+
)
|
| 1389 |
+
)
|
| 1390 |
+
self.attentions = nn.ModuleList(attentions)
|
| 1391 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1392 |
+
|
| 1393 |
+
if add_upsample:
|
| 1394 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
| 1395 |
+
else:
|
| 1396 |
+
self.upsamplers = None
|
| 1397 |
+
|
| 1398 |
+
self.gradient_checkpointing = False
|
| 1399 |
+
|
| 1400 |
+
def forward(
|
| 1401 |
+
self,
|
| 1402 |
+
hidden_states: torch.Tensor,
|
| 1403 |
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
| 1404 |
+
temb: Optional[torch.Tensor] = None,
|
| 1405 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1406 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1407 |
+
upsample_size: Optional[int] = None,
|
| 1408 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1409 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1410 |
+
encoder_hidden_states_1: Optional[torch.Tensor] = None,
|
| 1411 |
+
encoder_attention_mask_1: Optional[torch.Tensor] = None,
|
| 1412 |
+
):
|
| 1413 |
+
num_layers = len(self.resnets)
|
| 1414 |
+
num_attention_per_layer = len(self.attentions) // num_layers
|
| 1415 |
+
|
| 1416 |
+
encoder_hidden_states_1 = (
|
| 1417 |
+
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
|
| 1418 |
+
)
|
| 1419 |
+
encoder_attention_mask_1 = (
|
| 1420 |
+
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
|
| 1421 |
+
)
|
| 1422 |
+
|
| 1423 |
+
for i in range(num_layers):
|
| 1424 |
+
# pop res hidden states
|
| 1425 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 1426 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 1427 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1428 |
+
|
| 1429 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1430 |
+
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
|
| 1431 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1432 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1433 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1434 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1435 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1436 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1437 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1438 |
+
else:
|
| 1439 |
+
forward_encoder_hidden_states = None
|
| 1440 |
+
forward_encoder_attention_mask = None
|
| 1441 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 1442 |
+
self.attentions[i * num_attention_per_layer + idx],
|
| 1443 |
+
hidden_states,
|
| 1444 |
+
forward_encoder_hidden_states,
|
| 1445 |
+
None, # timestep
|
| 1446 |
+
None, # class_labels
|
| 1447 |
+
cross_attention_kwargs,
|
| 1448 |
+
attention_mask,
|
| 1449 |
+
forward_encoder_attention_mask,
|
| 1450 |
+
)[0]
|
| 1451 |
+
else:
|
| 1452 |
+
hidden_states = self.resnets[i](hidden_states, temb)
|
| 1453 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1454 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1455 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1456 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1457 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1458 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1459 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1460 |
+
else:
|
| 1461 |
+
forward_encoder_hidden_states = None
|
| 1462 |
+
forward_encoder_attention_mask = None
|
| 1463 |
+
hidden_states = self.attentions[i * num_attention_per_layer + idx](
|
| 1464 |
+
hidden_states,
|
| 1465 |
+
attention_mask=attention_mask,
|
| 1466 |
+
encoder_hidden_states=forward_encoder_hidden_states,
|
| 1467 |
+
encoder_attention_mask=forward_encoder_attention_mask,
|
| 1468 |
+
return_dict=False,
|
| 1469 |
+
)[0]
|
| 1470 |
+
|
| 1471 |
+
if self.upsamplers is not None:
|
| 1472 |
+
for upsampler in self.upsamplers:
|
| 1473 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 1474 |
+
|
| 1475 |
+
return hidden_states
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
ADDED
|
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import (
|
| 21 |
+
ClapFeatureExtractor,
|
| 22 |
+
ClapModel,
|
| 23 |
+
GPT2LMHeadModel,
|
| 24 |
+
RobertaTokenizer,
|
| 25 |
+
RobertaTokenizerFast,
|
| 26 |
+
SpeechT5HifiGan,
|
| 27 |
+
T5EncoderModel,
|
| 28 |
+
T5Tokenizer,
|
| 29 |
+
T5TokenizerFast,
|
| 30 |
+
VitsModel,
|
| 31 |
+
VitsTokenizer,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from ...models import AutoencoderKL
|
| 35 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 36 |
+
from ...utils import (
|
| 37 |
+
is_accelerate_available,
|
| 38 |
+
is_accelerate_version,
|
| 39 |
+
is_librosa_available,
|
| 40 |
+
logging,
|
| 41 |
+
replace_example_docstring,
|
| 42 |
+
)
|
| 43 |
+
from ...utils.import_utils import is_transformers_version
|
| 44 |
+
from ...utils.torch_utils import empty_device_cache, randn_tensor
|
| 45 |
+
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
| 46 |
+
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_librosa_available():
|
| 50 |
+
import librosa
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
from ...utils import is_torch_xla_available
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if is_torch_xla_available():
|
| 57 |
+
import torch_xla.core.xla_model as xm
|
| 58 |
+
|
| 59 |
+
XLA_AVAILABLE = True
|
| 60 |
+
else:
|
| 61 |
+
XLA_AVAILABLE = False
|
| 62 |
+
|
| 63 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
EXAMPLE_DOC_STRING = """
|
| 67 |
+
Examples:
|
| 68 |
+
```py
|
| 69 |
+
>>> import scipy
|
| 70 |
+
>>> import torch
|
| 71 |
+
>>> from diffusers import AudioLDM2Pipeline
|
| 72 |
+
|
| 73 |
+
>>> repo_id = "cvssp/audioldm2"
|
| 74 |
+
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
| 75 |
+
>>> pipe = pipe.to("cuda")
|
| 76 |
+
|
| 77 |
+
>>> # define the prompts
|
| 78 |
+
>>> prompt = "The sound of a hammer hitting a wooden surface."
|
| 79 |
+
>>> negative_prompt = "Low quality."
|
| 80 |
+
|
| 81 |
+
>>> # set the seed for generator
|
| 82 |
+
>>> generator = torch.Generator("cuda").manual_seed(0)
|
| 83 |
+
|
| 84 |
+
>>> # run the generation
|
| 85 |
+
>>> audio = pipe(
|
| 86 |
+
... prompt,
|
| 87 |
+
... negative_prompt=negative_prompt,
|
| 88 |
+
... num_inference_steps=200,
|
| 89 |
+
... audio_length_in_s=10.0,
|
| 90 |
+
... num_waveforms_per_prompt=3,
|
| 91 |
+
... generator=generator,
|
| 92 |
+
... ).audios
|
| 93 |
+
|
| 94 |
+
>>> # save the best audio sample (index 0) as a .wav file
|
| 95 |
+
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
|
| 96 |
+
```
|
| 97 |
+
```
|
| 98 |
+
#Using AudioLDM2 for Text To Speech
|
| 99 |
+
>>> import scipy
|
| 100 |
+
>>> import torch
|
| 101 |
+
>>> from diffusers import AudioLDM2Pipeline
|
| 102 |
+
|
| 103 |
+
>>> repo_id = "anhnct/audioldm2_gigaspeech"
|
| 104 |
+
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
| 105 |
+
>>> pipe = pipe.to("cuda")
|
| 106 |
+
|
| 107 |
+
>>> # define the prompts
|
| 108 |
+
>>> prompt = "A female reporter is speaking"
|
| 109 |
+
>>> transcript = "wish you have a good day"
|
| 110 |
+
|
| 111 |
+
>>> # set the seed for generator
|
| 112 |
+
>>> generator = torch.Generator("cuda").manual_seed(0)
|
| 113 |
+
|
| 114 |
+
>>> # run the generation
|
| 115 |
+
>>> audio = pipe(
|
| 116 |
+
... prompt,
|
| 117 |
+
... transcription=transcript,
|
| 118 |
+
... num_inference_steps=200,
|
| 119 |
+
... audio_length_in_s=10.0,
|
| 120 |
+
... num_waveforms_per_prompt=2,
|
| 121 |
+
... generator=generator,
|
| 122 |
+
... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS
|
| 123 |
+
... ).audios
|
| 124 |
+
|
| 125 |
+
>>> # save the best audio sample (index 0) as a .wav file
|
| 126 |
+
>>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0])
|
| 127 |
+
```
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def prepare_inputs_for_generation(
|
| 132 |
+
inputs_embeds,
|
| 133 |
+
attention_mask=None,
|
| 134 |
+
past_key_values=None,
|
| 135 |
+
**kwargs,
|
| 136 |
+
):
|
| 137 |
+
if past_key_values is not None:
|
| 138 |
+
# only last token for inputs_embeds if past is defined in kwargs
|
| 139 |
+
inputs_embeds = inputs_embeds[:, -1:]
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"inputs_embeds": inputs_embeds,
|
| 143 |
+
"attention_mask": attention_mask,
|
| 144 |
+
"past_key_values": past_key_values,
|
| 145 |
+
"use_cache": kwargs.get("use_cache"),
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class AudioLDM2Pipeline(DiffusionPipeline):
|
| 150 |
+
r"""
|
| 151 |
+
Pipeline for text-to-audio generation using AudioLDM2.
|
| 152 |
+
|
| 153 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 154 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
vae ([`AutoencoderKL`]):
|
| 158 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 159 |
+
text_encoder ([`~transformers.ClapModel`]):
|
| 160 |
+
First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
|
| 161 |
+
[CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
|
| 162 |
+
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
|
| 163 |
+
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
|
| 164 |
+
rank generated waveforms against the text prompt by computing similarity scores.
|
| 165 |
+
text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]):
|
| 166 |
+
Second frozen text-encoder. AudioLDM2 uses the encoder of
|
| 167 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 168 |
+
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use
|
| 169 |
+
for TTS. AudioLDM2 uses the encoder of
|
| 170 |
+
[Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel).
|
| 171 |
+
projection_model ([`AudioLDM2ProjectionModel`]):
|
| 172 |
+
A trained model used to linearly project the hidden-states from the first and second text encoder models
|
| 173 |
+
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
|
| 174 |
+
concatenated to give the input to the language model. A Learned Position Embedding for the Vits
|
| 175 |
+
hidden-states
|
| 176 |
+
language_model ([`~transformers.GPT2Model`]):
|
| 177 |
+
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
|
| 178 |
+
outputs from the two text encoders.
|
| 179 |
+
tokenizer ([`~transformers.RobertaTokenizer`]):
|
| 180 |
+
Tokenizer to tokenize text for the first frozen text-encoder.
|
| 181 |
+
tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]):
|
| 182 |
+
Tokenizer to tokenize text for the second frozen text-encoder.
|
| 183 |
+
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
|
| 184 |
+
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
|
| 185 |
+
unet ([`UNet2DConditionModel`]):
|
| 186 |
+
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
| 187 |
+
scheduler ([`SchedulerMixin`]):
|
| 188 |
+
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
| 189 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 190 |
+
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
| 191 |
+
Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
vae: AutoencoderKL,
|
| 197 |
+
text_encoder: ClapModel,
|
| 198 |
+
text_encoder_2: Union[T5EncoderModel, VitsModel],
|
| 199 |
+
projection_model: AudioLDM2ProjectionModel,
|
| 200 |
+
language_model: GPT2LMHeadModel,
|
| 201 |
+
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
| 202 |
+
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
|
| 203 |
+
feature_extractor: ClapFeatureExtractor,
|
| 204 |
+
unet: AudioLDM2UNet2DConditionModel,
|
| 205 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 206 |
+
vocoder: SpeechT5HifiGan,
|
| 207 |
+
):
|
| 208 |
+
super().__init__()
|
| 209 |
+
|
| 210 |
+
self.register_modules(
|
| 211 |
+
vae=vae,
|
| 212 |
+
text_encoder=text_encoder,
|
| 213 |
+
text_encoder_2=text_encoder_2,
|
| 214 |
+
projection_model=projection_model,
|
| 215 |
+
language_model=language_model,
|
| 216 |
+
tokenizer=tokenizer,
|
| 217 |
+
tokenizer_2=tokenizer_2,
|
| 218 |
+
feature_extractor=feature_extractor,
|
| 219 |
+
unet=unet,
|
| 220 |
+
scheduler=scheduler,
|
| 221 |
+
vocoder=vocoder,
|
| 222 |
+
)
|
| 223 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 224 |
+
|
| 225 |
+
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
|
| 226 |
+
def enable_vae_slicing(self):
|
| 227 |
+
r"""
|
| 228 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 229 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 230 |
+
"""
|
| 231 |
+
self.vae.enable_slicing()
|
| 232 |
+
|
| 233 |
+
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
|
| 234 |
+
def disable_vae_slicing(self):
|
| 235 |
+
r"""
|
| 236 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 237 |
+
computing decoding in one step.
|
| 238 |
+
"""
|
| 239 |
+
self.vae.disable_slicing()
|
| 240 |
+
|
| 241 |
+
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
| 242 |
+
r"""
|
| 243 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
| 244 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
| 245 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
| 246 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
| 247 |
+
"""
|
| 248 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
| 249 |
+
from accelerate import cpu_offload_with_hook
|
| 250 |
+
else:
|
| 251 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
| 252 |
+
|
| 253 |
+
torch_device = torch.device(device)
|
| 254 |
+
device_index = torch_device.index
|
| 255 |
+
|
| 256 |
+
if gpu_id is not None and device_index is not None:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
| 259 |
+
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
device_type = torch_device.type
|
| 263 |
+
device_str = device_type
|
| 264 |
+
if gpu_id or torch_device.index:
|
| 265 |
+
device_str = f"{device_str}:{gpu_id or torch_device.index}"
|
| 266 |
+
device = torch.device(device_str)
|
| 267 |
+
|
| 268 |
+
if self.device.type != "cpu":
|
| 269 |
+
self.to("cpu", silence_dtype_warnings=True)
|
| 270 |
+
empty_device_cache(device.type)
|
| 271 |
+
|
| 272 |
+
model_sequence = [
|
| 273 |
+
self.text_encoder.text_model,
|
| 274 |
+
self.text_encoder.text_projection,
|
| 275 |
+
self.text_encoder_2,
|
| 276 |
+
self.projection_model,
|
| 277 |
+
self.language_model,
|
| 278 |
+
self.unet,
|
| 279 |
+
self.vae,
|
| 280 |
+
self.vocoder,
|
| 281 |
+
self.text_encoder,
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
hook = None
|
| 285 |
+
for cpu_offloaded_model in model_sequence:
|
| 286 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
| 287 |
+
|
| 288 |
+
# We'll offload the last model manually.
|
| 289 |
+
self.final_offload_hook = hook
|
| 290 |
+
|
| 291 |
+
def generate_language_model(
|
| 292 |
+
self,
|
| 293 |
+
inputs_embeds: torch.Tensor = None,
|
| 294 |
+
max_new_tokens: int = 8,
|
| 295 |
+
**model_kwargs,
|
| 296 |
+
):
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
|
| 300 |
+
|
| 301 |
+
Parameters:
|
| 302 |
+
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 303 |
+
The sequence used as a prompt for the generation.
|
| 304 |
+
max_new_tokens (`int`):
|
| 305 |
+
Number of new tokens to generate.
|
| 306 |
+
model_kwargs (`Dict[str, Any]`, *optional*):
|
| 307 |
+
Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
|
| 308 |
+
function of the model.
|
| 309 |
+
|
| 310 |
+
Return:
|
| 311 |
+
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 312 |
+
The sequence of generated hidden-states.
|
| 313 |
+
"""
|
| 314 |
+
cache_position_kwargs = {}
|
| 315 |
+
if is_transformers_version("<", "4.52.1"):
|
| 316 |
+
cache_position_kwargs["input_ids"] = inputs_embeds
|
| 317 |
+
else:
|
| 318 |
+
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
| 319 |
+
cache_position_kwargs["device"] = (
|
| 320 |
+
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
| 321 |
+
)
|
| 322 |
+
cache_position_kwargs["model_kwargs"] = model_kwargs
|
| 323 |
+
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
| 324 |
+
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
| 325 |
+
|
| 326 |
+
for _ in range(max_new_tokens):
|
| 327 |
+
# prepare model inputs
|
| 328 |
+
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
| 329 |
+
|
| 330 |
+
# forward pass to get next hidden states
|
| 331 |
+
output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
|
| 332 |
+
|
| 333 |
+
next_hidden_states = output.hidden_states[-1]
|
| 334 |
+
|
| 335 |
+
# Update the model input
|
| 336 |
+
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
|
| 337 |
+
|
| 338 |
+
# Update generated hidden states, model inputs, and length for next step
|
| 339 |
+
model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
|
| 340 |
+
|
| 341 |
+
return inputs_embeds[:, -max_new_tokens:, :]
|
| 342 |
+
|
| 343 |
+
def encode_prompt(
|
| 344 |
+
self,
|
| 345 |
+
prompt,
|
| 346 |
+
device,
|
| 347 |
+
num_waveforms_per_prompt,
|
| 348 |
+
do_classifier_free_guidance,
|
| 349 |
+
transcription=None,
|
| 350 |
+
negative_prompt=None,
|
| 351 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 352 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 353 |
+
generated_prompt_embeds: Optional[torch.Tensor] = None,
|
| 354 |
+
negative_generated_prompt_embeds: Optional[torch.Tensor] = None,
|
| 355 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 356 |
+
negative_attention_mask: Optional[torch.LongTensor] = None,
|
| 357 |
+
max_new_tokens: Optional[int] = None,
|
| 358 |
+
):
|
| 359 |
+
r"""
|
| 360 |
+
Encodes the prompt into text encoder hidden states.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 364 |
+
prompt to be encoded
|
| 365 |
+
transcription (`str` or `List[str]`):
|
| 366 |
+
transcription of text to speech
|
| 367 |
+
device (`torch.device`):
|
| 368 |
+
torch device
|
| 369 |
+
num_waveforms_per_prompt (`int`):
|
| 370 |
+
number of waveforms that should be generated per prompt
|
| 371 |
+
do_classifier_free_guidance (`bool`):
|
| 372 |
+
whether to use classifier free guidance or not
|
| 373 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 374 |
+
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
| 375 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 376 |
+
less than `1`).
|
| 377 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 378 |
+
Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
|
| 379 |
+
prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
|
| 380 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 381 |
+
Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
|
| 382 |
+
*e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 383 |
+
`negative_prompt` input argument.
|
| 384 |
+
generated_prompt_embeds (`torch.Tensor`, *optional*):
|
| 385 |
+
Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
|
| 386 |
+
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
| 387 |
+
argument.
|
| 388 |
+
negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
|
| 389 |
+
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
| 390 |
+
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 391 |
+
`negative_prompt` input argument.
|
| 392 |
+
attention_mask (`torch.LongTensor`, *optional*):
|
| 393 |
+
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
| 394 |
+
be computed from `prompt` input argument.
|
| 395 |
+
negative_attention_mask (`torch.LongTensor`, *optional*):
|
| 396 |
+
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
| 397 |
+
mask will be computed from `negative_prompt` input argument.
|
| 398 |
+
max_new_tokens (`int`, *optional*, defaults to None):
|
| 399 |
+
The number of new tokens to generate with the GPT2 language model.
|
| 400 |
+
Returns:
|
| 401 |
+
prompt_embeds (`torch.Tensor`):
|
| 402 |
+
Text embeddings from the Flan T5 model.
|
| 403 |
+
attention_mask (`torch.LongTensor`):
|
| 404 |
+
Attention mask to be applied to the `prompt_embeds`.
|
| 405 |
+
generated_prompt_embeds (`torch.Tensor`):
|
| 406 |
+
Text embeddings generated from the GPT2 language model.
|
| 407 |
+
|
| 408 |
+
Example:
|
| 409 |
+
|
| 410 |
+
```python
|
| 411 |
+
>>> import scipy
|
| 412 |
+
>>> import torch
|
| 413 |
+
>>> from diffusers import AudioLDM2Pipeline
|
| 414 |
+
|
| 415 |
+
>>> repo_id = "cvssp/audioldm2"
|
| 416 |
+
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
| 417 |
+
>>> pipe = pipe.to("cuda")
|
| 418 |
+
|
| 419 |
+
>>> # Get text embedding vectors
|
| 420 |
+
>>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
|
| 421 |
+
... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
|
| 422 |
+
... device="cuda",
|
| 423 |
+
... do_classifier_free_guidance=True,
|
| 424 |
+
... )
|
| 425 |
+
|
| 426 |
+
>>> # Pass text embeddings to pipeline for text-conditional audio generation
|
| 427 |
+
>>> audio = pipe(
|
| 428 |
+
... prompt_embeds=prompt_embeds,
|
| 429 |
+
... attention_mask=attention_mask,
|
| 430 |
+
... generated_prompt_embeds=generated_prompt_embeds,
|
| 431 |
+
... num_inference_steps=200,
|
| 432 |
+
... audio_length_in_s=10.0,
|
| 433 |
+
... ).audios[0]
|
| 434 |
+
|
| 435 |
+
>>> # save generated audio sample
|
| 436 |
+
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
| 437 |
+
```"""
|
| 438 |
+
if prompt is not None and isinstance(prompt, str):
|
| 439 |
+
batch_size = 1
|
| 440 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 441 |
+
batch_size = len(prompt)
|
| 442 |
+
else:
|
| 443 |
+
batch_size = prompt_embeds.shape[0]
|
| 444 |
+
|
| 445 |
+
# Define tokenizers and text encoders
|
| 446 |
+
tokenizers = [self.tokenizer, self.tokenizer_2]
|
| 447 |
+
is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel)
|
| 448 |
+
|
| 449 |
+
if is_vits_text_encoder:
|
| 450 |
+
text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder]
|
| 451 |
+
else:
|
| 452 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
| 453 |
+
|
| 454 |
+
if prompt_embeds is None:
|
| 455 |
+
prompt_embeds_list = []
|
| 456 |
+
attention_mask_list = []
|
| 457 |
+
|
| 458 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 459 |
+
use_prompt = isinstance(
|
| 460 |
+
tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast)
|
| 461 |
+
)
|
| 462 |
+
text_inputs = tokenizer(
|
| 463 |
+
prompt if use_prompt else transcription,
|
| 464 |
+
padding="max_length"
|
| 465 |
+
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
|
| 466 |
+
else True,
|
| 467 |
+
max_length=tokenizer.model_max_length,
|
| 468 |
+
truncation=True,
|
| 469 |
+
return_tensors="pt",
|
| 470 |
+
)
|
| 471 |
+
text_input_ids = text_inputs.input_ids
|
| 472 |
+
attention_mask = text_inputs.attention_mask
|
| 473 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 474 |
+
|
| 475 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 476 |
+
text_input_ids, untruncated_ids
|
| 477 |
+
):
|
| 478 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
| 479 |
+
logger.warning(
|
| 480 |
+
f"The following part of your input was truncated because {text_encoder.config.model_type} can "
|
| 481 |
+
f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
text_input_ids = text_input_ids.to(device)
|
| 485 |
+
attention_mask = attention_mask.to(device)
|
| 486 |
+
|
| 487 |
+
if text_encoder.config.model_type == "clap":
|
| 488 |
+
prompt_embeds = text_encoder.get_text_features(
|
| 489 |
+
text_input_ids,
|
| 490 |
+
attention_mask=attention_mask,
|
| 491 |
+
)
|
| 492 |
+
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
| 493 |
+
prompt_embeds = prompt_embeds[:, None, :]
|
| 494 |
+
# make sure that we attend to this single hidden-state
|
| 495 |
+
attention_mask = attention_mask.new_ones((batch_size, 1))
|
| 496 |
+
elif is_vits_text_encoder:
|
| 497 |
+
# Add end_token_id and attention mask in the end of sequence phonemes
|
| 498 |
+
for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask):
|
| 499 |
+
for idx, phoneme_id in enumerate(text_input_id):
|
| 500 |
+
if phoneme_id == 0:
|
| 501 |
+
text_input_id[idx] = 182
|
| 502 |
+
text_attention_mask[idx] = 1
|
| 503 |
+
break
|
| 504 |
+
prompt_embeds = text_encoder(
|
| 505 |
+
text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1)
|
| 506 |
+
)
|
| 507 |
+
prompt_embeds = prompt_embeds[0]
|
| 508 |
+
else:
|
| 509 |
+
prompt_embeds = text_encoder(
|
| 510 |
+
text_input_ids,
|
| 511 |
+
attention_mask=attention_mask,
|
| 512 |
+
)
|
| 513 |
+
prompt_embeds = prompt_embeds[0]
|
| 514 |
+
|
| 515 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 516 |
+
attention_mask_list.append(attention_mask)
|
| 517 |
+
|
| 518 |
+
projection_output = self.projection_model(
|
| 519 |
+
hidden_states=prompt_embeds_list[0],
|
| 520 |
+
hidden_states_1=prompt_embeds_list[1],
|
| 521 |
+
attention_mask=attention_mask_list[0],
|
| 522 |
+
attention_mask_1=attention_mask_list[1],
|
| 523 |
+
)
|
| 524 |
+
projected_prompt_embeds = projection_output.hidden_states
|
| 525 |
+
projected_attention_mask = projection_output.attention_mask
|
| 526 |
+
|
| 527 |
+
generated_prompt_embeds = self.generate_language_model(
|
| 528 |
+
projected_prompt_embeds,
|
| 529 |
+
attention_mask=projected_attention_mask,
|
| 530 |
+
max_new_tokens=max_new_tokens,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
| 534 |
+
attention_mask = (
|
| 535 |
+
attention_mask.to(device=device)
|
| 536 |
+
if attention_mask is not None
|
| 537 |
+
else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
| 538 |
+
)
|
| 539 |
+
generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
|
| 540 |
+
|
| 541 |
+
bs_embed, seq_len, hidden_size = prompt_embeds.shape
|
| 542 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 543 |
+
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 544 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
|
| 545 |
+
|
| 546 |
+
# duplicate attention mask for each generation per prompt
|
| 547 |
+
attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
|
| 548 |
+
attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
|
| 549 |
+
|
| 550 |
+
bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
|
| 551 |
+
# duplicate generated embeddings for each generation per prompt, using mps friendly method
|
| 552 |
+
generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 553 |
+
generated_prompt_embeds = generated_prompt_embeds.view(
|
| 554 |
+
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# get unconditional embeddings for classifier free guidance
|
| 558 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 559 |
+
uncond_tokens: List[str]
|
| 560 |
+
if negative_prompt is None:
|
| 561 |
+
uncond_tokens = [""] * batch_size
|
| 562 |
+
elif type(prompt) is not type(negative_prompt):
|
| 563 |
+
raise TypeError(
|
| 564 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 565 |
+
f" {type(prompt)}."
|
| 566 |
+
)
|
| 567 |
+
elif isinstance(negative_prompt, str):
|
| 568 |
+
uncond_tokens = [negative_prompt]
|
| 569 |
+
elif batch_size != len(negative_prompt):
|
| 570 |
+
raise ValueError(
|
| 571 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 572 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 573 |
+
" the batch size of `prompt`."
|
| 574 |
+
)
|
| 575 |
+
else:
|
| 576 |
+
uncond_tokens = negative_prompt
|
| 577 |
+
|
| 578 |
+
negative_prompt_embeds_list = []
|
| 579 |
+
negative_attention_mask_list = []
|
| 580 |
+
max_length = prompt_embeds.shape[1]
|
| 581 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 582 |
+
uncond_input = tokenizer(
|
| 583 |
+
uncond_tokens,
|
| 584 |
+
padding="max_length",
|
| 585 |
+
max_length=tokenizer.model_max_length
|
| 586 |
+
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
|
| 587 |
+
else max_length,
|
| 588 |
+
truncation=True,
|
| 589 |
+
return_tensors="pt",
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
uncond_input_ids = uncond_input.input_ids.to(device)
|
| 593 |
+
negative_attention_mask = uncond_input.attention_mask.to(device)
|
| 594 |
+
|
| 595 |
+
if text_encoder.config.model_type == "clap":
|
| 596 |
+
negative_prompt_embeds = text_encoder.get_text_features(
|
| 597 |
+
uncond_input_ids,
|
| 598 |
+
attention_mask=negative_attention_mask,
|
| 599 |
+
)
|
| 600 |
+
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
| 601 |
+
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
|
| 602 |
+
# make sure that we attend to this single hidden-state
|
| 603 |
+
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
|
| 604 |
+
elif is_vits_text_encoder:
|
| 605 |
+
negative_prompt_embeds = torch.zeros(
|
| 606 |
+
batch_size,
|
| 607 |
+
tokenizer.model_max_length,
|
| 608 |
+
text_encoder.config.hidden_size,
|
| 609 |
+
).to(dtype=self.text_encoder_2.dtype, device=device)
|
| 610 |
+
negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to(
|
| 611 |
+
dtype=self.text_encoder_2.dtype, device=device
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
negative_prompt_embeds = text_encoder(
|
| 615 |
+
uncond_input_ids,
|
| 616 |
+
attention_mask=negative_attention_mask,
|
| 617 |
+
)
|
| 618 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 619 |
+
|
| 620 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
| 621 |
+
negative_attention_mask_list.append(negative_attention_mask)
|
| 622 |
+
|
| 623 |
+
projection_output = self.projection_model(
|
| 624 |
+
hidden_states=negative_prompt_embeds_list[0],
|
| 625 |
+
hidden_states_1=negative_prompt_embeds_list[1],
|
| 626 |
+
attention_mask=negative_attention_mask_list[0],
|
| 627 |
+
attention_mask_1=negative_attention_mask_list[1],
|
| 628 |
+
)
|
| 629 |
+
negative_projected_prompt_embeds = projection_output.hidden_states
|
| 630 |
+
negative_projected_attention_mask = projection_output.attention_mask
|
| 631 |
+
|
| 632 |
+
negative_generated_prompt_embeds = self.generate_language_model(
|
| 633 |
+
negative_projected_prompt_embeds,
|
| 634 |
+
attention_mask=negative_projected_attention_mask,
|
| 635 |
+
max_new_tokens=max_new_tokens,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
if do_classifier_free_guidance:
|
| 639 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 640 |
+
|
| 641 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
| 642 |
+
negative_attention_mask = (
|
| 643 |
+
negative_attention_mask.to(device=device)
|
| 644 |
+
if negative_attention_mask is not None
|
| 645 |
+
else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
| 646 |
+
)
|
| 647 |
+
negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
|
| 648 |
+
dtype=self.language_model.dtype, device=device
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 652 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 653 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
|
| 654 |
+
|
| 655 |
+
# duplicate unconditional attention mask for each generation per prompt
|
| 656 |
+
negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
|
| 657 |
+
negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
|
| 658 |
+
|
| 659 |
+
# duplicate unconditional generated embeddings for each generation per prompt
|
| 660 |
+
seq_len = negative_generated_prompt_embeds.shape[1]
|
| 661 |
+
negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 662 |
+
negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
|
| 663 |
+
batch_size * num_waveforms_per_prompt, seq_len, -1
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 667 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 668 |
+
# to avoid doing two forward passes
|
| 669 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 670 |
+
attention_mask = torch.cat([negative_attention_mask, attention_mask])
|
| 671 |
+
generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
|
| 672 |
+
|
| 673 |
+
return prompt_embeds, attention_mask, generated_prompt_embeds
|
| 674 |
+
|
| 675 |
+
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
|
| 676 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
| 677 |
+
if mel_spectrogram.dim() == 4:
|
| 678 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
| 679 |
+
|
| 680 |
+
waveform = self.vocoder(mel_spectrogram)
|
| 681 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 682 |
+
waveform = waveform.cpu().float()
|
| 683 |
+
return waveform
|
| 684 |
+
|
| 685 |
+
def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
|
| 686 |
+
if not is_librosa_available():
|
| 687 |
+
logger.info(
|
| 688 |
+
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
|
| 689 |
+
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
|
| 690 |
+
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
|
| 691 |
+
)
|
| 692 |
+
return audio
|
| 693 |
+
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
|
| 694 |
+
resampled_audio = librosa.resample(
|
| 695 |
+
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
| 696 |
+
)
|
| 697 |
+
inputs["input_features"] = self.feature_extractor(
|
| 698 |
+
list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
|
| 699 |
+
).input_features.type(dtype)
|
| 700 |
+
inputs = inputs.to(device)
|
| 701 |
+
|
| 702 |
+
# compute the audio-text similarity score using the CLAP model
|
| 703 |
+
logits_per_text = self.text_encoder(**inputs).logits_per_text
|
| 704 |
+
# sort by the highest matching generations per prompt
|
| 705 |
+
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
|
| 706 |
+
audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
|
| 707 |
+
return audio
|
| 708 |
+
|
| 709 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 710 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 711 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 712 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 713 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 714 |
+
# and should be between [0, 1]
|
| 715 |
+
|
| 716 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 717 |
+
extra_step_kwargs = {}
|
| 718 |
+
if accepts_eta:
|
| 719 |
+
extra_step_kwargs["eta"] = eta
|
| 720 |
+
|
| 721 |
+
# check if the scheduler accepts generator
|
| 722 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 723 |
+
if accepts_generator:
|
| 724 |
+
extra_step_kwargs["generator"] = generator
|
| 725 |
+
return extra_step_kwargs
|
| 726 |
+
|
| 727 |
+
def check_inputs(
|
| 728 |
+
self,
|
| 729 |
+
prompt,
|
| 730 |
+
audio_length_in_s,
|
| 731 |
+
vocoder_upsample_factor,
|
| 732 |
+
callback_steps,
|
| 733 |
+
transcription=None,
|
| 734 |
+
negative_prompt=None,
|
| 735 |
+
prompt_embeds=None,
|
| 736 |
+
negative_prompt_embeds=None,
|
| 737 |
+
generated_prompt_embeds=None,
|
| 738 |
+
negative_generated_prompt_embeds=None,
|
| 739 |
+
attention_mask=None,
|
| 740 |
+
negative_attention_mask=None,
|
| 741 |
+
):
|
| 742 |
+
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
| 743 |
+
if audio_length_in_s < min_audio_length_in_s:
|
| 744 |
+
raise ValueError(
|
| 745 |
+
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
| 746 |
+
f"is {audio_length_in_s}."
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
| 750 |
+
raise ValueError(
|
| 751 |
+
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
| 752 |
+
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
| 753 |
+
f"{self.vae_scale_factor}."
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
if (callback_steps is None) or (
|
| 757 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 758 |
+
):
|
| 759 |
+
raise ValueError(
|
| 760 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 761 |
+
f" {type(callback_steps)}."
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
if prompt is not None and prompt_embeds is not None:
|
| 765 |
+
raise ValueError(
|
| 766 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 767 |
+
" only forward one of the two."
|
| 768 |
+
)
|
| 769 |
+
elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
|
| 770 |
+
raise ValueError(
|
| 771 |
+
"Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
|
| 772 |
+
"`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
|
| 773 |
+
)
|
| 774 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 775 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 776 |
+
|
| 777 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 778 |
+
raise ValueError(
|
| 779 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 780 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 781 |
+
)
|
| 782 |
+
elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
|
| 783 |
+
raise ValueError(
|
| 784 |
+
"Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
|
| 785 |
+
"both arguments are specified"
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 789 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 790 |
+
raise ValueError(
|
| 791 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 792 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 793 |
+
f" {negative_prompt_embeds.shape}."
|
| 794 |
+
)
|
| 795 |
+
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
|
| 796 |
+
raise ValueError(
|
| 797 |
+
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
| 798 |
+
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
if transcription is None:
|
| 802 |
+
if self.text_encoder_2.config.model_type == "vits":
|
| 803 |
+
raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
|
| 804 |
+
elif transcription is not None and (
|
| 805 |
+
not isinstance(transcription, str) and not isinstance(transcription, list)
|
| 806 |
+
):
|
| 807 |
+
raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}")
|
| 808 |
+
|
| 809 |
+
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
|
| 810 |
+
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
|
| 811 |
+
raise ValueError(
|
| 812 |
+
"`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
|
| 813 |
+
f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
|
| 814 |
+
f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
|
| 815 |
+
)
|
| 816 |
+
if (
|
| 817 |
+
negative_attention_mask is not None
|
| 818 |
+
and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
|
| 819 |
+
):
|
| 820 |
+
raise ValueError(
|
| 821 |
+
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
| 822 |
+
f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
|
| 826 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
|
| 827 |
+
shape = (
|
| 828 |
+
batch_size,
|
| 829 |
+
num_channels_latents,
|
| 830 |
+
int(height) // self.vae_scale_factor,
|
| 831 |
+
int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
|
| 832 |
+
)
|
| 833 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 834 |
+
raise ValueError(
|
| 835 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 836 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
if latents is None:
|
| 840 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 841 |
+
else:
|
| 842 |
+
latents = latents.to(device)
|
| 843 |
+
|
| 844 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 845 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 846 |
+
return latents
|
| 847 |
+
|
| 848 |
+
@torch.no_grad()
|
| 849 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 850 |
+
def __call__(
|
| 851 |
+
self,
|
| 852 |
+
prompt: Union[str, List[str]] = None,
|
| 853 |
+
transcription: Union[str, List[str]] = None,
|
| 854 |
+
audio_length_in_s: Optional[float] = None,
|
| 855 |
+
num_inference_steps: int = 200,
|
| 856 |
+
guidance_scale: float = 3.5,
|
| 857 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 858 |
+
num_waveforms_per_prompt: Optional[int] = 1,
|
| 859 |
+
eta: float = 0.0,
|
| 860 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 861 |
+
latents: Optional[torch.Tensor] = None,
|
| 862 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 863 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 864 |
+
generated_prompt_embeds: Optional[torch.Tensor] = None,
|
| 865 |
+
negative_generated_prompt_embeds: Optional[torch.Tensor] = None,
|
| 866 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 867 |
+
negative_attention_mask: Optional[torch.LongTensor] = None,
|
| 868 |
+
max_new_tokens: Optional[int] = None,
|
| 869 |
+
return_dict: bool = True,
|
| 870 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 871 |
+
callback_steps: Optional[int] = 1,
|
| 872 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 873 |
+
output_type: Optional[str] = "np",
|
| 874 |
+
):
|
| 875 |
+
r"""
|
| 876 |
+
The call function to the pipeline for generation.
|
| 877 |
+
|
| 878 |
+
Args:
|
| 879 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 880 |
+
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
| 881 |
+
transcription (`str` or `List[str]`, *optional*):\
|
| 882 |
+
The transcript for text to speech.
|
| 883 |
+
audio_length_in_s (`int`, *optional*, defaults to 10.24):
|
| 884 |
+
The length of the generated audio sample in seconds.
|
| 885 |
+
num_inference_steps (`int`, *optional*, defaults to 200):
|
| 886 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
| 887 |
+
expense of slower inference.
|
| 888 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 889 |
+
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
| 890 |
+
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 891 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 892 |
+
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
| 893 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 894 |
+
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
| 895 |
+
The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
|
| 896 |
+
scoring is performed between the generated outputs and the text prompt. This scoring ranks the
|
| 897 |
+
generated waveforms based on their cosine similarity with the text input in the joint text-audio
|
| 898 |
+
embedding space.
|
| 899 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 900 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 901 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 902 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 903 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 904 |
+
generation deterministic.
|
| 905 |
+
latents (`torch.Tensor`, *optional*):
|
| 906 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
|
| 907 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 908 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 909 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 910 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 911 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 912 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 913 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 914 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 915 |
+
generated_prompt_embeds (`torch.Tensor`, *optional*):
|
| 916 |
+
Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
|
| 917 |
+
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
| 918 |
+
argument.
|
| 919 |
+
negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
|
| 920 |
+
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
| 921 |
+
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 922 |
+
`negative_prompt` input argument.
|
| 923 |
+
attention_mask (`torch.LongTensor`, *optional*):
|
| 924 |
+
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
| 925 |
+
be computed from `prompt` input argument.
|
| 926 |
+
negative_attention_mask (`torch.LongTensor`, *optional*):
|
| 927 |
+
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
| 928 |
+
mask will be computed from `negative_prompt` input argument.
|
| 929 |
+
max_new_tokens (`int`, *optional*, defaults to None):
|
| 930 |
+
Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
|
| 931 |
+
be taken from the config of the model.
|
| 932 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 933 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 934 |
+
plain tuple.
|
| 935 |
+
callback (`Callable`, *optional*):
|
| 936 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 937 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 938 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 939 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 940 |
+
every step.
|
| 941 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 942 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 943 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 944 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
| 945 |
+
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
| 946 |
+
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
| 947 |
+
model (LDM) output.
|
| 948 |
+
|
| 949 |
+
Examples:
|
| 950 |
+
|
| 951 |
+
Returns:
|
| 952 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 953 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 954 |
+
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
| 955 |
+
"""
|
| 956 |
+
# 0. Convert audio input length from seconds to spectrogram height
|
| 957 |
+
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
| 958 |
+
|
| 959 |
+
if audio_length_in_s is None:
|
| 960 |
+
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
| 961 |
+
|
| 962 |
+
height = int(audio_length_in_s / vocoder_upsample_factor)
|
| 963 |
+
|
| 964 |
+
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
| 965 |
+
if height % self.vae_scale_factor != 0:
|
| 966 |
+
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
| 967 |
+
logger.info(
|
| 968 |
+
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
| 969 |
+
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
| 970 |
+
f"denoising process."
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
# 1. Check inputs. Raise error if not correct
|
| 974 |
+
self.check_inputs(
|
| 975 |
+
prompt,
|
| 976 |
+
audio_length_in_s,
|
| 977 |
+
vocoder_upsample_factor,
|
| 978 |
+
callback_steps,
|
| 979 |
+
transcription,
|
| 980 |
+
negative_prompt,
|
| 981 |
+
prompt_embeds,
|
| 982 |
+
negative_prompt_embeds,
|
| 983 |
+
generated_prompt_embeds,
|
| 984 |
+
negative_generated_prompt_embeds,
|
| 985 |
+
attention_mask,
|
| 986 |
+
negative_attention_mask,
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
# 2. Define call parameters
|
| 990 |
+
if prompt is not None and isinstance(prompt, str):
|
| 991 |
+
batch_size = 1
|
| 992 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 993 |
+
batch_size = len(prompt)
|
| 994 |
+
else:
|
| 995 |
+
batch_size = prompt_embeds.shape[0]
|
| 996 |
+
|
| 997 |
+
device = self._execution_device
|
| 998 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 999 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 1000 |
+
# corresponds to doing no classifier free guidance.
|
| 1001 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 1002 |
+
|
| 1003 |
+
# 3. Encode input prompt
|
| 1004 |
+
prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
|
| 1005 |
+
prompt,
|
| 1006 |
+
device,
|
| 1007 |
+
num_waveforms_per_prompt,
|
| 1008 |
+
do_classifier_free_guidance,
|
| 1009 |
+
transcription,
|
| 1010 |
+
negative_prompt,
|
| 1011 |
+
prompt_embeds=prompt_embeds,
|
| 1012 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1013 |
+
generated_prompt_embeds=generated_prompt_embeds,
|
| 1014 |
+
negative_generated_prompt_embeds=negative_generated_prompt_embeds,
|
| 1015 |
+
attention_mask=attention_mask,
|
| 1016 |
+
negative_attention_mask=negative_attention_mask,
|
| 1017 |
+
max_new_tokens=max_new_tokens,
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
# 4. Prepare timesteps
|
| 1021 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 1022 |
+
timesteps = self.scheduler.timesteps
|
| 1023 |
+
|
| 1024 |
+
# 5. Prepare latent variables
|
| 1025 |
+
num_channels_latents = self.unet.config.in_channels
|
| 1026 |
+
latents = self.prepare_latents(
|
| 1027 |
+
batch_size * num_waveforms_per_prompt,
|
| 1028 |
+
num_channels_latents,
|
| 1029 |
+
height,
|
| 1030 |
+
prompt_embeds.dtype,
|
| 1031 |
+
device,
|
| 1032 |
+
generator,
|
| 1033 |
+
latents,
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# 6. Prepare extra step kwargs
|
| 1037 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1038 |
+
|
| 1039 |
+
# 7. Denoising loop
|
| 1040 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1041 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1042 |
+
for i, t in enumerate(timesteps):
|
| 1043 |
+
# expand the latents if we are doing classifier free guidance
|
| 1044 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1045 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1046 |
+
|
| 1047 |
+
# predict the noise residual
|
| 1048 |
+
noise_pred = self.unet(
|
| 1049 |
+
latent_model_input,
|
| 1050 |
+
t,
|
| 1051 |
+
encoder_hidden_states=generated_prompt_embeds,
|
| 1052 |
+
encoder_hidden_states_1=prompt_embeds,
|
| 1053 |
+
encoder_attention_mask_1=attention_mask,
|
| 1054 |
+
return_dict=False,
|
| 1055 |
+
)[0]
|
| 1056 |
+
|
| 1057 |
+
# perform guidance
|
| 1058 |
+
if do_classifier_free_guidance:
|
| 1059 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1060 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1061 |
+
|
| 1062 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1063 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 1064 |
+
|
| 1065 |
+
# call the callback, if provided
|
| 1066 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1067 |
+
progress_bar.update()
|
| 1068 |
+
if callback is not None and i % callback_steps == 0:
|
| 1069 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 1070 |
+
callback(step_idx, t, latents)
|
| 1071 |
+
|
| 1072 |
+
if XLA_AVAILABLE:
|
| 1073 |
+
xm.mark_step()
|
| 1074 |
+
|
| 1075 |
+
self.maybe_free_model_hooks()
|
| 1076 |
+
|
| 1077 |
+
# 8. Post-processing
|
| 1078 |
+
if not output_type == "latent":
|
| 1079 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 1080 |
+
mel_spectrogram = self.vae.decode(latents).sample
|
| 1081 |
+
else:
|
| 1082 |
+
return AudioPipelineOutput(audios=latents)
|
| 1083 |
+
|
| 1084 |
+
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
| 1085 |
+
|
| 1086 |
+
audio = audio[:, :original_waveform_length]
|
| 1087 |
+
|
| 1088 |
+
# 9. Automatic scoring
|
| 1089 |
+
if num_waveforms_per_prompt > 1 and prompt is not None:
|
| 1090 |
+
audio = self.score_waveforms(
|
| 1091 |
+
text=prompt,
|
| 1092 |
+
audio=audio,
|
| 1093 |
+
num_waveforms_per_prompt=num_waveforms_per_prompt,
|
| 1094 |
+
device=device,
|
| 1095 |
+
dtype=prompt_embeds.dtype,
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
if output_type == "np":
|
| 1099 |
+
audio = audio.numpy()
|
| 1100 |
+
|
| 1101 |
+
if not return_dict:
|
| 1102 |
+
return (audio,)
|
| 1103 |
+
|
| 1104 |
+
return AudioPipelineOutput(audios=audio)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"]
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 28 |
+
try:
|
| 29 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 30 |
+
raise OptionalDependencyNotAvailable()
|
| 31 |
+
|
| 32 |
+
except OptionalDependencyNotAvailable:
|
| 33 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 34 |
+
else:
|
| 35 |
+
from .pipeline_aura_flow import AuraFlowPipeline
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
import sys
|
| 39 |
+
|
| 40 |
+
sys.modules[__name__] = _LazyModule(
|
| 41 |
+
__name__,
|
| 42 |
+
globals()["__file__"],
|
| 43 |
+
_import_structure,
|
| 44 |
+
module_spec=__spec__,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
for name, value in _dummy_objects.items():
|
| 48 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 AuraFlow Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import inspect
|
| 15 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from transformers import T5Tokenizer, UMT5EncoderModel
|
| 19 |
+
|
| 20 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 21 |
+
from ...image_processor import VaeImageProcessor
|
| 22 |
+
from ...loaders import AuraFlowLoraLoaderMixin
|
| 23 |
+
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
|
| 24 |
+
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
| 25 |
+
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 26 |
+
from ...utils import (
|
| 27 |
+
USE_PEFT_BACKEND,
|
| 28 |
+
is_torch_xla_available,
|
| 29 |
+
logging,
|
| 30 |
+
replace_example_docstring,
|
| 31 |
+
scale_lora_layers,
|
| 32 |
+
unscale_lora_layers,
|
| 33 |
+
)
|
| 34 |
+
from ...utils.torch_utils import randn_tensor
|
| 35 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if is_torch_xla_available():
|
| 39 |
+
import torch_xla.core.xla_model as xm
|
| 40 |
+
|
| 41 |
+
XLA_AVAILABLE = True
|
| 42 |
+
else:
|
| 43 |
+
XLA_AVAILABLE = False
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
EXAMPLE_DOC_STRING = """
|
| 49 |
+
Examples:
|
| 50 |
+
```py
|
| 51 |
+
>>> import torch
|
| 52 |
+
>>> from diffusers import AuraFlowPipeline
|
| 53 |
+
|
| 54 |
+
>>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
|
| 55 |
+
>>> pipe = pipe.to("cuda")
|
| 56 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 57 |
+
>>> image = pipe(prompt).images[0]
|
| 58 |
+
>>> image.save("aura_flow.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 64 |
+
def retrieve_timesteps(
|
| 65 |
+
scheduler,
|
| 66 |
+
num_inference_steps: Optional[int] = None,
|
| 67 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 68 |
+
timesteps: Optional[List[int]] = None,
|
| 69 |
+
sigmas: Optional[List[float]] = None,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
r"""
|
| 73 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 74 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
scheduler (`SchedulerMixin`):
|
| 78 |
+
The scheduler to get timesteps from.
|
| 79 |
+
num_inference_steps (`int`):
|
| 80 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 81 |
+
must be `None`.
|
| 82 |
+
device (`str` or `torch.device`, *optional*):
|
| 83 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 84 |
+
timesteps (`List[int]`, *optional*):
|
| 85 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 86 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 87 |
+
sigmas (`List[float]`, *optional*):
|
| 88 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 89 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 93 |
+
second element is the number of inference steps.
|
| 94 |
+
"""
|
| 95 |
+
if timesteps is not None and sigmas is not None:
|
| 96 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 97 |
+
if timesteps is not None:
|
| 98 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 99 |
+
if not accepts_timesteps:
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 102 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 103 |
+
)
|
| 104 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 105 |
+
timesteps = scheduler.timesteps
|
| 106 |
+
num_inference_steps = len(timesteps)
|
| 107 |
+
elif sigmas is not None:
|
| 108 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 109 |
+
if not accept_sigmas:
|
| 110 |
+
raise ValueError(
|
| 111 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 112 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 113 |
+
)
|
| 114 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 115 |
+
timesteps = scheduler.timesteps
|
| 116 |
+
num_inference_steps = len(timesteps)
|
| 117 |
+
else:
|
| 118 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 119 |
+
timesteps = scheduler.timesteps
|
| 120 |
+
return timesteps, num_inference_steps
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
| 124 |
+
r"""
|
| 125 |
+
Args:
|
| 126 |
+
tokenizer (`T5TokenizerFast`):
|
| 127 |
+
Tokenizer of class
|
| 128 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 129 |
+
text_encoder ([`T5EncoderModel`]):
|
| 130 |
+
Frozen text-encoder. AuraFlow uses
|
| 131 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 132 |
+
[EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant.
|
| 133 |
+
vae ([`AutoencoderKL`]):
|
| 134 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 135 |
+
transformer ([`AuraFlowTransformer2DModel`]):
|
| 136 |
+
Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents.
|
| 137 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 138 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
_optional_components = []
|
| 142 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 143 |
+
_callback_tensor_inputs = [
|
| 144 |
+
"latents",
|
| 145 |
+
"prompt_embeds",
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
tokenizer: T5Tokenizer,
|
| 151 |
+
text_encoder: UMT5EncoderModel,
|
| 152 |
+
vae: AutoencoderKL,
|
| 153 |
+
transformer: AuraFlowTransformer2DModel,
|
| 154 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 155 |
+
):
|
| 156 |
+
super().__init__()
|
| 157 |
+
|
| 158 |
+
self.register_modules(
|
| 159 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 163 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 164 |
+
|
| 165 |
+
def check_inputs(
|
| 166 |
+
self,
|
| 167 |
+
prompt,
|
| 168 |
+
height,
|
| 169 |
+
width,
|
| 170 |
+
negative_prompt,
|
| 171 |
+
prompt_embeds=None,
|
| 172 |
+
negative_prompt_embeds=None,
|
| 173 |
+
prompt_attention_mask=None,
|
| 174 |
+
negative_prompt_attention_mask=None,
|
| 175 |
+
callback_on_step_end_tensor_inputs=None,
|
| 176 |
+
):
|
| 177 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 183 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 184 |
+
):
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 187 |
+
)
|
| 188 |
+
if prompt is not None and prompt_embeds is not None:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 191 |
+
" only forward one of the two."
|
| 192 |
+
)
|
| 193 |
+
elif prompt is None and prompt_embeds is None:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 196 |
+
)
|
| 197 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 198 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 199 |
+
|
| 200 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 203 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 209 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 213 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
| 214 |
+
|
| 215 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 216 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
| 217 |
+
|
| 218 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 219 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 220 |
+
raise ValueError(
|
| 221 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 222 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 223 |
+
f" {negative_prompt_embeds.shape}."
|
| 224 |
+
)
|
| 225 |
+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
| 228 |
+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
| 229 |
+
f" {negative_prompt_attention_mask.shape}."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def encode_prompt(
|
| 233 |
+
self,
|
| 234 |
+
prompt: Union[str, List[str]],
|
| 235 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 236 |
+
do_classifier_free_guidance: bool = True,
|
| 237 |
+
num_images_per_prompt: int = 1,
|
| 238 |
+
device: Optional[torch.device] = None,
|
| 239 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 240 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 242 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 243 |
+
max_sequence_length: int = 256,
|
| 244 |
+
lora_scale: Optional[float] = None,
|
| 245 |
+
):
|
| 246 |
+
r"""
|
| 247 |
+
Encodes the prompt into text encoder hidden states.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 251 |
+
prompt to be encoded
|
| 252 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 253 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
| 254 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
| 255 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 256 |
+
whether to use classifier free guidance or not
|
| 257 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 258 |
+
number of images that should be generated per prompt
|
| 259 |
+
device: (`torch.device`, *optional*):
|
| 260 |
+
torch device to place the resulting embeddings on
|
| 261 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 262 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 263 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 264 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
| 265 |
+
Pre-generated attention mask for text embeddings.
|
| 266 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 267 |
+
Pre-generated negative text embeddings.
|
| 268 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
| 269 |
+
Pre-generated attention mask for negative text embeddings.
|
| 270 |
+
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
|
| 271 |
+
lora_scale (`float`, *optional*):
|
| 272 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 273 |
+
"""
|
| 274 |
+
# set lora scale so that monkey patched LoRA
|
| 275 |
+
# function of text encoder can correctly access it
|
| 276 |
+
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
|
| 277 |
+
self._lora_scale = lora_scale
|
| 278 |
+
|
| 279 |
+
# dynamically adjust the LoRA scale
|
| 280 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 281 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 282 |
+
|
| 283 |
+
if device is None:
|
| 284 |
+
device = self._execution_device
|
| 285 |
+
if prompt is not None and isinstance(prompt, str):
|
| 286 |
+
batch_size = 1
|
| 287 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 288 |
+
batch_size = len(prompt)
|
| 289 |
+
else:
|
| 290 |
+
batch_size = prompt_embeds.shape[0]
|
| 291 |
+
|
| 292 |
+
max_length = max_sequence_length
|
| 293 |
+
if prompt_embeds is None:
|
| 294 |
+
text_inputs = self.tokenizer(
|
| 295 |
+
prompt,
|
| 296 |
+
truncation=True,
|
| 297 |
+
max_length=max_length,
|
| 298 |
+
padding="max_length",
|
| 299 |
+
return_tensors="pt",
|
| 300 |
+
)
|
| 301 |
+
text_input_ids = text_inputs["input_ids"]
|
| 302 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 303 |
+
|
| 304 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 305 |
+
text_input_ids, untruncated_ids
|
| 306 |
+
):
|
| 307 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
| 308 |
+
logger.warning(
|
| 309 |
+
"The following part of your input was truncated because T5 can only handle sequences up to"
|
| 310 |
+
f" {max_length} tokens: {removed_text}"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 314 |
+
prompt_embeds = self.text_encoder(**text_inputs)[0]
|
| 315 |
+
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
| 316 |
+
prompt_embeds = prompt_embeds * prompt_attention_mask
|
| 317 |
+
|
| 318 |
+
if self.text_encoder is not None:
|
| 319 |
+
dtype = self.text_encoder.dtype
|
| 320 |
+
elif self.transformer is not None:
|
| 321 |
+
dtype = self.transformer.dtype
|
| 322 |
+
else:
|
| 323 |
+
dtype = None
|
| 324 |
+
|
| 325 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 326 |
+
|
| 327 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 328 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 329 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 330 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 331 |
+
prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1)
|
| 332 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
| 333 |
+
|
| 334 |
+
# get unconditional embeddings for classifier free guidance
|
| 335 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 336 |
+
negative_prompt = negative_prompt or ""
|
| 337 |
+
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
| 338 |
+
max_length = prompt_embeds.shape[1]
|
| 339 |
+
uncond_input = self.tokenizer(
|
| 340 |
+
uncond_tokens,
|
| 341 |
+
truncation=True,
|
| 342 |
+
max_length=max_length,
|
| 343 |
+
padding="max_length",
|
| 344 |
+
return_tensors="pt",
|
| 345 |
+
)
|
| 346 |
+
uncond_input = {k: v.to(device) for k, v in uncond_input.items()}
|
| 347 |
+
negative_prompt_embeds = self.text_encoder(**uncond_input)[0]
|
| 348 |
+
negative_prompt_attention_mask = (
|
| 349 |
+
uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape)
|
| 350 |
+
)
|
| 351 |
+
negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask
|
| 352 |
+
|
| 353 |
+
if do_classifier_free_guidance:
|
| 354 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 355 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 356 |
+
|
| 357 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
| 358 |
+
|
| 359 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 360 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 361 |
+
|
| 362 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1)
|
| 363 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
| 364 |
+
else:
|
| 365 |
+
negative_prompt_embeds = None
|
| 366 |
+
negative_prompt_attention_mask = None
|
| 367 |
+
|
| 368 |
+
if self.text_encoder is not None:
|
| 369 |
+
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 370 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 371 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 372 |
+
|
| 373 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
| 374 |
+
|
| 375 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
|
| 376 |
+
def prepare_latents(
|
| 377 |
+
self,
|
| 378 |
+
batch_size,
|
| 379 |
+
num_channels_latents,
|
| 380 |
+
height,
|
| 381 |
+
width,
|
| 382 |
+
dtype,
|
| 383 |
+
device,
|
| 384 |
+
generator,
|
| 385 |
+
latents=None,
|
| 386 |
+
):
|
| 387 |
+
if latents is not None:
|
| 388 |
+
return latents.to(device=device, dtype=dtype)
|
| 389 |
+
|
| 390 |
+
shape = (
|
| 391 |
+
batch_size,
|
| 392 |
+
num_channels_latents,
|
| 393 |
+
int(height) // self.vae_scale_factor,
|
| 394 |
+
int(width) // self.vae_scale_factor,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 398 |
+
raise ValueError(
|
| 399 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 400 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 404 |
+
|
| 405 |
+
return latents
|
| 406 |
+
|
| 407 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
|
| 408 |
+
def upcast_vae(self):
|
| 409 |
+
dtype = self.vae.dtype
|
| 410 |
+
self.vae.to(dtype=torch.float32)
|
| 411 |
+
use_torch_2_0_or_xformers = isinstance(
|
| 412 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
| 413 |
+
(
|
| 414 |
+
AttnProcessor2_0,
|
| 415 |
+
XFormersAttnProcessor,
|
| 416 |
+
FusedAttnProcessor2_0,
|
| 417 |
+
),
|
| 418 |
+
)
|
| 419 |
+
# if xformers or torch_2_0 is used attention block does not need
|
| 420 |
+
# to be in float32 which can save lots of memory
|
| 421 |
+
if use_torch_2_0_or_xformers:
|
| 422 |
+
self.vae.post_quant_conv.to(dtype)
|
| 423 |
+
self.vae.decoder.conv_in.to(dtype)
|
| 424 |
+
self.vae.decoder.mid_block.to(dtype)
|
| 425 |
+
|
| 426 |
+
@property
|
| 427 |
+
def guidance_scale(self):
|
| 428 |
+
return self._guidance_scale
|
| 429 |
+
|
| 430 |
+
@property
|
| 431 |
+
def attention_kwargs(self):
|
| 432 |
+
return self._attention_kwargs
|
| 433 |
+
|
| 434 |
+
@property
|
| 435 |
+
def num_timesteps(self):
|
| 436 |
+
return self._num_timesteps
|
| 437 |
+
|
| 438 |
+
@torch.no_grad()
|
| 439 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 440 |
+
def __call__(
|
| 441 |
+
self,
|
| 442 |
+
prompt: Union[str, List[str]] = None,
|
| 443 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 444 |
+
num_inference_steps: int = 50,
|
| 445 |
+
sigmas: List[float] = None,
|
| 446 |
+
guidance_scale: float = 3.5,
|
| 447 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 448 |
+
height: Optional[int] = 1024,
|
| 449 |
+
width: Optional[int] = 1024,
|
| 450 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 451 |
+
latents: Optional[torch.Tensor] = None,
|
| 452 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 453 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 454 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 455 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 456 |
+
max_sequence_length: int = 256,
|
| 457 |
+
output_type: Optional[str] = "pil",
|
| 458 |
+
return_dict: bool = True,
|
| 459 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 460 |
+
callback_on_step_end: Optional[
|
| 461 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 462 |
+
] = None,
|
| 463 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 464 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 465 |
+
r"""
|
| 466 |
+
Function invoked when calling the pipeline for generation.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 470 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 471 |
+
instead.
|
| 472 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 473 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 474 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 475 |
+
less than `1`).
|
| 476 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 477 |
+
The height in pixels of the generated image. This is set to 1024 by default for best results.
|
| 478 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 479 |
+
The width in pixels of the generated image. This is set to 1024 by default for best results.
|
| 480 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 481 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 482 |
+
expense of slower inference.
|
| 483 |
+
sigmas (`List[float]`, *optional*):
|
| 484 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 485 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 486 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 487 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 488 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 489 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 490 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 491 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 492 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 493 |
+
The number of images to generate per prompt.
|
| 494 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 495 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 496 |
+
to make generation deterministic.
|
| 497 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 498 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 499 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 500 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 501 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 502 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 503 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 504 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
| 505 |
+
Pre-generated attention mask for text embeddings.
|
| 506 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 507 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 508 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 509 |
+
argument.
|
| 510 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
| 511 |
+
Pre-generated attention mask for negative text embeddings.
|
| 512 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 513 |
+
The output format of the generate image. Choose between
|
| 514 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 515 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 516 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 517 |
+
of a plain tuple.
|
| 518 |
+
attention_kwargs (`dict`, *optional*):
|
| 519 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 520 |
+
`self.processor` in
|
| 521 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 522 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 523 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 524 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 525 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 526 |
+
`callback_on_step_end_tensor_inputs`.
|
| 527 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 528 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 529 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 530 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 531 |
+
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
| 532 |
+
|
| 533 |
+
Examples:
|
| 534 |
+
|
| 535 |
+
Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 536 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
|
| 537 |
+
where the first element is a list with the generated images.
|
| 538 |
+
"""
|
| 539 |
+
# 1. Check inputs. Raise error if not correct
|
| 540 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
| 541 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
| 542 |
+
|
| 543 |
+
self.check_inputs(
|
| 544 |
+
prompt,
|
| 545 |
+
height,
|
| 546 |
+
width,
|
| 547 |
+
negative_prompt,
|
| 548 |
+
prompt_embeds,
|
| 549 |
+
negative_prompt_embeds,
|
| 550 |
+
prompt_attention_mask,
|
| 551 |
+
negative_prompt_attention_mask,
|
| 552 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
self._guidance_scale = guidance_scale
|
| 556 |
+
self._attention_kwargs = attention_kwargs
|
| 557 |
+
|
| 558 |
+
# 2. Determine batch size.
|
| 559 |
+
if prompt is not None and isinstance(prompt, str):
|
| 560 |
+
batch_size = 1
|
| 561 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 562 |
+
batch_size = len(prompt)
|
| 563 |
+
else:
|
| 564 |
+
batch_size = prompt_embeds.shape[0]
|
| 565 |
+
|
| 566 |
+
device = self._execution_device
|
| 567 |
+
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
| 568 |
+
|
| 569 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 570 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 571 |
+
# corresponds to doing no classifier free guidance.
|
| 572 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 573 |
+
|
| 574 |
+
# 3. Encode input prompt
|
| 575 |
+
(
|
| 576 |
+
prompt_embeds,
|
| 577 |
+
prompt_attention_mask,
|
| 578 |
+
negative_prompt_embeds,
|
| 579 |
+
negative_prompt_attention_mask,
|
| 580 |
+
) = self.encode_prompt(
|
| 581 |
+
prompt=prompt,
|
| 582 |
+
negative_prompt=negative_prompt,
|
| 583 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 584 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 585 |
+
device=device,
|
| 586 |
+
prompt_embeds=prompt_embeds,
|
| 587 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 588 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 589 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 590 |
+
max_sequence_length=max_sequence_length,
|
| 591 |
+
lora_scale=lora_scale,
|
| 592 |
+
)
|
| 593 |
+
if do_classifier_free_guidance:
|
| 594 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 595 |
+
|
| 596 |
+
# 4. Prepare timesteps
|
| 597 |
+
|
| 598 |
+
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 599 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
| 600 |
+
|
| 601 |
+
# 5. Prepare latents.
|
| 602 |
+
latent_channels = self.transformer.config.in_channels
|
| 603 |
+
latents = self.prepare_latents(
|
| 604 |
+
batch_size * num_images_per_prompt,
|
| 605 |
+
latent_channels,
|
| 606 |
+
height,
|
| 607 |
+
width,
|
| 608 |
+
prompt_embeds.dtype,
|
| 609 |
+
device,
|
| 610 |
+
generator,
|
| 611 |
+
latents,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# 6. Denoising loop
|
| 615 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 616 |
+
self._num_timesteps = len(timesteps)
|
| 617 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 618 |
+
for i, t in enumerate(timesteps):
|
| 619 |
+
# expand the latents if we are doing classifier free guidance
|
| 620 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 621 |
+
|
| 622 |
+
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
| 623 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 624 |
+
timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0])
|
| 625 |
+
timestep = timestep.to(latents.device, dtype=latents.dtype)
|
| 626 |
+
|
| 627 |
+
# predict noise model_output
|
| 628 |
+
noise_pred = self.transformer(
|
| 629 |
+
latent_model_input,
|
| 630 |
+
encoder_hidden_states=prompt_embeds,
|
| 631 |
+
timestep=timestep,
|
| 632 |
+
return_dict=False,
|
| 633 |
+
attention_kwargs=self.attention_kwargs,
|
| 634 |
+
)[0]
|
| 635 |
+
|
| 636 |
+
# perform guidance
|
| 637 |
+
if do_classifier_free_guidance:
|
| 638 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 639 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 640 |
+
|
| 641 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 642 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 643 |
+
|
| 644 |
+
if callback_on_step_end is not None:
|
| 645 |
+
callback_kwargs = {}
|
| 646 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 647 |
+
callback_kwargs[k] = locals()[k]
|
| 648 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 649 |
+
|
| 650 |
+
latents = callback_outputs.pop("latents", latents)
|
| 651 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 652 |
+
|
| 653 |
+
# call the callback, if provided
|
| 654 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 655 |
+
progress_bar.update()
|
| 656 |
+
|
| 657 |
+
if XLA_AVAILABLE:
|
| 658 |
+
xm.mark_step()
|
| 659 |
+
|
| 660 |
+
if output_type == "latent":
|
| 661 |
+
image = latents
|
| 662 |
+
else:
|
| 663 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 664 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 665 |
+
if needs_upcasting:
|
| 666 |
+
self.upcast_vae()
|
| 667 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 668 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 669 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 670 |
+
|
| 671 |
+
# Offload all models
|
| 672 |
+
self.maybe_free_model_hooks()
|
| 673 |
+
|
| 674 |
+
if not return_dict:
|
| 675 |
+
return (image,)
|
| 676 |
+
|
| 677 |
+
return ImagePipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 13 |
+
raise OptionalDependencyNotAvailable()
|
| 14 |
+
except OptionalDependencyNotAvailable:
|
| 15 |
+
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
| 16 |
+
else:
|
| 17 |
+
from .blip_image_processing import BlipImageProcessor
|
| 18 |
+
from .modeling_blip2 import Blip2QFormerModel
|
| 19 |
+
from .modeling_ctx_clip import ContextCLIPTextModel
|
| 20 |
+
from .pipeline_blip_diffusion import BlipDiffusionPipeline
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/blip_image_processing.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for BLIP."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
| 23 |
+
from transformers.image_utils import (
|
| 24 |
+
OPENAI_CLIP_MEAN,
|
| 25 |
+
OPENAI_CLIP_STD,
|
| 26 |
+
ChannelDimension,
|
| 27 |
+
ImageInput,
|
| 28 |
+
PILImageResampling,
|
| 29 |
+
infer_channel_dimension_format,
|
| 30 |
+
is_scaled_image,
|
| 31 |
+
make_list_of_images,
|
| 32 |
+
to_numpy_array,
|
| 33 |
+
valid_images,
|
| 34 |
+
)
|
| 35 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 36 |
+
|
| 37 |
+
from diffusers.utils import numpy_to_pil
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if is_vision_available():
|
| 41 |
+
import PIL.Image
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
|
| 48 |
+
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
|
| 49 |
+
class BlipImageProcessor(BaseImageProcessor):
|
| 50 |
+
r"""
|
| 51 |
+
Constructs a BLIP image processor.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 55 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
| 56 |
+
`do_resize` parameter in the `preprocess` method.
|
| 57 |
+
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
| 58 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 59 |
+
method.
|
| 60 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 61 |
+
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
| 62 |
+
overridden by the `resample` parameter in the `preprocess` method.
|
| 63 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 64 |
+
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
| 65 |
+
`do_rescale` parameter in the `preprocess` method.
|
| 66 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 67 |
+
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
| 68 |
+
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
| 69 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 71 |
+
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
| 72 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 73 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 74 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
| 75 |
+
overridden by the `image_mean` parameter in the `preprocess` method.
|
| 76 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 77 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 78 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 79 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 80 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 81 |
+
Whether to convert the image to RGB.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
model_input_names = ["pixel_values"]
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
do_resize: bool = True,
|
| 89 |
+
size: Dict[str, int] = None,
|
| 90 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 91 |
+
do_rescale: bool = True,
|
| 92 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 93 |
+
do_normalize: bool = True,
|
| 94 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 95 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 96 |
+
do_convert_rgb: bool = True,
|
| 97 |
+
do_center_crop: bool = True,
|
| 98 |
+
**kwargs,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__(**kwargs)
|
| 101 |
+
size = size if size is not None else {"height": 224, "width": 224}
|
| 102 |
+
size = get_size_dict(size, default_to_square=True)
|
| 103 |
+
|
| 104 |
+
self.do_resize = do_resize
|
| 105 |
+
self.size = size
|
| 106 |
+
self.resample = resample
|
| 107 |
+
self.do_rescale = do_rescale
|
| 108 |
+
self.rescale_factor = rescale_factor
|
| 109 |
+
self.do_normalize = do_normalize
|
| 110 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
| 111 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
| 112 |
+
self.do_convert_rgb = do_convert_rgb
|
| 113 |
+
self.do_center_crop = do_center_crop
|
| 114 |
+
|
| 115 |
+
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
| 116 |
+
def resize(
|
| 117 |
+
self,
|
| 118 |
+
image: np.ndarray,
|
| 119 |
+
size: Dict[str, int],
|
| 120 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 121 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 122 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 123 |
+
**kwargs,
|
| 124 |
+
) -> np.ndarray:
|
| 125 |
+
"""
|
| 126 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
image (`np.ndarray`):
|
| 130 |
+
Image to resize.
|
| 131 |
+
size (`Dict[str, int]`):
|
| 132 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 133 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 134 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
| 135 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 136 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 137 |
+
image is used. Can be one of:
|
| 138 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 139 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 140 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 141 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 142 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 143 |
+
from the input image. Can be one of:
|
| 144 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 145 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 146 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
`np.ndarray`: The resized image.
|
| 150 |
+
"""
|
| 151 |
+
size = get_size_dict(size)
|
| 152 |
+
if "height" not in size or "width" not in size:
|
| 153 |
+
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
| 154 |
+
output_size = (size["height"], size["width"])
|
| 155 |
+
return resize(
|
| 156 |
+
image,
|
| 157 |
+
size=output_size,
|
| 158 |
+
resample=resample,
|
| 159 |
+
data_format=data_format,
|
| 160 |
+
input_data_format=input_data_format,
|
| 161 |
+
**kwargs,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def preprocess(
|
| 165 |
+
self,
|
| 166 |
+
images: ImageInput,
|
| 167 |
+
do_resize: Optional[bool] = None,
|
| 168 |
+
size: Optional[Dict[str, int]] = None,
|
| 169 |
+
resample: PILImageResampling = None,
|
| 170 |
+
do_rescale: Optional[bool] = None,
|
| 171 |
+
do_center_crop: Optional[bool] = None,
|
| 172 |
+
rescale_factor: Optional[float] = None,
|
| 173 |
+
do_normalize: Optional[bool] = None,
|
| 174 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 175 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 176 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 177 |
+
do_convert_rgb: bool = None,
|
| 178 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 179 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 180 |
+
**kwargs,
|
| 181 |
+
) -> PIL.Image.Image:
|
| 182 |
+
"""
|
| 183 |
+
Preprocess an image or batch of images.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
images (`ImageInput`):
|
| 187 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 188 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 189 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 190 |
+
Whether to resize the image.
|
| 191 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 192 |
+
Controls the size of the image after `resize`. The shortest edge of the image is resized to
|
| 193 |
+
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
| 194 |
+
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
| 195 |
+
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
| 196 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 197 |
+
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
|
| 198 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 199 |
+
Whether to rescale the image values between [0 - 1].
|
| 200 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 201 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 202 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 203 |
+
Whether to normalize the image.
|
| 204 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 205 |
+
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
| 206 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 207 |
+
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
| 208 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 209 |
+
Whether to convert the image to RGB.
|
| 210 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 211 |
+
The type of tensors to return. Can be one of:
|
| 212 |
+
- Unset: Return a list of `np.ndarray`.
|
| 213 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 214 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 215 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 216 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 217 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 218 |
+
The channel dimension format for the output image. Can be one of:
|
| 219 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 220 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 221 |
+
- Unset: Use the channel dimension format of the input image.
|
| 222 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 223 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 224 |
+
from the input image. Can be one of:
|
| 225 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 226 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 227 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 228 |
+
"""
|
| 229 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 230 |
+
resample = resample if resample is not None else self.resample
|
| 231 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 232 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 233 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 234 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 235 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 236 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 237 |
+
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
| 238 |
+
|
| 239 |
+
size = size if size is not None else self.size
|
| 240 |
+
size = get_size_dict(size, default_to_square=False)
|
| 241 |
+
images = make_list_of_images(images)
|
| 242 |
+
|
| 243 |
+
if not valid_images(images):
|
| 244 |
+
raise ValueError(
|
| 245 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 246 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if do_resize and size is None or resample is None:
|
| 250 |
+
raise ValueError("Size and resample must be specified if do_resize is True.")
|
| 251 |
+
|
| 252 |
+
if do_rescale and rescale_factor is None:
|
| 253 |
+
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
| 254 |
+
|
| 255 |
+
if do_normalize and (image_mean is None or image_std is None):
|
| 256 |
+
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
| 257 |
+
|
| 258 |
+
# PIL RGBA images are converted to RGB
|
| 259 |
+
if do_convert_rgb:
|
| 260 |
+
images = [convert_to_rgb(image) for image in images]
|
| 261 |
+
|
| 262 |
+
# All transformations expect numpy arrays.
|
| 263 |
+
images = [to_numpy_array(image) for image in images]
|
| 264 |
+
|
| 265 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 266 |
+
logger.warning_once(
|
| 267 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 268 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 269 |
+
)
|
| 270 |
+
if input_data_format is None:
|
| 271 |
+
# We assume that all images have the same channel dimension format.
|
| 272 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 273 |
+
|
| 274 |
+
if do_resize:
|
| 275 |
+
images = [
|
| 276 |
+
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 277 |
+
for image in images
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
if do_rescale:
|
| 281 |
+
images = [
|
| 282 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 283 |
+
for image in images
|
| 284 |
+
]
|
| 285 |
+
if do_normalize:
|
| 286 |
+
images = [
|
| 287 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 288 |
+
for image in images
|
| 289 |
+
]
|
| 290 |
+
if do_center_crop:
|
| 291 |
+
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
|
| 292 |
+
|
| 293 |
+
images = [
|
| 294 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
| 298 |
+
return encoded_outputs
|
| 299 |
+
|
| 300 |
+
# Follows diffusers.VaeImageProcessor.postprocess
|
| 301 |
+
def postprocess(self, sample: torch.Tensor, output_type: str = "pil"):
|
| 302 |
+
if output_type not in ["pt", "np", "pil"]:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Equivalent to diffusers.VaeImageProcessor.denormalize
|
| 308 |
+
sample = (sample / 2 + 0.5).clamp(0, 1)
|
| 309 |
+
if output_type == "pt":
|
| 310 |
+
return sample
|
| 311 |
+
|
| 312 |
+
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
|
| 313 |
+
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
| 314 |
+
if output_type == "np":
|
| 315 |
+
return sample
|
| 316 |
+
# Output_type must be 'pil'
|
| 317 |
+
sample = numpy_to_pil(sample)
|
| 318 |
+
return sample
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_blip2.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.utils.checkpoint
|
| 18 |
+
from torch import nn
|
| 19 |
+
from transformers import BertTokenizer
|
| 20 |
+
from transformers.activations import QuickGELUActivation as QuickGELU
|
| 21 |
+
from transformers.modeling_outputs import (
|
| 22 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 23 |
+
BaseModelOutputWithPooling,
|
| 24 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 25 |
+
)
|
| 26 |
+
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
|
| 27 |
+
from transformers.models.blip_2.modeling_blip_2 import (
|
| 28 |
+
Blip2Encoder,
|
| 29 |
+
Blip2PreTrainedModel,
|
| 30 |
+
Blip2QFormerAttention,
|
| 31 |
+
Blip2QFormerIntermediate,
|
| 32 |
+
Blip2QFormerOutput,
|
| 33 |
+
)
|
| 34 |
+
from transformers.pytorch_utils import apply_chunking_to_forward
|
| 35 |
+
from transformers.utils import (
|
| 36 |
+
logging,
|
| 37 |
+
replace_return_docstrings,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
|
| 45 |
+
# But it doesn't support getting multimodal embeddings. So, this module can be
|
| 46 |
+
# replaced with a future `transformers` version supports that.
|
| 47 |
+
class Blip2TextEmbeddings(nn.Module):
|
| 48 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 49 |
+
|
| 50 |
+
def __init__(self, config):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 53 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 54 |
+
|
| 55 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 56 |
+
# any TensorFlow checkpoint file
|
| 57 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 58 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 59 |
+
|
| 60 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 61 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 62 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 63 |
+
|
| 64 |
+
self.config = config
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
input_ids=None,
|
| 69 |
+
position_ids=None,
|
| 70 |
+
query_embeds=None,
|
| 71 |
+
past_key_values_length=0,
|
| 72 |
+
):
|
| 73 |
+
if input_ids is not None:
|
| 74 |
+
seq_length = input_ids.size()[1]
|
| 75 |
+
else:
|
| 76 |
+
seq_length = 0
|
| 77 |
+
|
| 78 |
+
if position_ids is None:
|
| 79 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
| 80 |
+
|
| 81 |
+
if input_ids is not None:
|
| 82 |
+
embeddings = self.word_embeddings(input_ids)
|
| 83 |
+
if self.position_embedding_type == "absolute":
|
| 84 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 85 |
+
embeddings = embeddings + position_embeddings
|
| 86 |
+
|
| 87 |
+
if query_embeds is not None:
|
| 88 |
+
batch_size = embeddings.shape[0]
|
| 89 |
+
# repeat the query embeddings for batch size
|
| 90 |
+
query_embeds = query_embeds.repeat(batch_size, 1, 1)
|
| 91 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
| 92 |
+
else:
|
| 93 |
+
embeddings = query_embeds
|
| 94 |
+
embeddings = embeddings.to(query_embeds.dtype)
|
| 95 |
+
embeddings = self.LayerNorm(embeddings)
|
| 96 |
+
embeddings = self.dropout(embeddings)
|
| 97 |
+
return embeddings
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
|
| 101 |
+
class Blip2VisionEmbeddings(nn.Module):
|
| 102 |
+
def __init__(self, config: Blip2VisionConfig):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.config = config
|
| 105 |
+
self.embed_dim = config.hidden_size
|
| 106 |
+
self.image_size = config.image_size
|
| 107 |
+
self.patch_size = config.patch_size
|
| 108 |
+
|
| 109 |
+
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
| 110 |
+
|
| 111 |
+
self.patch_embedding = nn.Conv2d(
|
| 112 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 116 |
+
self.num_positions = self.num_patches + 1
|
| 117 |
+
|
| 118 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 119 |
+
|
| 120 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
batch_size = pixel_values.shape[0]
|
| 122 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 123 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
| 124 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 125 |
+
|
| 126 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 127 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 128 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 129 |
+
return embeddings
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
|
| 133 |
+
class Blip2QFormerEncoder(nn.Module):
|
| 134 |
+
def __init__(self, config):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.config = config
|
| 137 |
+
self.layer = nn.ModuleList(
|
| 138 |
+
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 139 |
+
)
|
| 140 |
+
self.gradient_checkpointing = False
|
| 141 |
+
|
| 142 |
+
def forward(
|
| 143 |
+
self,
|
| 144 |
+
hidden_states,
|
| 145 |
+
attention_mask=None,
|
| 146 |
+
head_mask=None,
|
| 147 |
+
encoder_hidden_states=None,
|
| 148 |
+
encoder_attention_mask=None,
|
| 149 |
+
past_key_values=None,
|
| 150 |
+
use_cache=None,
|
| 151 |
+
output_attentions=False,
|
| 152 |
+
output_hidden_states=False,
|
| 153 |
+
return_dict=True,
|
| 154 |
+
query_length=0,
|
| 155 |
+
):
|
| 156 |
+
all_hidden_states = () if output_hidden_states else None
|
| 157 |
+
all_self_attentions = () if output_attentions else None
|
| 158 |
+
all_cross_attentions = () if output_attentions else None
|
| 159 |
+
|
| 160 |
+
next_decoder_cache = () if use_cache else None
|
| 161 |
+
|
| 162 |
+
for i in range(self.config.num_hidden_layers):
|
| 163 |
+
layer_module = self.layer[i]
|
| 164 |
+
if output_hidden_states:
|
| 165 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 166 |
+
|
| 167 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 168 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 169 |
+
|
| 170 |
+
if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled():
|
| 171 |
+
if use_cache:
|
| 172 |
+
logger.warning(
|
| 173 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 174 |
+
)
|
| 175 |
+
use_cache = False
|
| 176 |
+
|
| 177 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 178 |
+
layer_module,
|
| 179 |
+
hidden_states,
|
| 180 |
+
attention_mask,
|
| 181 |
+
layer_head_mask,
|
| 182 |
+
encoder_hidden_states,
|
| 183 |
+
encoder_attention_mask,
|
| 184 |
+
past_key_value,
|
| 185 |
+
output_attentions,
|
| 186 |
+
query_length,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
layer_outputs = layer_module(
|
| 190 |
+
hidden_states,
|
| 191 |
+
attention_mask,
|
| 192 |
+
layer_head_mask,
|
| 193 |
+
encoder_hidden_states,
|
| 194 |
+
encoder_attention_mask,
|
| 195 |
+
past_key_value,
|
| 196 |
+
output_attentions,
|
| 197 |
+
query_length,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
hidden_states = layer_outputs[0]
|
| 201 |
+
if use_cache:
|
| 202 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 203 |
+
if output_attentions:
|
| 204 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 205 |
+
if layer_module.has_cross_attention:
|
| 206 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 207 |
+
|
| 208 |
+
if output_hidden_states:
|
| 209 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 210 |
+
|
| 211 |
+
if not return_dict:
|
| 212 |
+
return tuple(
|
| 213 |
+
v
|
| 214 |
+
for v in [
|
| 215 |
+
hidden_states,
|
| 216 |
+
next_decoder_cache,
|
| 217 |
+
all_hidden_states,
|
| 218 |
+
all_self_attentions,
|
| 219 |
+
all_cross_attentions,
|
| 220 |
+
]
|
| 221 |
+
if v is not None
|
| 222 |
+
)
|
| 223 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 224 |
+
last_hidden_state=hidden_states,
|
| 225 |
+
past_key_values=next_decoder_cache,
|
| 226 |
+
hidden_states=all_hidden_states,
|
| 227 |
+
attentions=all_self_attentions,
|
| 228 |
+
cross_attentions=all_cross_attentions,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# The layers making up the Qformer encoder
|
| 233 |
+
class Blip2QFormerLayer(nn.Module):
|
| 234 |
+
def __init__(self, config, layer_idx):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 237 |
+
self.seq_len_dim = 1
|
| 238 |
+
self.attention = Blip2QFormerAttention(config)
|
| 239 |
+
|
| 240 |
+
self.layer_idx = layer_idx
|
| 241 |
+
|
| 242 |
+
if layer_idx % config.cross_attention_frequency == 0:
|
| 243 |
+
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
|
| 244 |
+
self.has_cross_attention = True
|
| 245 |
+
else:
|
| 246 |
+
self.has_cross_attention = False
|
| 247 |
+
|
| 248 |
+
self.intermediate = Blip2QFormerIntermediate(config)
|
| 249 |
+
self.intermediate_query = Blip2QFormerIntermediate(config)
|
| 250 |
+
self.output_query = Blip2QFormerOutput(config)
|
| 251 |
+
self.output = Blip2QFormerOutput(config)
|
| 252 |
+
|
| 253 |
+
def forward(
|
| 254 |
+
self,
|
| 255 |
+
hidden_states,
|
| 256 |
+
attention_mask=None,
|
| 257 |
+
head_mask=None,
|
| 258 |
+
encoder_hidden_states=None,
|
| 259 |
+
encoder_attention_mask=None,
|
| 260 |
+
past_key_value=None,
|
| 261 |
+
output_attentions=False,
|
| 262 |
+
query_length=0,
|
| 263 |
+
):
|
| 264 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 265 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 266 |
+
self_attention_outputs = self.attention(
|
| 267 |
+
hidden_states,
|
| 268 |
+
attention_mask,
|
| 269 |
+
head_mask,
|
| 270 |
+
output_attentions=output_attentions,
|
| 271 |
+
past_key_value=self_attn_past_key_value,
|
| 272 |
+
)
|
| 273 |
+
attention_output = self_attention_outputs[0]
|
| 274 |
+
outputs = self_attention_outputs[1:-1]
|
| 275 |
+
|
| 276 |
+
present_key_value = self_attention_outputs[-1]
|
| 277 |
+
|
| 278 |
+
if query_length > 0:
|
| 279 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 280 |
+
|
| 281 |
+
if self.has_cross_attention:
|
| 282 |
+
if encoder_hidden_states is None:
|
| 283 |
+
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
| 284 |
+
cross_attention_outputs = self.crossattention(
|
| 285 |
+
query_attention_output,
|
| 286 |
+
attention_mask,
|
| 287 |
+
head_mask,
|
| 288 |
+
encoder_hidden_states,
|
| 289 |
+
encoder_attention_mask,
|
| 290 |
+
output_attentions=output_attentions,
|
| 291 |
+
)
|
| 292 |
+
query_attention_output = cross_attention_outputs[0]
|
| 293 |
+
# add cross attentions if we output attention weights
|
| 294 |
+
outputs = outputs + cross_attention_outputs[1:-1]
|
| 295 |
+
|
| 296 |
+
layer_output = apply_chunking_to_forward(
|
| 297 |
+
self.feed_forward_chunk_query,
|
| 298 |
+
self.chunk_size_feed_forward,
|
| 299 |
+
self.seq_len_dim,
|
| 300 |
+
query_attention_output,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if attention_output.shape[1] > query_length:
|
| 304 |
+
layer_output_text = apply_chunking_to_forward(
|
| 305 |
+
self.feed_forward_chunk,
|
| 306 |
+
self.chunk_size_feed_forward,
|
| 307 |
+
self.seq_len_dim,
|
| 308 |
+
attention_output[:, query_length:, :],
|
| 309 |
+
)
|
| 310 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
| 311 |
+
else:
|
| 312 |
+
layer_output = apply_chunking_to_forward(
|
| 313 |
+
self.feed_forward_chunk,
|
| 314 |
+
self.chunk_size_feed_forward,
|
| 315 |
+
self.seq_len_dim,
|
| 316 |
+
attention_output,
|
| 317 |
+
)
|
| 318 |
+
outputs = (layer_output,) + outputs
|
| 319 |
+
|
| 320 |
+
outputs = outputs + (present_key_value,)
|
| 321 |
+
|
| 322 |
+
return outputs
|
| 323 |
+
|
| 324 |
+
def feed_forward_chunk(self, attention_output):
|
| 325 |
+
intermediate_output = self.intermediate(attention_output)
|
| 326 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 327 |
+
return layer_output
|
| 328 |
+
|
| 329 |
+
def feed_forward_chunk_query(self, attention_output):
|
| 330 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 331 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 332 |
+
return layer_output
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
|
| 336 |
+
class ProjLayer(nn.Module):
|
| 337 |
+
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
|
| 338 |
+
super().__init__()
|
| 339 |
+
|
| 340 |
+
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
|
| 341 |
+
self.dense1 = nn.Linear(in_dim, hidden_dim)
|
| 342 |
+
self.act_fn = QuickGELU()
|
| 343 |
+
self.dense2 = nn.Linear(hidden_dim, out_dim)
|
| 344 |
+
self.dropout = nn.Dropout(drop_p)
|
| 345 |
+
|
| 346 |
+
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
|
| 347 |
+
|
| 348 |
+
def forward(self, x):
|
| 349 |
+
x_in = x
|
| 350 |
+
|
| 351 |
+
x = self.LayerNorm(x)
|
| 352 |
+
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
|
| 353 |
+
|
| 354 |
+
return x
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
|
| 358 |
+
class Blip2VisionModel(Blip2PreTrainedModel):
|
| 359 |
+
main_input_name = "pixel_values"
|
| 360 |
+
config_class = Blip2VisionConfig
|
| 361 |
+
|
| 362 |
+
def __init__(self, config: Blip2VisionConfig):
|
| 363 |
+
super().__init__(config)
|
| 364 |
+
self.config = config
|
| 365 |
+
embed_dim = config.hidden_size
|
| 366 |
+
self.embeddings = Blip2VisionEmbeddings(config)
|
| 367 |
+
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 368 |
+
self.encoder = Blip2Encoder(config)
|
| 369 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 370 |
+
|
| 371 |
+
self.post_init()
|
| 372 |
+
|
| 373 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
|
| 374 |
+
def forward(
|
| 375 |
+
self,
|
| 376 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 377 |
+
output_attentions: Optional[bool] = None,
|
| 378 |
+
output_hidden_states: Optional[bool] = None,
|
| 379 |
+
return_dict: Optional[bool] = None,
|
| 380 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 381 |
+
r"""
|
| 382 |
+
Returns:
|
| 383 |
+
|
| 384 |
+
"""
|
| 385 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 386 |
+
output_hidden_states = (
|
| 387 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 388 |
+
)
|
| 389 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 390 |
+
|
| 391 |
+
if pixel_values is None:
|
| 392 |
+
raise ValueError("You have to specify pixel_values")
|
| 393 |
+
|
| 394 |
+
hidden_states = self.embeddings(pixel_values)
|
| 395 |
+
hidden_states = self.pre_layernorm(hidden_states)
|
| 396 |
+
encoder_outputs = self.encoder(
|
| 397 |
+
inputs_embeds=hidden_states,
|
| 398 |
+
output_attentions=output_attentions,
|
| 399 |
+
output_hidden_states=output_hidden_states,
|
| 400 |
+
return_dict=return_dict,
|
| 401 |
+
)
|
| 402 |
+
last_hidden_state = encoder_outputs[0]
|
| 403 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 404 |
+
|
| 405 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 406 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 407 |
+
|
| 408 |
+
if not return_dict:
|
| 409 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 410 |
+
|
| 411 |
+
return BaseModelOutputWithPooling(
|
| 412 |
+
last_hidden_state=last_hidden_state,
|
| 413 |
+
pooler_output=pooled_output,
|
| 414 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 415 |
+
attentions=encoder_outputs.attentions,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def get_input_embeddings(self):
|
| 419 |
+
return self.embeddings
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# Qformer model, used to get multimodal embeddings from the text and image inputs
|
| 423 |
+
class Blip2QFormerModel(Blip2PreTrainedModel):
|
| 424 |
+
"""
|
| 425 |
+
Querying Transformer (Q-Former), used in BLIP-2.
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
def __init__(self, config: Blip2Config):
|
| 429 |
+
super().__init__(config)
|
| 430 |
+
self.config = config
|
| 431 |
+
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
|
| 432 |
+
self.visual_encoder = Blip2VisionModel(config.vision_config)
|
| 433 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 434 |
+
if not hasattr(config, "tokenizer") or config.tokenizer is None:
|
| 435 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
| 436 |
+
else:
|
| 437 |
+
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
|
| 438 |
+
self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
| 439 |
+
self.proj_layer = ProjLayer(
|
| 440 |
+
in_dim=config.qformer_config.hidden_size,
|
| 441 |
+
out_dim=config.qformer_config.hidden_size,
|
| 442 |
+
hidden_dim=config.qformer_config.hidden_size * 4,
|
| 443 |
+
drop_p=0.1,
|
| 444 |
+
eps=1e-12,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
self.encoder = Blip2QFormerEncoder(config.qformer_config)
|
| 448 |
+
|
| 449 |
+
self.post_init()
|
| 450 |
+
|
| 451 |
+
def get_input_embeddings(self):
|
| 452 |
+
return self.embeddings.word_embeddings
|
| 453 |
+
|
| 454 |
+
def set_input_embeddings(self, value):
|
| 455 |
+
self.embeddings.word_embeddings = value
|
| 456 |
+
|
| 457 |
+
def _prune_heads(self, heads_to_prune):
|
| 458 |
+
"""
|
| 459 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 460 |
+
class PreTrainedModel
|
| 461 |
+
"""
|
| 462 |
+
for layer, heads in heads_to_prune.items():
|
| 463 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 464 |
+
|
| 465 |
+
def get_extended_attention_mask(
|
| 466 |
+
self,
|
| 467 |
+
attention_mask: torch.Tensor,
|
| 468 |
+
input_shape: Tuple[int],
|
| 469 |
+
device: torch.device,
|
| 470 |
+
has_query: bool = False,
|
| 471 |
+
) -> torch.Tensor:
|
| 472 |
+
"""
|
| 473 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 474 |
+
|
| 475 |
+
Arguments:
|
| 476 |
+
attention_mask (`torch.Tensor`):
|
| 477 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 478 |
+
input_shape (`Tuple[int]`):
|
| 479 |
+
The shape of the input to the model.
|
| 480 |
+
device (`torch.device`):
|
| 481 |
+
The device of the input to the model.
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
| 485 |
+
"""
|
| 486 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 487 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 488 |
+
if attention_mask.dim() == 3:
|
| 489 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 490 |
+
elif attention_mask.dim() == 2:
|
| 491 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 492 |
+
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 493 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 494 |
+
else:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 497 |
+
input_shape, attention_mask.shape
|
| 498 |
+
)
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 502 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 503 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 504 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 505 |
+
# effectively the same as removing these entirely.
|
| 506 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 507 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 508 |
+
return extended_attention_mask
|
| 509 |
+
|
| 510 |
+
def forward(
|
| 511 |
+
self,
|
| 512 |
+
text_input=None,
|
| 513 |
+
image_input=None,
|
| 514 |
+
head_mask=None,
|
| 515 |
+
encoder_hidden_states=None,
|
| 516 |
+
encoder_attention_mask=None,
|
| 517 |
+
past_key_values=None,
|
| 518 |
+
use_cache=None,
|
| 519 |
+
output_attentions=None,
|
| 520 |
+
output_hidden_states=None,
|
| 521 |
+
return_dict=None,
|
| 522 |
+
):
|
| 523 |
+
r"""
|
| 524 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 525 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 526 |
+
the model is configured as a decoder.
|
| 527 |
+
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, `optional`):
|
| 528 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 529 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 530 |
+
- 1 for tokens that are **not masked**,
|
| 531 |
+
- 0 for tokens that are **masked**.
|
| 532 |
+
past_key_values (`tuple(tuple(torch.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
| 533 |
+
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
| 534 |
+
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
| 535 |
+
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
| 536 |
+
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
| 537 |
+
`(batch_size, sequence_length)`.
|
| 538 |
+
use_cache (`bool`, `optional`):
|
| 539 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 540 |
+
`past_key_values`).
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
text = self.tokenizer(text_input, return_tensors="pt", padding=True)
|
| 544 |
+
text = text.to(self.device)
|
| 545 |
+
input_ids = text.input_ids
|
| 546 |
+
batch_size = input_ids.shape[0]
|
| 547 |
+
query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
|
| 548 |
+
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
|
| 549 |
+
|
| 550 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 551 |
+
output_hidden_states = (
|
| 552 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 553 |
+
)
|
| 554 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 555 |
+
|
| 556 |
+
# past_key_values_length
|
| 557 |
+
past_key_values_length = (
|
| 558 |
+
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
query_length = self.query_tokens.shape[1]
|
| 562 |
+
|
| 563 |
+
embedding_output = self.embeddings(
|
| 564 |
+
input_ids=input_ids,
|
| 565 |
+
query_embeds=self.query_tokens,
|
| 566 |
+
past_key_values_length=past_key_values_length,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# embedding_output = self.layernorm(query_embeds)
|
| 570 |
+
# embedding_output = self.dropout(embedding_output)
|
| 571 |
+
|
| 572 |
+
input_shape = embedding_output.size()[:-1]
|
| 573 |
+
batch_size, seq_length = input_shape
|
| 574 |
+
device = embedding_output.device
|
| 575 |
+
|
| 576 |
+
image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
|
| 577 |
+
# image_embeds_frozen = torch.ones_like(image_embeds_frozen)
|
| 578 |
+
encoder_hidden_states = image_embeds_frozen
|
| 579 |
+
|
| 580 |
+
if attention_mask is None:
|
| 581 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 582 |
+
|
| 583 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 584 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 585 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
| 586 |
+
|
| 587 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 588 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 589 |
+
if encoder_hidden_states is not None:
|
| 590 |
+
if isinstance(encoder_hidden_states, list):
|
| 591 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 592 |
+
else:
|
| 593 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 594 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 595 |
+
|
| 596 |
+
if isinstance(encoder_attention_mask, list):
|
| 597 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 598 |
+
elif encoder_attention_mask is None:
|
| 599 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 600 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 601 |
+
else:
|
| 602 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 603 |
+
else:
|
| 604 |
+
encoder_extended_attention_mask = None
|
| 605 |
+
|
| 606 |
+
# Prepare head mask if needed
|
| 607 |
+
# 1.0 in head_mask indicate we keep the head
|
| 608 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 609 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 610 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 611 |
+
head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
|
| 612 |
+
|
| 613 |
+
encoder_outputs = self.encoder(
|
| 614 |
+
embedding_output,
|
| 615 |
+
attention_mask=extended_attention_mask,
|
| 616 |
+
head_mask=head_mask,
|
| 617 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 618 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 619 |
+
past_key_values=past_key_values,
|
| 620 |
+
use_cache=use_cache,
|
| 621 |
+
output_attentions=output_attentions,
|
| 622 |
+
output_hidden_states=output_hidden_states,
|
| 623 |
+
return_dict=return_dict,
|
| 624 |
+
query_length=query_length,
|
| 625 |
+
)
|
| 626 |
+
sequence_output = encoder_outputs[0]
|
| 627 |
+
pooled_output = sequence_output[:, 0, :]
|
| 628 |
+
|
| 629 |
+
if not return_dict:
|
| 630 |
+
return self.proj_layer(sequence_output[:, :query_length, :])
|
| 631 |
+
|
| 632 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 633 |
+
last_hidden_state=sequence_output,
|
| 634 |
+
pooler_output=pooled_output,
|
| 635 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 636 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 637 |
+
attentions=encoder_outputs.attentions,
|
| 638 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 639 |
+
)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Salesforce.com, inc.
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from typing import Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from transformers import CLIPPreTrainedModel
|
| 20 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 21 |
+
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
| 22 |
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 26 |
+
"""
|
| 27 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 28 |
+
"""
|
| 29 |
+
bsz, src_len = mask.size()
|
| 30 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 31 |
+
|
| 32 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 33 |
+
|
| 34 |
+
inverted_mask = 1.0 - expanded_mask
|
| 35 |
+
|
| 36 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
|
| 40 |
+
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
|
| 41 |
+
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
|
| 42 |
+
class ContextCLIPTextModel(CLIPPreTrainedModel):
|
| 43 |
+
config_class = CLIPTextConfig
|
| 44 |
+
|
| 45 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
| 46 |
+
|
| 47 |
+
def __init__(self, config: CLIPTextConfig):
|
| 48 |
+
super().__init__(config)
|
| 49 |
+
self.text_model = ContextCLIPTextTransformer(config)
|
| 50 |
+
# Initialize weights and apply final processing
|
| 51 |
+
self.post_init()
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self,
|
| 55 |
+
ctx_embeddings: torch.Tensor = None,
|
| 56 |
+
ctx_begin_pos: list = None,
|
| 57 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 59 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 60 |
+
output_attentions: Optional[bool] = None,
|
| 61 |
+
output_hidden_states: Optional[bool] = None,
|
| 62 |
+
return_dict: Optional[bool] = None,
|
| 63 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 64 |
+
return self.text_model(
|
| 65 |
+
ctx_embeddings=ctx_embeddings,
|
| 66 |
+
ctx_begin_pos=ctx_begin_pos,
|
| 67 |
+
input_ids=input_ids,
|
| 68 |
+
attention_mask=attention_mask,
|
| 69 |
+
position_ids=position_ids,
|
| 70 |
+
output_attentions=output_attentions,
|
| 71 |
+
output_hidden_states=output_hidden_states,
|
| 72 |
+
return_dict=return_dict,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ContextCLIPTextTransformer(nn.Module):
|
| 77 |
+
def __init__(self, config: CLIPTextConfig):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.config = config
|
| 80 |
+
embed_dim = config.hidden_size
|
| 81 |
+
self.embeddings = ContextCLIPTextEmbeddings(config)
|
| 82 |
+
self.encoder = CLIPEncoder(config)
|
| 83 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
| 84 |
+
|
| 85 |
+
def forward(
|
| 86 |
+
self,
|
| 87 |
+
ctx_embeddings: torch.Tensor,
|
| 88 |
+
ctx_begin_pos: list,
|
| 89 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 91 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 92 |
+
output_attentions: Optional[bool] = None,
|
| 93 |
+
output_hidden_states: Optional[bool] = None,
|
| 94 |
+
return_dict: Optional[bool] = None,
|
| 95 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 96 |
+
r"""
|
| 97 |
+
Returns:
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 101 |
+
output_hidden_states = (
|
| 102 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 103 |
+
)
|
| 104 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 105 |
+
|
| 106 |
+
if input_ids is None:
|
| 107 |
+
raise ValueError("You have to specify either input_ids")
|
| 108 |
+
|
| 109 |
+
input_shape = input_ids.size()
|
| 110 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 111 |
+
|
| 112 |
+
hidden_states = self.embeddings(
|
| 113 |
+
input_ids=input_ids,
|
| 114 |
+
position_ids=position_ids,
|
| 115 |
+
ctx_embeddings=ctx_embeddings,
|
| 116 |
+
ctx_begin_pos=ctx_begin_pos,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
bsz, seq_len = input_shape
|
| 120 |
+
if ctx_embeddings is not None:
|
| 121 |
+
seq_len += ctx_embeddings.size(1)
|
| 122 |
+
# CLIP's text model uses causal mask, prepare it here.
|
| 123 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
| 124 |
+
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
| 125 |
+
hidden_states.device
|
| 126 |
+
)
|
| 127 |
+
# expand attention_mask
|
| 128 |
+
if attention_mask is not None:
|
| 129 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 130 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
| 131 |
+
|
| 132 |
+
encoder_outputs = self.encoder(
|
| 133 |
+
inputs_embeds=hidden_states,
|
| 134 |
+
attention_mask=attention_mask,
|
| 135 |
+
causal_attention_mask=causal_attention_mask,
|
| 136 |
+
output_attentions=output_attentions,
|
| 137 |
+
output_hidden_states=output_hidden_states,
|
| 138 |
+
return_dict=return_dict,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
last_hidden_state = encoder_outputs[0]
|
| 142 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
| 143 |
+
|
| 144 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
| 145 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 146 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
| 147 |
+
pooled_output = last_hidden_state[
|
| 148 |
+
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
|
| 149 |
+
input_ids.to(torch.int).argmax(dim=-1),
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
if not return_dict:
|
| 153 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 154 |
+
|
| 155 |
+
return BaseModelOutputWithPooling(
|
| 156 |
+
last_hidden_state=last_hidden_state,
|
| 157 |
+
pooler_output=pooled_output,
|
| 158 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 159 |
+
attentions=encoder_outputs.attentions,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
|
| 163 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 164 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 165 |
+
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
| 166 |
+
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
| 167 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 168 |
+
mask = mask.unsqueeze(1) # expand mask
|
| 169 |
+
return mask
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ContextCLIPTextEmbeddings(nn.Module):
|
| 173 |
+
def __init__(self, config: CLIPTextConfig):
|
| 174 |
+
super().__init__()
|
| 175 |
+
embed_dim = config.hidden_size
|
| 176 |
+
|
| 177 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
| 178 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
| 179 |
+
|
| 180 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 181 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 182 |
+
|
| 183 |
+
def forward(
|
| 184 |
+
self,
|
| 185 |
+
ctx_embeddings: torch.Tensor,
|
| 186 |
+
ctx_begin_pos: list,
|
| 187 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 188 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 189 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 190 |
+
) -> torch.Tensor:
|
| 191 |
+
if ctx_embeddings is None:
|
| 192 |
+
ctx_len = 0
|
| 193 |
+
else:
|
| 194 |
+
ctx_len = ctx_embeddings.shape[1]
|
| 195 |
+
|
| 196 |
+
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
|
| 197 |
+
|
| 198 |
+
if position_ids is None:
|
| 199 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 200 |
+
|
| 201 |
+
if inputs_embeds is None:
|
| 202 |
+
inputs_embeds = self.token_embedding(input_ids)
|
| 203 |
+
|
| 204 |
+
# for each input embeddings, add the ctx embeddings at the correct position
|
| 205 |
+
input_embeds_ctx = []
|
| 206 |
+
bsz = inputs_embeds.shape[0]
|
| 207 |
+
|
| 208 |
+
if ctx_embeddings is not None:
|
| 209 |
+
for i in range(bsz):
|
| 210 |
+
cbp = ctx_begin_pos[i]
|
| 211 |
+
|
| 212 |
+
prefix = inputs_embeds[i, :cbp]
|
| 213 |
+
# remove the special token embedding
|
| 214 |
+
suffix = inputs_embeds[i, cbp:]
|
| 215 |
+
|
| 216 |
+
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
|
| 217 |
+
|
| 218 |
+
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
|
| 219 |
+
|
| 220 |
+
position_embeddings = self.position_embedding(position_ids)
|
| 221 |
+
embeddings = inputs_embeds + position_embeddings
|
| 222 |
+
|
| 223 |
+
return embeddings
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Salesforce.com, inc.
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import List, Optional, Union
|
| 15 |
+
|
| 16 |
+
import PIL.Image
|
| 17 |
+
import torch
|
| 18 |
+
from transformers import CLIPTokenizer
|
| 19 |
+
|
| 20 |
+
from ...models import AutoencoderKL, UNet2DConditionModel
|
| 21 |
+
from ...schedulers import PNDMScheduler
|
| 22 |
+
from ...utils import (
|
| 23 |
+
is_torch_xla_available,
|
| 24 |
+
logging,
|
| 25 |
+
replace_example_docstring,
|
| 26 |
+
)
|
| 27 |
+
from ...utils.torch_utils import randn_tensor
|
| 28 |
+
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
| 29 |
+
from .blip_image_processing import BlipImageProcessor
|
| 30 |
+
from .modeling_blip2 import Blip2QFormerModel
|
| 31 |
+
from .modeling_ctx_clip import ContextCLIPTextModel
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```py
|
| 47 |
+
>>> from diffusers.pipelines import BlipDiffusionPipeline
|
| 48 |
+
>>> from diffusers.utils import load_image
|
| 49 |
+
>>> import torch
|
| 50 |
+
|
| 51 |
+
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
|
| 52 |
+
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
|
| 53 |
+
... ).to("cuda")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
>>> cond_subject = "dog"
|
| 57 |
+
>>> tgt_subject = "dog"
|
| 58 |
+
>>> text_prompt_input = "swimming underwater"
|
| 59 |
+
|
| 60 |
+
>>> cond_image = load_image(
|
| 61 |
+
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
|
| 62 |
+
... )
|
| 63 |
+
>>> guidance_scale = 7.5
|
| 64 |
+
>>> num_inference_steps = 25
|
| 65 |
+
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
>>> output = blip_diffusion_pipe(
|
| 69 |
+
... text_prompt_input,
|
| 70 |
+
... cond_image,
|
| 71 |
+
... cond_subject,
|
| 72 |
+
... tgt_subject,
|
| 73 |
+
... guidance_scale=guidance_scale,
|
| 74 |
+
... num_inference_steps=num_inference_steps,
|
| 75 |
+
... neg_prompt=negative_prompt,
|
| 76 |
+
... height=512,
|
| 77 |
+
... width=512,
|
| 78 |
+
... ).images
|
| 79 |
+
>>> output[0].save("image.png")
|
| 80 |
+
```
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
| 85 |
+
"""
|
| 86 |
+
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
|
| 87 |
+
|
| 88 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 89 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
tokenizer ([`CLIPTokenizer`]):
|
| 93 |
+
Tokenizer for the text encoder
|
| 94 |
+
text_encoder ([`ContextCLIPTextModel`]):
|
| 95 |
+
Text encoder to encode the text prompt
|
| 96 |
+
vae ([`AutoencoderKL`]):
|
| 97 |
+
VAE model to map the latents to the image
|
| 98 |
+
unet ([`UNet2DConditionModel`]):
|
| 99 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 100 |
+
scheduler ([`PNDMScheduler`]):
|
| 101 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 102 |
+
qformer ([`Blip2QFormerModel`]):
|
| 103 |
+
QFormer model to get multi-modal embeddings from the text and image.
|
| 104 |
+
image_processor ([`BlipImageProcessor`]):
|
| 105 |
+
Image Processor to preprocess and postprocess the image.
|
| 106 |
+
ctx_begin_pos (int, `optional`, defaults to 2):
|
| 107 |
+
Position of the context token in the text encoder.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
_last_supported_version = "0.33.1"
|
| 111 |
+
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
tokenizer: CLIPTokenizer,
|
| 116 |
+
text_encoder: ContextCLIPTextModel,
|
| 117 |
+
vae: AutoencoderKL,
|
| 118 |
+
unet: UNet2DConditionModel,
|
| 119 |
+
scheduler: PNDMScheduler,
|
| 120 |
+
qformer: Blip2QFormerModel,
|
| 121 |
+
image_processor: BlipImageProcessor,
|
| 122 |
+
ctx_begin_pos: int = 2,
|
| 123 |
+
mean: List[float] = None,
|
| 124 |
+
std: List[float] = None,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.register_modules(
|
| 129 |
+
tokenizer=tokenizer,
|
| 130 |
+
text_encoder=text_encoder,
|
| 131 |
+
vae=vae,
|
| 132 |
+
unet=unet,
|
| 133 |
+
scheduler=scheduler,
|
| 134 |
+
qformer=qformer,
|
| 135 |
+
image_processor=image_processor,
|
| 136 |
+
)
|
| 137 |
+
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
| 138 |
+
|
| 139 |
+
def get_query_embeddings(self, input_image, src_subject):
|
| 140 |
+
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
| 141 |
+
|
| 142 |
+
# from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
|
| 143 |
+
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
| 144 |
+
rv = []
|
| 145 |
+
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
| 146 |
+
prompt = f"a {tgt_subject} {prompt.strip()}"
|
| 147 |
+
# a trick to amplify the prompt
|
| 148 |
+
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
| 149 |
+
|
| 150 |
+
return rv
|
| 151 |
+
|
| 152 |
+
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
| 153 |
+
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
| 154 |
+
shape = (batch_size, num_channels, height, width)
|
| 155 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 158 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if latents is None:
|
| 162 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 163 |
+
else:
|
| 164 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 165 |
+
|
| 166 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 167 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 168 |
+
return latents
|
| 169 |
+
|
| 170 |
+
def encode_prompt(self, query_embeds, prompt, device=None):
|
| 171 |
+
device = device or self._execution_device
|
| 172 |
+
|
| 173 |
+
# embeddings for prompt, with query_embeds as context
|
| 174 |
+
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
| 175 |
+
max_len -= self.qformer.config.num_query_tokens
|
| 176 |
+
|
| 177 |
+
tokenized_prompt = self.tokenizer(
|
| 178 |
+
prompt,
|
| 179 |
+
padding="max_length",
|
| 180 |
+
truncation=True,
|
| 181 |
+
max_length=max_len,
|
| 182 |
+
return_tensors="pt",
|
| 183 |
+
).to(device)
|
| 184 |
+
|
| 185 |
+
batch_size = query_embeds.shape[0]
|
| 186 |
+
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
| 187 |
+
|
| 188 |
+
text_embeddings = self.text_encoder(
|
| 189 |
+
input_ids=tokenized_prompt.input_ids,
|
| 190 |
+
ctx_embeddings=query_embeds,
|
| 191 |
+
ctx_begin_pos=ctx_begin_pos,
|
| 192 |
+
)[0]
|
| 193 |
+
|
| 194 |
+
return text_embeddings
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 198 |
+
def __call__(
|
| 199 |
+
self,
|
| 200 |
+
prompt: List[str],
|
| 201 |
+
reference_image: PIL.Image.Image,
|
| 202 |
+
source_subject_category: List[str],
|
| 203 |
+
target_subject_category: List[str],
|
| 204 |
+
latents: Optional[torch.Tensor] = None,
|
| 205 |
+
guidance_scale: float = 7.5,
|
| 206 |
+
height: int = 512,
|
| 207 |
+
width: int = 512,
|
| 208 |
+
num_inference_steps: int = 50,
|
| 209 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 210 |
+
neg_prompt: Optional[str] = "",
|
| 211 |
+
prompt_strength: float = 1.0,
|
| 212 |
+
prompt_reps: int = 20,
|
| 213 |
+
output_type: Optional[str] = "pil",
|
| 214 |
+
return_dict: bool = True,
|
| 215 |
+
):
|
| 216 |
+
"""
|
| 217 |
+
Function invoked when calling the pipeline for generation.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
prompt (`List[str]`):
|
| 221 |
+
The prompt or prompts to guide the image generation.
|
| 222 |
+
reference_image (`PIL.Image.Image`):
|
| 223 |
+
The reference image to condition the generation on.
|
| 224 |
+
source_subject_category (`List[str]`):
|
| 225 |
+
The source subject category.
|
| 226 |
+
target_subject_category (`List[str]`):
|
| 227 |
+
The target subject category.
|
| 228 |
+
latents (`torch.Tensor`, *optional*):
|
| 229 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 230 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 231 |
+
tensor will be generated by random sampling.
|
| 232 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 233 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 234 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 235 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 236 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 237 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 238 |
+
height (`int`, *optional*, defaults to 512):
|
| 239 |
+
The height of the generated image.
|
| 240 |
+
width (`int`, *optional*, defaults to 512):
|
| 241 |
+
The width of the generated image.
|
| 242 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 243 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 244 |
+
expense of slower inference.
|
| 245 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 246 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 247 |
+
to make generation deterministic.
|
| 248 |
+
neg_prompt (`str`, *optional*, defaults to ""):
|
| 249 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 250 |
+
if `guidance_scale` is less than `1`).
|
| 251 |
+
prompt_strength (`float`, *optional*, defaults to 1.0):
|
| 252 |
+
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
| 253 |
+
to amplify the prompt.
|
| 254 |
+
prompt_reps (`int`, *optional*, defaults to 20):
|
| 255 |
+
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
| 256 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 257 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 258 |
+
(`np.array`) or `"pt"` (`torch.Tensor`).
|
| 259 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 260 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 261 |
+
Examples:
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 265 |
+
"""
|
| 266 |
+
device = self._execution_device
|
| 267 |
+
|
| 268 |
+
reference_image = self.image_processor.preprocess(
|
| 269 |
+
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
| 270 |
+
)["pixel_values"]
|
| 271 |
+
reference_image = reference_image.to(device)
|
| 272 |
+
|
| 273 |
+
if isinstance(prompt, str):
|
| 274 |
+
prompt = [prompt]
|
| 275 |
+
if isinstance(source_subject_category, str):
|
| 276 |
+
source_subject_category = [source_subject_category]
|
| 277 |
+
if isinstance(target_subject_category, str):
|
| 278 |
+
target_subject_category = [target_subject_category]
|
| 279 |
+
|
| 280 |
+
batch_size = len(prompt)
|
| 281 |
+
|
| 282 |
+
prompt = self._build_prompt(
|
| 283 |
+
prompts=prompt,
|
| 284 |
+
tgt_subjects=target_subject_category,
|
| 285 |
+
prompt_strength=prompt_strength,
|
| 286 |
+
prompt_reps=prompt_reps,
|
| 287 |
+
)
|
| 288 |
+
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
| 289 |
+
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
|
| 290 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 291 |
+
if do_classifier_free_guidance:
|
| 292 |
+
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
| 293 |
+
|
| 294 |
+
uncond_input = self.tokenizer(
|
| 295 |
+
[neg_prompt] * batch_size,
|
| 296 |
+
padding="max_length",
|
| 297 |
+
max_length=max_length,
|
| 298 |
+
return_tensors="pt",
|
| 299 |
+
)
|
| 300 |
+
uncond_embeddings = self.text_encoder(
|
| 301 |
+
input_ids=uncond_input.input_ids.to(device),
|
| 302 |
+
ctx_embeddings=None,
|
| 303 |
+
)[0]
|
| 304 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 305 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 306 |
+
# to avoid doing two forward passes
|
| 307 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 308 |
+
|
| 309 |
+
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
| 310 |
+
latents = self.prepare_latents(
|
| 311 |
+
batch_size=batch_size,
|
| 312 |
+
num_channels=self.unet.config.in_channels,
|
| 313 |
+
height=height // scale_down_factor,
|
| 314 |
+
width=width // scale_down_factor,
|
| 315 |
+
generator=generator,
|
| 316 |
+
latents=latents,
|
| 317 |
+
dtype=self.unet.dtype,
|
| 318 |
+
device=device,
|
| 319 |
+
)
|
| 320 |
+
# set timesteps
|
| 321 |
+
extra_set_kwargs = {}
|
| 322 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
| 323 |
+
|
| 324 |
+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
| 325 |
+
# expand the latents if we are doing classifier free guidance
|
| 326 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 327 |
+
|
| 328 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 329 |
+
|
| 330 |
+
noise_pred = self.unet(
|
| 331 |
+
latent_model_input,
|
| 332 |
+
timestep=t,
|
| 333 |
+
encoder_hidden_states=text_embeddings,
|
| 334 |
+
down_block_additional_residuals=None,
|
| 335 |
+
mid_block_additional_residual=None,
|
| 336 |
+
)["sample"]
|
| 337 |
+
|
| 338 |
+
# perform guidance
|
| 339 |
+
if do_classifier_free_guidance:
|
| 340 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 341 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 342 |
+
|
| 343 |
+
latents = self.scheduler.step(
|
| 344 |
+
noise_pred,
|
| 345 |
+
t,
|
| 346 |
+
latents,
|
| 347 |
+
)["prev_sample"]
|
| 348 |
+
|
| 349 |
+
if XLA_AVAILABLE:
|
| 350 |
+
xm.mark_step()
|
| 351 |
+
|
| 352 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 353 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 354 |
+
|
| 355 |
+
# Offload all models
|
| 356 |
+
self.maybe_free_model_hooks()
|
| 357 |
+
|
| 358 |
+
if not return_dict:
|
| 359 |
+
return (image,)
|
| 360 |
+
|
| 361 |
+
return ImagePipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["pipeline_bria"] = ["BriaPipeline"]
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 28 |
+
try:
|
| 29 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 30 |
+
raise OptionalDependencyNotAvailable()
|
| 31 |
+
|
| 32 |
+
except OptionalDependencyNotAvailable:
|
| 33 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 34 |
+
else:
|
| 35 |
+
from .pipeline_bria import BriaPipeline
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
import sys
|
| 39 |
+
|
| 40 |
+
sys.modules[__name__] = _LazyModule(
|
| 41 |
+
__name__,
|
| 42 |
+
globals()["__file__"],
|
| 43 |
+
_import_structure,
|
| 44 |
+
module_spec=__spec__,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
for name, value in _dummy_objects.items():
|
| 48 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_bria.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import (
|
| 6 |
+
CLIPImageProcessor,
|
| 7 |
+
CLIPVisionModelWithProjection,
|
| 8 |
+
T5EncoderModel,
|
| 9 |
+
T5TokenizerFast,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from ...image_processor import VaeImageProcessor
|
| 13 |
+
from ...loaders import FluxLoraLoaderMixin
|
| 14 |
+
from ...models import AutoencoderKL
|
| 15 |
+
from ...models.transformers.transformer_bria import BriaTransformer2DModel
|
| 16 |
+
from ...pipelines import DiffusionPipeline
|
| 17 |
+
from ...pipelines.bria.pipeline_output import BriaPipelineOutput
|
| 18 |
+
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
| 19 |
+
from ...schedulers import (
|
| 20 |
+
DDIMScheduler,
|
| 21 |
+
EulerAncestralDiscreteScheduler,
|
| 22 |
+
FlowMatchEulerDiscreteScheduler,
|
| 23 |
+
KarrasDiffusionSchedulers,
|
| 24 |
+
)
|
| 25 |
+
from ...utils import (
|
| 26 |
+
USE_PEFT_BACKEND,
|
| 27 |
+
is_torch_xla_available,
|
| 28 |
+
logging,
|
| 29 |
+
replace_example_docstring,
|
| 30 |
+
scale_lora_layers,
|
| 31 |
+
unscale_lora_layers,
|
| 32 |
+
)
|
| 33 |
+
from ...utils.torch_utils import randn_tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if is_torch_xla_available():
|
| 37 |
+
import torch_xla.core.xla_model as xm
|
| 38 |
+
|
| 39 |
+
XLA_AVAILABLE = True
|
| 40 |
+
else:
|
| 41 |
+
XLA_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
EXAMPLE_DOC_STRING = """
|
| 47 |
+
Examples:
|
| 48 |
+
```py
|
| 49 |
+
>>> import torch
|
| 50 |
+
>>> from diffusers import BriaPipeline
|
| 51 |
+
|
| 52 |
+
>>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
|
| 53 |
+
>>> pipe.to("cuda")
|
| 54 |
+
# BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32.
|
| 55 |
+
|
| 56 |
+
>>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16)
|
| 57 |
+
>>> for block in pipe.text_encoder.encoder.block:
|
| 58 |
+
... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
|
| 59 |
+
# BRIA's VAE is not supported in mixed precision, so we use float32.
|
| 60 |
+
|
| 61 |
+
>>> if pipe.vae.config.shift_factor == 0:
|
| 62 |
+
... pipe.vae.to(dtype=torch.float32)
|
| 63 |
+
|
| 64 |
+
>>> prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
|
| 65 |
+
>>> image = pipe(prompt).images[0]
|
| 66 |
+
>>> image.save("bria.png")
|
| 67 |
+
```
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def is_ng_none(negative_prompt):
|
| 72 |
+
return (
|
| 73 |
+
negative_prompt is None
|
| 74 |
+
or negative_prompt == ""
|
| 75 |
+
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
|
| 76 |
+
or (type(negative_prompt) == list and negative_prompt[0] == "")
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
|
| 81 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 82 |
+
sigmas = timesteps / num_train_timesteps
|
| 83 |
+
|
| 84 |
+
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
|
| 85 |
+
new_sigmas = sigmas[inds]
|
| 86 |
+
return new_sigmas
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BriaPipeline(DiffusionPipeline):
|
| 90 |
+
r"""
|
| 91 |
+
Based on FluxPipeline with several changes:
|
| 92 |
+
- no pooled embeddings
|
| 93 |
+
- We use zero padding for prompts
|
| 94 |
+
- No guidance embedding since this is not a distilled version
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
transformer ([`BriaTransformer2DModel`]):
|
| 98 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 99 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 100 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 101 |
+
vae ([`AutoencoderKL`]):
|
| 102 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 103 |
+
text_encoder ([`T5EncoderModel`]):
|
| 104 |
+
Frozen text-encoder. Bria uses
|
| 105 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 106 |
+
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 107 |
+
tokenizer (`T5TokenizerFast`):
|
| 108 |
+
Tokenizer of class
|
| 109 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
| 113 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 114 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
transformer: BriaTransformer2DModel,
|
| 119 |
+
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
| 120 |
+
vae: AutoencoderKL,
|
| 121 |
+
text_encoder: T5EncoderModel,
|
| 122 |
+
tokenizer: T5TokenizerFast,
|
| 123 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 124 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 125 |
+
):
|
| 126 |
+
self.register_modules(
|
| 127 |
+
vae=vae,
|
| 128 |
+
text_encoder=text_encoder,
|
| 129 |
+
tokenizer=tokenizer,
|
| 130 |
+
transformer=transformer,
|
| 131 |
+
scheduler=scheduler,
|
| 132 |
+
image_encoder=image_encoder,
|
| 133 |
+
feature_extractor=feature_extractor,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.vae_scale_factor = (
|
| 137 |
+
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
| 138 |
+
)
|
| 139 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 140 |
+
self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
|
| 141 |
+
|
| 142 |
+
if self.vae.config.shift_factor is None:
|
| 143 |
+
self.vae.config.shift_factor = 0
|
| 144 |
+
self.vae.to(dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
def encode_prompt(
|
| 147 |
+
self,
|
| 148 |
+
prompt: Union[str, List[str]],
|
| 149 |
+
device: Optional[torch.device] = None,
|
| 150 |
+
num_images_per_prompt: int = 1,
|
| 151 |
+
do_classifier_free_guidance: bool = True,
|
| 152 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 153 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 154 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 155 |
+
max_sequence_length: int = 128,
|
| 156 |
+
lora_scale: Optional[float] = None,
|
| 157 |
+
):
|
| 158 |
+
r"""
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 162 |
+
prompt to be encoded
|
| 163 |
+
device: (`torch.device`):
|
| 164 |
+
torch device
|
| 165 |
+
num_images_per_prompt (`int`):
|
| 166 |
+
number of images that should be generated per prompt
|
| 167 |
+
do_classifier_free_guidance (`bool`):
|
| 168 |
+
whether to use classifier free guidance or not
|
| 169 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 170 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 171 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 172 |
+
less than `1`).
|
| 173 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 174 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 175 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 176 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 177 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 178 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 179 |
+
argument.
|
| 180 |
+
"""
|
| 181 |
+
device = device or self._execution_device
|
| 182 |
+
|
| 183 |
+
# set lora scale so that monkey patched LoRA
|
| 184 |
+
# function of text encoder can correctly access it
|
| 185 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 186 |
+
self._lora_scale = lora_scale
|
| 187 |
+
|
| 188 |
+
# dynamically adjust the LoRA scale
|
| 189 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 190 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 191 |
+
|
| 192 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 193 |
+
if prompt is not None:
|
| 194 |
+
batch_size = len(prompt)
|
| 195 |
+
else:
|
| 196 |
+
batch_size = prompt_embeds.shape[0]
|
| 197 |
+
|
| 198 |
+
if prompt_embeds is None:
|
| 199 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 200 |
+
prompt=prompt,
|
| 201 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 202 |
+
max_sequence_length=max_sequence_length,
|
| 203 |
+
device=device,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 207 |
+
if not is_ng_none(negative_prompt):
|
| 208 |
+
negative_prompt = (
|
| 209 |
+
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 213 |
+
raise TypeError(
|
| 214 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 215 |
+
f" {type(prompt)}."
|
| 216 |
+
)
|
| 217 |
+
elif batch_size != len(negative_prompt):
|
| 218 |
+
raise ValueError(
|
| 219 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 220 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 221 |
+
" the batch size of `prompt`."
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 225 |
+
prompt=negative_prompt,
|
| 226 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 227 |
+
max_sequence_length=max_sequence_length,
|
| 228 |
+
device=device,
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
| 232 |
+
|
| 233 |
+
if self.text_encoder is not None:
|
| 234 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 235 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 236 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 237 |
+
|
| 238 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device)
|
| 239 |
+
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
| 240 |
+
|
| 241 |
+
return prompt_embeds, negative_prompt_embeds, text_ids
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def guidance_scale(self):
|
| 245 |
+
return self._guidance_scale
|
| 246 |
+
|
| 247 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 248 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 249 |
+
# corresponds to doing no classifier free guidance.
|
| 250 |
+
@property
|
| 251 |
+
def do_classifier_free_guidance(self):
|
| 252 |
+
return self._guidance_scale > 1
|
| 253 |
+
|
| 254 |
+
@property
|
| 255 |
+
def attention_kwargs(self):
|
| 256 |
+
return self._attention_kwargs
|
| 257 |
+
|
| 258 |
+
@attention_kwargs.setter
|
| 259 |
+
def attention_kwargs(self, value):
|
| 260 |
+
self._attention_kwargs = value
|
| 261 |
+
|
| 262 |
+
@property
|
| 263 |
+
def num_timesteps(self):
|
| 264 |
+
return self._num_timesteps
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def interrupt(self):
|
| 268 |
+
return self._interrupt
|
| 269 |
+
|
| 270 |
+
def check_inputs(
|
| 271 |
+
self,
|
| 272 |
+
prompt,
|
| 273 |
+
height,
|
| 274 |
+
width,
|
| 275 |
+
negative_prompt=None,
|
| 276 |
+
prompt_embeds=None,
|
| 277 |
+
negative_prompt_embeds=None,
|
| 278 |
+
callback_on_step_end_tensor_inputs=None,
|
| 279 |
+
max_sequence_length=None,
|
| 280 |
+
):
|
| 281 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 282 |
+
logger.warning(
|
| 283 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 284 |
+
)
|
| 285 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 286 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 287 |
+
):
|
| 288 |
+
raise ValueError(
|
| 289 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if prompt is not None and prompt_embeds is not None:
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 295 |
+
" only forward one of the two."
|
| 296 |
+
)
|
| 297 |
+
elif prompt is None and prompt_embeds is None:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 300 |
+
)
|
| 301 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 302 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 303 |
+
|
| 304 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 307 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 311 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 312 |
+
raise ValueError(
|
| 313 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 314 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 315 |
+
f" {negative_prompt_embeds.shape}."
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 319 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 320 |
+
|
| 321 |
+
def _get_t5_prompt_embeds(
|
| 322 |
+
self,
|
| 323 |
+
prompt: Union[str, List[str]] = None,
|
| 324 |
+
num_images_per_prompt: int = 1,
|
| 325 |
+
max_sequence_length: int = 128,
|
| 326 |
+
device: Optional[torch.device] = None,
|
| 327 |
+
):
|
| 328 |
+
tokenizer = self.tokenizer
|
| 329 |
+
text_encoder = self.text_encoder
|
| 330 |
+
device = device or text_encoder.device
|
| 331 |
+
|
| 332 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 333 |
+
batch_size = len(prompt)
|
| 334 |
+
prompt_embeds_list = []
|
| 335 |
+
for p in prompt:
|
| 336 |
+
text_inputs = tokenizer(
|
| 337 |
+
p,
|
| 338 |
+
# padding="max_length",
|
| 339 |
+
max_length=max_sequence_length,
|
| 340 |
+
truncation=True,
|
| 341 |
+
add_special_tokens=True,
|
| 342 |
+
return_tensors="pt",
|
| 343 |
+
)
|
| 344 |
+
text_input_ids = text_inputs.input_ids
|
| 345 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 346 |
+
|
| 347 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 348 |
+
text_input_ids, untruncated_ids
|
| 349 |
+
):
|
| 350 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 351 |
+
logger.warning(
|
| 352 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 353 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 357 |
+
|
| 358 |
+
# Concat zeros to max_sequence
|
| 359 |
+
b, seq_len, dim = prompt_embeds.shape
|
| 360 |
+
if seq_len < max_sequence_length:
|
| 361 |
+
padding = torch.zeros(
|
| 362 |
+
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
| 363 |
+
)
|
| 364 |
+
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
|
| 365 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 366 |
+
|
| 367 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
|
| 368 |
+
prompt_embeds = prompt_embeds.to(device=device)
|
| 369 |
+
|
| 370 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 371 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 372 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1)
|
| 373 |
+
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
| 374 |
+
return prompt_embeds
|
| 375 |
+
|
| 376 |
+
def prepare_latents(
|
| 377 |
+
self,
|
| 378 |
+
batch_size,
|
| 379 |
+
num_channels_latents,
|
| 380 |
+
height,
|
| 381 |
+
width,
|
| 382 |
+
dtype,
|
| 383 |
+
device,
|
| 384 |
+
generator,
|
| 385 |
+
latents=None,
|
| 386 |
+
):
|
| 387 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 388 |
+
# latent height and width to be divisible by 2.
|
| 389 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
| 390 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
| 391 |
+
|
| 392 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 393 |
+
|
| 394 |
+
if latents is not None:
|
| 395 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 396 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 397 |
+
|
| 398 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 399 |
+
raise ValueError(
|
| 400 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 401 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 405 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 406 |
+
|
| 407 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 408 |
+
|
| 409 |
+
return latents, latent_image_ids
|
| 410 |
+
|
| 411 |
+
@staticmethod
|
| 412 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 413 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 414 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 415 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 416 |
+
|
| 417 |
+
return latents
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 421 |
+
batch_size, num_patches, channels = latents.shape
|
| 422 |
+
|
| 423 |
+
height = height // vae_scale_factor
|
| 424 |
+
width = width // vae_scale_factor
|
| 425 |
+
|
| 426 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
| 427 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 428 |
+
|
| 429 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
| 430 |
+
|
| 431 |
+
return latents
|
| 432 |
+
|
| 433 |
+
@staticmethod
|
| 434 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 435 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 436 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 437 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 438 |
+
|
| 439 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 440 |
+
|
| 441 |
+
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
|
| 442 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 443 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 447 |
+
|
| 448 |
+
@torch.no_grad()
|
| 449 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 450 |
+
def __call__(
|
| 451 |
+
self,
|
| 452 |
+
prompt: Union[str, List[str]] = None,
|
| 453 |
+
height: Optional[int] = None,
|
| 454 |
+
width: Optional[int] = None,
|
| 455 |
+
num_inference_steps: int = 30,
|
| 456 |
+
timesteps: List[int] = None,
|
| 457 |
+
guidance_scale: float = 5,
|
| 458 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 459 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 460 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 461 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 462 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 463 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 464 |
+
output_type: Optional[str] = "pil",
|
| 465 |
+
return_dict: bool = True,
|
| 466 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 467 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 468 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 469 |
+
max_sequence_length: int = 128,
|
| 470 |
+
clip_value: Union[None, float] = None,
|
| 471 |
+
normalize: bool = False,
|
| 472 |
+
):
|
| 473 |
+
r"""
|
| 474 |
+
Function invoked when calling the pipeline for generation.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 478 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 479 |
+
instead.
|
| 480 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 481 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 482 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 483 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 484 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 485 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 486 |
+
expense of slower inference.
|
| 487 |
+
timesteps (`List[int]`, *optional*):
|
| 488 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 489 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 490 |
+
passed will be used. Must be in descending order.
|
| 491 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 492 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 493 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 494 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 495 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 496 |
+
usually at the expense of lower image quality.
|
| 497 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 498 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 499 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 500 |
+
less than `1`).
|
| 501 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 502 |
+
The number of images to generate per prompt.
|
| 503 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 504 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 505 |
+
to make generation deterministic.
|
| 506 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 507 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 508 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 509 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 510 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 511 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 512 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 513 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 514 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 515 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 516 |
+
argument.
|
| 517 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 518 |
+
The output format of the generate image. Choose between
|
| 519 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 520 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 521 |
+
Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple.
|
| 522 |
+
attention_kwargs (`dict`, *optional*):
|
| 523 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 524 |
+
`self.processor` in
|
| 525 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 526 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 527 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 528 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 529 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 530 |
+
`callback_on_step_end_tensor_inputs`.
|
| 531 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 532 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 533 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 534 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 535 |
+
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
| 536 |
+
|
| 537 |
+
Examples:
|
| 538 |
+
|
| 539 |
+
Returns:
|
| 540 |
+
[`~pipelines.bria.BriaPipelineOutput`] or `tuple`: [`~pipelines.bria.BriaPipelineOutput`] if `return_dict`
|
| 541 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 542 |
+
images.
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 546 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 547 |
+
|
| 548 |
+
# 1. Check inputs. Raise error if not correct
|
| 549 |
+
self.check_inputs(
|
| 550 |
+
prompt=prompt,
|
| 551 |
+
height=height,
|
| 552 |
+
width=width,
|
| 553 |
+
prompt_embeds=prompt_embeds,
|
| 554 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 555 |
+
max_sequence_length=max_sequence_length,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
self._guidance_scale = guidance_scale
|
| 559 |
+
self.attention_kwargs = attention_kwargs
|
| 560 |
+
self._interrupt = False
|
| 561 |
+
|
| 562 |
+
# 2. Define call parameters
|
| 563 |
+
if prompt is not None and isinstance(prompt, str):
|
| 564 |
+
batch_size = 1
|
| 565 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 566 |
+
batch_size = len(prompt)
|
| 567 |
+
else:
|
| 568 |
+
batch_size = prompt_embeds.shape[0]
|
| 569 |
+
|
| 570 |
+
device = self._execution_device
|
| 571 |
+
|
| 572 |
+
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
| 573 |
+
|
| 574 |
+
(prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
|
| 575 |
+
prompt=prompt,
|
| 576 |
+
negative_prompt=negative_prompt,
|
| 577 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 578 |
+
prompt_embeds=prompt_embeds,
|
| 579 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 580 |
+
device=device,
|
| 581 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 582 |
+
max_sequence_length=max_sequence_length,
|
| 583 |
+
lora_scale=lora_scale,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
if self.do_classifier_free_guidance:
|
| 587 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 588 |
+
|
| 589 |
+
# 5. Prepare latent variables
|
| 590 |
+
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
|
| 591 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 592 |
+
batch_size * num_images_per_prompt,
|
| 593 |
+
num_channels_latents,
|
| 594 |
+
height,
|
| 595 |
+
width,
|
| 596 |
+
prompt_embeds.dtype,
|
| 597 |
+
device,
|
| 598 |
+
generator,
|
| 599 |
+
latents,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if (
|
| 603 |
+
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
|
| 604 |
+
and self.scheduler.config["use_dynamic_shifting"]
|
| 605 |
+
):
|
| 606 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 607 |
+
image_seq_len = latents.shape[1]
|
| 608 |
+
|
| 609 |
+
mu = calculate_shift(
|
| 610 |
+
image_seq_len,
|
| 611 |
+
self.scheduler.config.base_image_seq_len,
|
| 612 |
+
self.scheduler.config.max_image_seq_len,
|
| 613 |
+
self.scheduler.config.base_shift,
|
| 614 |
+
self.scheduler.config.max_shift,
|
| 615 |
+
)
|
| 616 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 617 |
+
self.scheduler,
|
| 618 |
+
num_inference_steps,
|
| 619 |
+
device,
|
| 620 |
+
timesteps,
|
| 621 |
+
sigmas,
|
| 622 |
+
mu=mu,
|
| 623 |
+
)
|
| 624 |
+
else:
|
| 625 |
+
# 4. Prepare timesteps
|
| 626 |
+
# Sample from training sigmas
|
| 627 |
+
if isinstance(self.scheduler, DDIMScheduler) or isinstance(
|
| 628 |
+
self.scheduler, EulerAncestralDiscreteScheduler
|
| 629 |
+
):
|
| 630 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 631 |
+
self.scheduler, num_inference_steps, device, None, None
|
| 632 |
+
)
|
| 633 |
+
else:
|
| 634 |
+
sigmas = get_original_sigmas(
|
| 635 |
+
num_train_timesteps=self.scheduler.config.num_train_timesteps,
|
| 636 |
+
num_inference_steps=num_inference_steps,
|
| 637 |
+
)
|
| 638 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 639 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 643 |
+
self._num_timesteps = len(timesteps)
|
| 644 |
+
|
| 645 |
+
if len(latent_image_ids.shape) == 3:
|
| 646 |
+
latent_image_ids = latent_image_ids[0]
|
| 647 |
+
if len(text_ids.shape) == 3:
|
| 648 |
+
text_ids = text_ids[0]
|
| 649 |
+
|
| 650 |
+
# 6. Denoising loop
|
| 651 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 652 |
+
for i, t in enumerate(timesteps):
|
| 653 |
+
if self.interrupt:
|
| 654 |
+
continue
|
| 655 |
+
|
| 656 |
+
# expand the latents if we are doing classifier free guidance
|
| 657 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 658 |
+
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
|
| 659 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 660 |
+
|
| 661 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 662 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 663 |
+
|
| 664 |
+
# This is predicts "v" from flow-matching or eps from diffusion
|
| 665 |
+
noise_pred = self.transformer(
|
| 666 |
+
hidden_states=latent_model_input,
|
| 667 |
+
timestep=timestep,
|
| 668 |
+
encoder_hidden_states=prompt_embeds,
|
| 669 |
+
attention_kwargs=self.attention_kwargs,
|
| 670 |
+
return_dict=False,
|
| 671 |
+
txt_ids=text_ids,
|
| 672 |
+
img_ids=latent_image_ids,
|
| 673 |
+
)[0]
|
| 674 |
+
|
| 675 |
+
# perform guidance
|
| 676 |
+
if self.do_classifier_free_guidance:
|
| 677 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 678 |
+
cfg_noise_pred_text = noise_pred_text.std()
|
| 679 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 680 |
+
|
| 681 |
+
if normalize:
|
| 682 |
+
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
|
| 683 |
+
|
| 684 |
+
if clip_value:
|
| 685 |
+
assert clip_value > 0
|
| 686 |
+
noise_pred = noise_pred.clip(-clip_value, clip_value)
|
| 687 |
+
|
| 688 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 689 |
+
latents_dtype = latents.dtype
|
| 690 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 691 |
+
|
| 692 |
+
if latents.dtype != latents_dtype:
|
| 693 |
+
if torch.backends.mps.is_available():
|
| 694 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 695 |
+
latents = latents.to(latents_dtype)
|
| 696 |
+
|
| 697 |
+
if callback_on_step_end is not None:
|
| 698 |
+
callback_kwargs = {}
|
| 699 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 700 |
+
callback_kwargs[k] = locals()[k]
|
| 701 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 702 |
+
|
| 703 |
+
latents = callback_outputs.pop("latents", latents)
|
| 704 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 705 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 706 |
+
|
| 707 |
+
# call the callback, if provided
|
| 708 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 709 |
+
progress_bar.update()
|
| 710 |
+
|
| 711 |
+
if XLA_AVAILABLE:
|
| 712 |
+
xm.mark_step()
|
| 713 |
+
|
| 714 |
+
if output_type == "latent":
|
| 715 |
+
image = latents
|
| 716 |
+
|
| 717 |
+
else:
|
| 718 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 719 |
+
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 720 |
+
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
|
| 721 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 722 |
+
|
| 723 |
+
# Offload all models
|
| 724 |
+
self.maybe_free_model_hooks()
|
| 725 |
+
|
| 726 |
+
if not return_dict:
|
| 727 |
+
return (image,)
|
| 728 |
+
|
| 729 |
+
return BriaPipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_output.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from ...utils import BaseOutput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class BriaPipelineOutput(BaseOutput):
|
| 12 |
+
"""
|
| 13 |
+
Output class for Bria pipelines.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 17 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 18 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_additional_imports = {}
|
| 15 |
+
_import_structure = {"pipeline_output": ["ChromaPipelineOutput"]}
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
| 26 |
+
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
|
| 27 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 28 |
+
try:
|
| 29 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 30 |
+
raise OptionalDependencyNotAvailable()
|
| 31 |
+
except OptionalDependencyNotAvailable:
|
| 32 |
+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
| 33 |
+
else:
|
| 34 |
+
from .pipeline_chroma import ChromaPipeline
|
| 35 |
+
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
|
| 36 |
+
else:
|
| 37 |
+
import sys
|
| 38 |
+
|
| 39 |
+
sys.modules[__name__] = _LazyModule(
|
| 40 |
+
__name__,
|
| 41 |
+
globals()["__file__"],
|
| 42 |
+
_import_structure,
|
| 43 |
+
module_spec=__spec__,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
for name, value in _dummy_objects.items():
|
| 47 |
+
setattr(sys.modules[__name__], name, value)
|
| 48 |
+
for name, value in _additional_imports.items():
|
| 49 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma.py
ADDED
|
@@ -0,0 +1,949 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
|
| 21 |
+
|
| 22 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 23 |
+
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 24 |
+
from ...models import AutoencoderKL, ChromaTransformer2DModel
|
| 25 |
+
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 26 |
+
from ...utils import (
|
| 27 |
+
USE_PEFT_BACKEND,
|
| 28 |
+
is_torch_xla_available,
|
| 29 |
+
logging,
|
| 30 |
+
replace_example_docstring,
|
| 31 |
+
scale_lora_layers,
|
| 32 |
+
unscale_lora_layers,
|
| 33 |
+
)
|
| 34 |
+
from ...utils.torch_utils import randn_tensor
|
| 35 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 36 |
+
from .pipeline_output import ChromaPipelineOutput
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if is_torch_xla_available():
|
| 40 |
+
import torch_xla.core.xla_model as xm
|
| 41 |
+
|
| 42 |
+
XLA_AVAILABLE = True
|
| 43 |
+
else:
|
| 44 |
+
XLA_AVAILABLE = False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 48 |
+
|
| 49 |
+
EXAMPLE_DOC_STRING = """
|
| 50 |
+
Examples:
|
| 51 |
+
```py
|
| 52 |
+
>>> import torch
|
| 53 |
+
>>> from diffusers import ChromaPipeline
|
| 54 |
+
|
| 55 |
+
>>> model_id = "lodestones/Chroma"
|
| 56 |
+
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
| 57 |
+
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
| 58 |
+
>>> pipe = ChromaPipeline.from_pretrained(
|
| 59 |
+
... model_id,
|
| 60 |
+
... transformer=transformer,
|
| 61 |
+
... torch_dtype=torch.bfloat16,
|
| 62 |
+
... )
|
| 63 |
+
>>> pipe.enable_model_cpu_offload()
|
| 64 |
+
>>> prompt = [
|
| 65 |
+
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
| 66 |
+
... ]
|
| 67 |
+
>>> negative_prompt = [
|
| 68 |
+
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
| 69 |
+
... ]
|
| 70 |
+
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
|
| 71 |
+
>>> image.save("chroma.png")
|
| 72 |
+
```
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 77 |
+
def calculate_shift(
|
| 78 |
+
image_seq_len,
|
| 79 |
+
base_seq_len: int = 256,
|
| 80 |
+
max_seq_len: int = 4096,
|
| 81 |
+
base_shift: float = 0.5,
|
| 82 |
+
max_shift: float = 1.15,
|
| 83 |
+
):
|
| 84 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 85 |
+
b = base_shift - m * base_seq_len
|
| 86 |
+
mu = image_seq_len * m + b
|
| 87 |
+
return mu
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 91 |
+
def retrieve_timesteps(
|
| 92 |
+
scheduler,
|
| 93 |
+
num_inference_steps: Optional[int] = None,
|
| 94 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 95 |
+
timesteps: Optional[List[int]] = None,
|
| 96 |
+
sigmas: Optional[List[float]] = None,
|
| 97 |
+
**kwargs,
|
| 98 |
+
):
|
| 99 |
+
r"""
|
| 100 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 101 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
scheduler (`SchedulerMixin`):
|
| 105 |
+
The scheduler to get timesteps from.
|
| 106 |
+
num_inference_steps (`int`):
|
| 107 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 108 |
+
must be `None`.
|
| 109 |
+
device (`str` or `torch.device`, *optional*):
|
| 110 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 111 |
+
timesteps (`List[int]`, *optional*):
|
| 112 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 113 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 114 |
+
sigmas (`List[float]`, *optional*):
|
| 115 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 116 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 120 |
+
second element is the number of inference steps.
|
| 121 |
+
"""
|
| 122 |
+
if timesteps is not None and sigmas is not None:
|
| 123 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 124 |
+
if timesteps is not None:
|
| 125 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 126 |
+
if not accepts_timesteps:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 129 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 130 |
+
)
|
| 131 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 132 |
+
timesteps = scheduler.timesteps
|
| 133 |
+
num_inference_steps = len(timesteps)
|
| 134 |
+
elif sigmas is not None:
|
| 135 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 136 |
+
if not accept_sigmas:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 139 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 140 |
+
)
|
| 141 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 142 |
+
timesteps = scheduler.timesteps
|
| 143 |
+
num_inference_steps = len(timesteps)
|
| 144 |
+
else:
|
| 145 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 146 |
+
timesteps = scheduler.timesteps
|
| 147 |
+
return timesteps, num_inference_steps
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ChromaPipeline(
|
| 151 |
+
DiffusionPipeline,
|
| 152 |
+
FluxLoraLoaderMixin,
|
| 153 |
+
FromSingleFileMixin,
|
| 154 |
+
TextualInversionLoaderMixin,
|
| 155 |
+
FluxIPAdapterMixin,
|
| 156 |
+
):
|
| 157 |
+
r"""
|
| 158 |
+
The Chroma pipeline for text-to-image generation.
|
| 159 |
+
|
| 160 |
+
Reference: https://huggingface.co/lodestones/Chroma/
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
transformer ([`ChromaTransformer2DModel`]):
|
| 164 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 165 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 166 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 167 |
+
vae ([`AutoencoderKL`]):
|
| 168 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
|
| 169 |
+
text_encoder ([`T5EncoderModel`]):
|
| 170 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 171 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 172 |
+
tokenizer (`T5TokenizerFast`):
|
| 173 |
+
Second Tokenizer of class
|
| 174 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
| 178 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 179 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 184 |
+
vae: AutoencoderKL,
|
| 185 |
+
text_encoder: T5EncoderModel,
|
| 186 |
+
tokenizer: T5TokenizerFast,
|
| 187 |
+
transformer: ChromaTransformer2DModel,
|
| 188 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 189 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 190 |
+
):
|
| 191 |
+
super().__init__()
|
| 192 |
+
|
| 193 |
+
self.register_modules(
|
| 194 |
+
vae=vae,
|
| 195 |
+
text_encoder=text_encoder,
|
| 196 |
+
tokenizer=tokenizer,
|
| 197 |
+
transformer=transformer,
|
| 198 |
+
scheduler=scheduler,
|
| 199 |
+
image_encoder=image_encoder,
|
| 200 |
+
feature_extractor=feature_extractor,
|
| 201 |
+
)
|
| 202 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 203 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 204 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 205 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 206 |
+
self.default_sample_size = 128
|
| 207 |
+
|
| 208 |
+
def _get_t5_prompt_embeds(
|
| 209 |
+
self,
|
| 210 |
+
prompt: Union[str, List[str]] = None,
|
| 211 |
+
num_images_per_prompt: int = 1,
|
| 212 |
+
max_sequence_length: int = 512,
|
| 213 |
+
device: Optional[torch.device] = None,
|
| 214 |
+
dtype: Optional[torch.dtype] = None,
|
| 215 |
+
):
|
| 216 |
+
device = device or self._execution_device
|
| 217 |
+
dtype = dtype or self.text_encoder.dtype
|
| 218 |
+
|
| 219 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 220 |
+
batch_size = len(prompt)
|
| 221 |
+
|
| 222 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 223 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 224 |
+
|
| 225 |
+
text_inputs = self.tokenizer(
|
| 226 |
+
prompt,
|
| 227 |
+
padding="max_length",
|
| 228 |
+
max_length=max_sequence_length,
|
| 229 |
+
truncation=True,
|
| 230 |
+
return_length=False,
|
| 231 |
+
return_overflowing_tokens=False,
|
| 232 |
+
return_tensors="pt",
|
| 233 |
+
)
|
| 234 |
+
text_input_ids = text_inputs.input_ids
|
| 235 |
+
attention_mask = text_inputs.attention_mask.clone()
|
| 236 |
+
|
| 237 |
+
# Chroma requires the attention mask to include one padding token
|
| 238 |
+
seq_lengths = attention_mask.sum(dim=1)
|
| 239 |
+
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
| 240 |
+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
| 241 |
+
|
| 242 |
+
prompt_embeds = self.text_encoder(
|
| 243 |
+
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
| 244 |
+
)[0]
|
| 245 |
+
|
| 246 |
+
dtype = self.text_encoder.dtype
|
| 247 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 248 |
+
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
| 249 |
+
|
| 250 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 251 |
+
|
| 252 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 253 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 254 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 255 |
+
|
| 256 |
+
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
| 257 |
+
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
|
| 258 |
+
|
| 259 |
+
return prompt_embeds, attention_mask
|
| 260 |
+
|
| 261 |
+
def encode_prompt(
|
| 262 |
+
self,
|
| 263 |
+
prompt: Union[str, List[str]],
|
| 264 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 265 |
+
device: Optional[torch.device] = None,
|
| 266 |
+
num_images_per_prompt: int = 1,
|
| 267 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 268 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 269 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 270 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 271 |
+
do_classifier_free_guidance: bool = True,
|
| 272 |
+
max_sequence_length: int = 512,
|
| 273 |
+
lora_scale: Optional[float] = None,
|
| 274 |
+
):
|
| 275 |
+
r"""
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 279 |
+
prompt to be encoded
|
| 280 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 281 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
| 282 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
| 283 |
+
device: (`torch.device`):
|
| 284 |
+
torch device
|
| 285 |
+
num_images_per_prompt (`int`):
|
| 286 |
+
number of images that should be generated per prompt
|
| 287 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 288 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 289 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 290 |
+
lora_scale (`float`, *optional*):
|
| 291 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 292 |
+
"""
|
| 293 |
+
device = device or self._execution_device
|
| 294 |
+
|
| 295 |
+
# set lora scale so that monkey patched LoRA
|
| 296 |
+
# function of text encoder can correctly access it
|
| 297 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 298 |
+
self._lora_scale = lora_scale
|
| 299 |
+
|
| 300 |
+
# dynamically adjust the LoRA scale
|
| 301 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 302 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 303 |
+
|
| 304 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 305 |
+
|
| 306 |
+
if prompt is not None:
|
| 307 |
+
batch_size = len(prompt)
|
| 308 |
+
else:
|
| 309 |
+
batch_size = prompt_embeds.shape[0]
|
| 310 |
+
|
| 311 |
+
if prompt_embeds is None:
|
| 312 |
+
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
| 313 |
+
prompt=prompt,
|
| 314 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 315 |
+
max_sequence_length=max_sequence_length,
|
| 316 |
+
device=device,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 320 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 321 |
+
negative_text_ids = None
|
| 322 |
+
|
| 323 |
+
if do_classifier_free_guidance:
|
| 324 |
+
if negative_prompt_embeds is None:
|
| 325 |
+
negative_prompt = negative_prompt or ""
|
| 326 |
+
negative_prompt = (
|
| 327 |
+
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 331 |
+
raise TypeError(
|
| 332 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 333 |
+
f" {type(prompt)}."
|
| 334 |
+
)
|
| 335 |
+
elif batch_size != len(negative_prompt):
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 338 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 339 |
+
" the batch size of `prompt`."
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
| 343 |
+
prompt=negative_prompt,
|
| 344 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 345 |
+
max_sequence_length=max_sequence_length,
|
| 346 |
+
device=device,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 350 |
+
|
| 351 |
+
if self.text_encoder is not None:
|
| 352 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 353 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 354 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 355 |
+
|
| 356 |
+
return (
|
| 357 |
+
prompt_embeds,
|
| 358 |
+
text_ids,
|
| 359 |
+
prompt_attention_mask,
|
| 360 |
+
negative_prompt_embeds,
|
| 361 |
+
negative_text_ids,
|
| 362 |
+
negative_prompt_attention_mask,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
| 366 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
| 367 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 368 |
+
|
| 369 |
+
if not isinstance(image, torch.Tensor):
|
| 370 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 371 |
+
|
| 372 |
+
image = image.to(device=device, dtype=dtype)
|
| 373 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 374 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 375 |
+
return image_embeds
|
| 376 |
+
|
| 377 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
| 378 |
+
def prepare_ip_adapter_image_embeds(
|
| 379 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
| 380 |
+
):
|
| 381 |
+
image_embeds = []
|
| 382 |
+
if ip_adapter_image_embeds is None:
|
| 383 |
+
if not isinstance(ip_adapter_image, list):
|
| 384 |
+
ip_adapter_image = [ip_adapter_image]
|
| 385 |
+
|
| 386 |
+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
for single_ip_adapter_image in ip_adapter_image:
|
| 392 |
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
| 393 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 394 |
+
else:
|
| 395 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 396 |
+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
| 397 |
+
|
| 398 |
+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 399 |
+
raise ValueError(
|
| 400 |
+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 404 |
+
image_embeds.append(single_image_embeds)
|
| 405 |
+
|
| 406 |
+
ip_adapter_image_embeds = []
|
| 407 |
+
for single_image_embeds in image_embeds:
|
| 408 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 409 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 410 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 411 |
+
|
| 412 |
+
return ip_adapter_image_embeds
|
| 413 |
+
|
| 414 |
+
def check_inputs(
|
| 415 |
+
self,
|
| 416 |
+
prompt,
|
| 417 |
+
height,
|
| 418 |
+
width,
|
| 419 |
+
negative_prompt=None,
|
| 420 |
+
prompt_embeds=None,
|
| 421 |
+
prompt_attention_mask=None,
|
| 422 |
+
negative_prompt_embeds=None,
|
| 423 |
+
negative_prompt_attention_mask=None,
|
| 424 |
+
callback_on_step_end_tensor_inputs=None,
|
| 425 |
+
max_sequence_length=None,
|
| 426 |
+
):
|
| 427 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 428 |
+
logger.warning(
|
| 429 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 433 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 434 |
+
):
|
| 435 |
+
raise ValueError(
|
| 436 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if prompt is not None and prompt_embeds is not None:
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 442 |
+
" only forward one of the two."
|
| 443 |
+
)
|
| 444 |
+
elif prompt is None and prompt_embeds is None:
|
| 445 |
+
raise ValueError(
|
| 446 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 447 |
+
)
|
| 448 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 449 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 450 |
+
|
| 451 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 452 |
+
raise ValueError(
|
| 453 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 454 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 458 |
+
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
| 459 |
+
|
| 460 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 461 |
+
raise ValueError(
|
| 462 |
+
"Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 466 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 467 |
+
|
| 468 |
+
@staticmethod
|
| 469 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 470 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 471 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 472 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 473 |
+
|
| 474 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 475 |
+
|
| 476 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 477 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 481 |
+
|
| 482 |
+
@staticmethod
|
| 483 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 484 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 485 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 486 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 487 |
+
|
| 488 |
+
return latents
|
| 489 |
+
|
| 490 |
+
@staticmethod
|
| 491 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 492 |
+
batch_size, num_patches, channels = latents.shape
|
| 493 |
+
|
| 494 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 495 |
+
# latent height and width to be divisible by 2.
|
| 496 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 497 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 498 |
+
|
| 499 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 500 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 501 |
+
|
| 502 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 503 |
+
|
| 504 |
+
return latents
|
| 505 |
+
|
| 506 |
+
def enable_vae_slicing(self):
|
| 507 |
+
r"""
|
| 508 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 509 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 510 |
+
"""
|
| 511 |
+
self.vae.enable_slicing()
|
| 512 |
+
|
| 513 |
+
def disable_vae_slicing(self):
|
| 514 |
+
r"""
|
| 515 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 516 |
+
computing decoding in one step.
|
| 517 |
+
"""
|
| 518 |
+
self.vae.disable_slicing()
|
| 519 |
+
|
| 520 |
+
def enable_vae_tiling(self):
|
| 521 |
+
r"""
|
| 522 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 523 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 524 |
+
processing larger images.
|
| 525 |
+
"""
|
| 526 |
+
self.vae.enable_tiling()
|
| 527 |
+
|
| 528 |
+
def disable_vae_tiling(self):
|
| 529 |
+
r"""
|
| 530 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 531 |
+
computing decoding in one step.
|
| 532 |
+
"""
|
| 533 |
+
self.vae.disable_tiling()
|
| 534 |
+
|
| 535 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
| 536 |
+
def prepare_latents(
|
| 537 |
+
self,
|
| 538 |
+
batch_size,
|
| 539 |
+
num_channels_latents,
|
| 540 |
+
height,
|
| 541 |
+
width,
|
| 542 |
+
dtype,
|
| 543 |
+
device,
|
| 544 |
+
generator,
|
| 545 |
+
latents=None,
|
| 546 |
+
):
|
| 547 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 548 |
+
# latent height and width to be divisible by 2.
|
| 549 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 550 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 551 |
+
|
| 552 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 553 |
+
|
| 554 |
+
if latents is not None:
|
| 555 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 556 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 557 |
+
|
| 558 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 559 |
+
raise ValueError(
|
| 560 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 561 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 565 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 566 |
+
|
| 567 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 568 |
+
|
| 569 |
+
return latents, latent_image_ids
|
| 570 |
+
|
| 571 |
+
def _prepare_attention_mask(
|
| 572 |
+
self,
|
| 573 |
+
batch_size,
|
| 574 |
+
sequence_length,
|
| 575 |
+
dtype,
|
| 576 |
+
attention_mask=None,
|
| 577 |
+
):
|
| 578 |
+
if attention_mask is None:
|
| 579 |
+
return attention_mask
|
| 580 |
+
|
| 581 |
+
# Extend the prompt attention mask to account for image tokens in the final sequence
|
| 582 |
+
attention_mask = torch.cat(
|
| 583 |
+
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
|
| 584 |
+
dim=1,
|
| 585 |
+
)
|
| 586 |
+
attention_mask = attention_mask.to(dtype)
|
| 587 |
+
|
| 588 |
+
return attention_mask
|
| 589 |
+
|
| 590 |
+
@property
|
| 591 |
+
def guidance_scale(self):
|
| 592 |
+
return self._guidance_scale
|
| 593 |
+
|
| 594 |
+
@property
|
| 595 |
+
def joint_attention_kwargs(self):
|
| 596 |
+
return self._joint_attention_kwargs
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def do_classifier_free_guidance(self):
|
| 600 |
+
return self._guidance_scale > 1
|
| 601 |
+
|
| 602 |
+
@property
|
| 603 |
+
def num_timesteps(self):
|
| 604 |
+
return self._num_timesteps
|
| 605 |
+
|
| 606 |
+
@property
|
| 607 |
+
def current_timestep(self):
|
| 608 |
+
return self._current_timestep
|
| 609 |
+
|
| 610 |
+
@property
|
| 611 |
+
def interrupt(self):
|
| 612 |
+
return self._interrupt
|
| 613 |
+
|
| 614 |
+
@torch.no_grad()
|
| 615 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 616 |
+
def __call__(
|
| 617 |
+
self,
|
| 618 |
+
prompt: Union[str, List[str]] = None,
|
| 619 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 620 |
+
height: Optional[int] = None,
|
| 621 |
+
width: Optional[int] = None,
|
| 622 |
+
num_inference_steps: int = 35,
|
| 623 |
+
sigmas: Optional[List[float]] = None,
|
| 624 |
+
guidance_scale: float = 5.0,
|
| 625 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 626 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 627 |
+
latents: Optional[torch.Tensor] = None,
|
| 628 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 629 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 630 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 631 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 632 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 633 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 634 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 635 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 636 |
+
output_type: Optional[str] = "pil",
|
| 637 |
+
return_dict: bool = True,
|
| 638 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 639 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 640 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 641 |
+
max_sequence_length: int = 512,
|
| 642 |
+
):
|
| 643 |
+
r"""
|
| 644 |
+
Function invoked when calling the pipeline for generation.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 648 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 649 |
+
instead.
|
| 650 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 651 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 652 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 653 |
+
not greater than `1`).
|
| 654 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 655 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 656 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 657 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 658 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 659 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 660 |
+
expense of slower inference.
|
| 661 |
+
sigmas (`List[float]`, *optional*):
|
| 662 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 663 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 664 |
+
will be used.
|
| 665 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 666 |
+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
| 667 |
+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
| 668 |
+
|
| 669 |
+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
| 670 |
+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
| 671 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 672 |
+
The number of images to generate per prompt.
|
| 673 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 674 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 675 |
+
to make generation deterministic.
|
| 676 |
+
latents (`torch.Tensor`, *optional*):
|
| 677 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 678 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 679 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 680 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 681 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 682 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 683 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 684 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 685 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 686 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 687 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 688 |
+
negative_ip_adapter_image:
|
| 689 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 690 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 691 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 692 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 693 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 694 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 695 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 696 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 697 |
+
argument.
|
| 698 |
+
prompt_attention_mask (torch.Tensor, *optional*):
|
| 699 |
+
Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
|
| 700 |
+
Chroma requires a single padding token remain unmasked. Please refer to
|
| 701 |
+
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
| 702 |
+
negative_prompt_attention_mask (torch.Tensor, *optional*):
|
| 703 |
+
Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
|
| 704 |
+
prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
|
| 705 |
+
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
| 706 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 707 |
+
The output format of the generate image. Choose between
|
| 708 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 709 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 710 |
+
Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
|
| 711 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 712 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 713 |
+
`self.processor` in
|
| 714 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 715 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 716 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 717 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 718 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 719 |
+
`callback_on_step_end_tensor_inputs`.
|
| 720 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 721 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 722 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 723 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 724 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 725 |
+
|
| 726 |
+
Examples:
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
[`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
|
| 730 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 731 |
+
generated images.
|
| 732 |
+
"""
|
| 733 |
+
|
| 734 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 735 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 736 |
+
|
| 737 |
+
# 1. Check inputs. Raise error if not correct
|
| 738 |
+
self.check_inputs(
|
| 739 |
+
prompt,
|
| 740 |
+
height,
|
| 741 |
+
width,
|
| 742 |
+
negative_prompt=negative_prompt,
|
| 743 |
+
prompt_embeds=prompt_embeds,
|
| 744 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 745 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 746 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 747 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 748 |
+
max_sequence_length=max_sequence_length,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
self._guidance_scale = guidance_scale
|
| 752 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 753 |
+
self._current_timestep = None
|
| 754 |
+
self._interrupt = False
|
| 755 |
+
|
| 756 |
+
# 2. Define call parameters
|
| 757 |
+
if prompt is not None and isinstance(prompt, str):
|
| 758 |
+
batch_size = 1
|
| 759 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 760 |
+
batch_size = len(prompt)
|
| 761 |
+
else:
|
| 762 |
+
batch_size = prompt_embeds.shape[0]
|
| 763 |
+
|
| 764 |
+
device = self._execution_device
|
| 765 |
+
|
| 766 |
+
lora_scale = (
|
| 767 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 768 |
+
)
|
| 769 |
+
(
|
| 770 |
+
prompt_embeds,
|
| 771 |
+
text_ids,
|
| 772 |
+
prompt_attention_mask,
|
| 773 |
+
negative_prompt_embeds,
|
| 774 |
+
negative_text_ids,
|
| 775 |
+
negative_prompt_attention_mask,
|
| 776 |
+
) = self.encode_prompt(
|
| 777 |
+
prompt=prompt,
|
| 778 |
+
negative_prompt=negative_prompt,
|
| 779 |
+
prompt_embeds=prompt_embeds,
|
| 780 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 781 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 782 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 783 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 784 |
+
device=device,
|
| 785 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 786 |
+
max_sequence_length=max_sequence_length,
|
| 787 |
+
lora_scale=lora_scale,
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
# 4. Prepare latent variables
|
| 791 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 792 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 793 |
+
batch_size * num_images_per_prompt,
|
| 794 |
+
num_channels_latents,
|
| 795 |
+
height,
|
| 796 |
+
width,
|
| 797 |
+
prompt_embeds.dtype,
|
| 798 |
+
device,
|
| 799 |
+
generator,
|
| 800 |
+
latents,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
# 5. Prepare timesteps
|
| 804 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 805 |
+
image_seq_len = latents.shape[1]
|
| 806 |
+
mu = calculate_shift(
|
| 807 |
+
image_seq_len,
|
| 808 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 809 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 810 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 811 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
attention_mask = self._prepare_attention_mask(
|
| 815 |
+
batch_size=latents.shape[0],
|
| 816 |
+
sequence_length=image_seq_len,
|
| 817 |
+
dtype=latents.dtype,
|
| 818 |
+
attention_mask=prompt_attention_mask,
|
| 819 |
+
)
|
| 820 |
+
negative_attention_mask = self._prepare_attention_mask(
|
| 821 |
+
batch_size=latents.shape[0],
|
| 822 |
+
sequence_length=image_seq_len,
|
| 823 |
+
dtype=latents.dtype,
|
| 824 |
+
attention_mask=negative_prompt_attention_mask,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 828 |
+
self.scheduler,
|
| 829 |
+
num_inference_steps,
|
| 830 |
+
device,
|
| 831 |
+
sigmas=sigmas,
|
| 832 |
+
mu=mu,
|
| 833 |
+
)
|
| 834 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 835 |
+
self._num_timesteps = len(timesteps)
|
| 836 |
+
|
| 837 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 838 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 839 |
+
):
|
| 840 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 841 |
+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 842 |
+
|
| 843 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 844 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 845 |
+
):
|
| 846 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 847 |
+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 848 |
+
|
| 849 |
+
if self.joint_attention_kwargs is None:
|
| 850 |
+
self._joint_attention_kwargs = {}
|
| 851 |
+
|
| 852 |
+
image_embeds = None
|
| 853 |
+
negative_image_embeds = None
|
| 854 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 855 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 856 |
+
ip_adapter_image,
|
| 857 |
+
ip_adapter_image_embeds,
|
| 858 |
+
device,
|
| 859 |
+
batch_size * num_images_per_prompt,
|
| 860 |
+
)
|
| 861 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 862 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 863 |
+
negative_ip_adapter_image,
|
| 864 |
+
negative_ip_adapter_image_embeds,
|
| 865 |
+
device,
|
| 866 |
+
batch_size * num_images_per_prompt,
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# 6. Denoising loop
|
| 870 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 871 |
+
for i, t in enumerate(timesteps):
|
| 872 |
+
if self.interrupt:
|
| 873 |
+
continue
|
| 874 |
+
|
| 875 |
+
self._current_timestep = t
|
| 876 |
+
if image_embeds is not None:
|
| 877 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 878 |
+
|
| 879 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 880 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 881 |
+
|
| 882 |
+
noise_pred = self.transformer(
|
| 883 |
+
hidden_states=latents,
|
| 884 |
+
timestep=timestep / 1000,
|
| 885 |
+
encoder_hidden_states=prompt_embeds,
|
| 886 |
+
txt_ids=text_ids,
|
| 887 |
+
img_ids=latent_image_ids,
|
| 888 |
+
attention_mask=attention_mask,
|
| 889 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 890 |
+
return_dict=False,
|
| 891 |
+
)[0]
|
| 892 |
+
|
| 893 |
+
if self.do_classifier_free_guidance:
|
| 894 |
+
if negative_image_embeds is not None:
|
| 895 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 896 |
+
neg_noise_pred = self.transformer(
|
| 897 |
+
hidden_states=latents,
|
| 898 |
+
timestep=timestep / 1000,
|
| 899 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 900 |
+
txt_ids=negative_text_ids,
|
| 901 |
+
img_ids=latent_image_ids,
|
| 902 |
+
attention_mask=negative_attention_mask,
|
| 903 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 904 |
+
return_dict=False,
|
| 905 |
+
)[0]
|
| 906 |
+
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
|
| 907 |
+
|
| 908 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 909 |
+
latents_dtype = latents.dtype
|
| 910 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 911 |
+
|
| 912 |
+
if latents.dtype != latents_dtype:
|
| 913 |
+
if torch.backends.mps.is_available():
|
| 914 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 915 |
+
latents = latents.to(latents_dtype)
|
| 916 |
+
|
| 917 |
+
if callback_on_step_end is not None:
|
| 918 |
+
callback_kwargs = {}
|
| 919 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 920 |
+
callback_kwargs[k] = locals()[k]
|
| 921 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 922 |
+
|
| 923 |
+
latents = callback_outputs.pop("latents", latents)
|
| 924 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 925 |
+
|
| 926 |
+
# call the callback, if provided
|
| 927 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 928 |
+
progress_bar.update()
|
| 929 |
+
|
| 930 |
+
if XLA_AVAILABLE:
|
| 931 |
+
xm.mark_step()
|
| 932 |
+
|
| 933 |
+
self._current_timestep = None
|
| 934 |
+
|
| 935 |
+
if output_type == "latent":
|
| 936 |
+
image = latents
|
| 937 |
+
else:
|
| 938 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 939 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 940 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 941 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 942 |
+
|
| 943 |
+
# Offload all models
|
| 944 |
+
self.maybe_free_model_hooks()
|
| 945 |
+
|
| 946 |
+
if not return_dict:
|
| 947 |
+
return (image,)
|
| 948 |
+
|
| 949 |
+
return ChromaPipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
ADDED
|
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
|
| 21 |
+
|
| 22 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 23 |
+
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 24 |
+
from ...models import AutoencoderKL, ChromaTransformer2DModel
|
| 25 |
+
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 26 |
+
from ...utils import (
|
| 27 |
+
USE_PEFT_BACKEND,
|
| 28 |
+
is_torch_xla_available,
|
| 29 |
+
logging,
|
| 30 |
+
replace_example_docstring,
|
| 31 |
+
scale_lora_layers,
|
| 32 |
+
unscale_lora_layers,
|
| 33 |
+
)
|
| 34 |
+
from ...utils.torch_utils import randn_tensor
|
| 35 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 36 |
+
from .pipeline_output import ChromaPipelineOutput
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if is_torch_xla_available():
|
| 40 |
+
import torch_xla.core.xla_model as xm
|
| 41 |
+
|
| 42 |
+
XLA_AVAILABLE = True
|
| 43 |
+
else:
|
| 44 |
+
XLA_AVAILABLE = False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 48 |
+
|
| 49 |
+
EXAMPLE_DOC_STRING = """
|
| 50 |
+
Examples:
|
| 51 |
+
```py
|
| 52 |
+
>>> import torch
|
| 53 |
+
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
|
| 54 |
+
|
| 55 |
+
>>> model_id = "lodestones/Chroma"
|
| 56 |
+
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
| 57 |
+
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
| 58 |
+
... model_id,
|
| 59 |
+
... transformer=transformer,
|
| 60 |
+
... torch_dtype=torch.bfloat16,
|
| 61 |
+
... )
|
| 62 |
+
>>> pipe.enable_model_cpu_offload()
|
| 63 |
+
>>> init_image = load_image(
|
| 64 |
+
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
| 65 |
+
... )
|
| 66 |
+
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
|
| 67 |
+
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
| 68 |
+
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
|
| 69 |
+
>>> image.save("chroma-img2img.png")
|
| 70 |
+
```
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 75 |
+
def calculate_shift(
|
| 76 |
+
image_seq_len,
|
| 77 |
+
base_seq_len: int = 256,
|
| 78 |
+
max_seq_len: int = 4096,
|
| 79 |
+
base_shift: float = 0.5,
|
| 80 |
+
max_shift: float = 1.15,
|
| 81 |
+
):
|
| 82 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 83 |
+
b = base_shift - m * base_seq_len
|
| 84 |
+
mu = image_seq_len * m + b
|
| 85 |
+
return mu
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 89 |
+
def retrieve_latents(
|
| 90 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 91 |
+
):
|
| 92 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 93 |
+
return encoder_output.latent_dist.sample(generator)
|
| 94 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 95 |
+
return encoder_output.latent_dist.mode()
|
| 96 |
+
elif hasattr(encoder_output, "latents"):
|
| 97 |
+
return encoder_output.latents
|
| 98 |
+
else:
|
| 99 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 103 |
+
def retrieve_timesteps(
|
| 104 |
+
scheduler,
|
| 105 |
+
num_inference_steps: Optional[int] = None,
|
| 106 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 107 |
+
timesteps: Optional[List[int]] = None,
|
| 108 |
+
sigmas: Optional[List[float]] = None,
|
| 109 |
+
**kwargs,
|
| 110 |
+
):
|
| 111 |
+
r"""
|
| 112 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 113 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
scheduler (`SchedulerMixin`):
|
| 117 |
+
The scheduler to get timesteps from.
|
| 118 |
+
num_inference_steps (`int`):
|
| 119 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 120 |
+
must be `None`.
|
| 121 |
+
device (`str` or `torch.device`, *optional*):
|
| 122 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 123 |
+
timesteps (`List[int]`, *optional*):
|
| 124 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 125 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 126 |
+
sigmas (`List[float]`, *optional*):
|
| 127 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 128 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 132 |
+
second element is the number of inference steps.
|
| 133 |
+
"""
|
| 134 |
+
if timesteps is not None and sigmas is not None:
|
| 135 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 136 |
+
if timesteps is not None:
|
| 137 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 138 |
+
if not accepts_timesteps:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 141 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 142 |
+
)
|
| 143 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 144 |
+
timesteps = scheduler.timesteps
|
| 145 |
+
num_inference_steps = len(timesteps)
|
| 146 |
+
elif sigmas is not None:
|
| 147 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 148 |
+
if not accept_sigmas:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 151 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 152 |
+
)
|
| 153 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 154 |
+
timesteps = scheduler.timesteps
|
| 155 |
+
num_inference_steps = len(timesteps)
|
| 156 |
+
else:
|
| 157 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 158 |
+
timesteps = scheduler.timesteps
|
| 159 |
+
return timesteps, num_inference_steps
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class ChromaImg2ImgPipeline(
|
| 163 |
+
DiffusionPipeline,
|
| 164 |
+
FluxLoraLoaderMixin,
|
| 165 |
+
FromSingleFileMixin,
|
| 166 |
+
TextualInversionLoaderMixin,
|
| 167 |
+
FluxIPAdapterMixin,
|
| 168 |
+
):
|
| 169 |
+
r"""
|
| 170 |
+
The Chroma pipeline for image-to-image generation.
|
| 171 |
+
|
| 172 |
+
Reference: https://huggingface.co/lodestones/Chroma/
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
transformer ([`ChromaTransformer2DModel`]):
|
| 176 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 177 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 178 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 179 |
+
vae ([`AutoencoderKL`]):
|
| 180 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
|
| 181 |
+
text_encoder ([`T5EncoderModel`]):
|
| 182 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 183 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 184 |
+
tokenizer (`T5TokenizerFast`):
|
| 185 |
+
Second Tokenizer of class
|
| 186 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
| 190 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 191 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 196 |
+
vae: AutoencoderKL,
|
| 197 |
+
text_encoder: T5EncoderModel,
|
| 198 |
+
tokenizer: T5TokenizerFast,
|
| 199 |
+
transformer: ChromaTransformer2DModel,
|
| 200 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 201 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
|
| 205 |
+
self.register_modules(
|
| 206 |
+
vae=vae,
|
| 207 |
+
text_encoder=text_encoder,
|
| 208 |
+
tokenizer=tokenizer,
|
| 209 |
+
transformer=transformer,
|
| 210 |
+
scheduler=scheduler,
|
| 211 |
+
image_encoder=image_encoder,
|
| 212 |
+
feature_extractor=feature_extractor,
|
| 213 |
+
)
|
| 214 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 215 |
+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
| 216 |
+
|
| 217 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 218 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 219 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 220 |
+
self.default_sample_size = 128
|
| 221 |
+
|
| 222 |
+
def _get_t5_prompt_embeds(
|
| 223 |
+
self,
|
| 224 |
+
prompt: Union[str, List[str]] = None,
|
| 225 |
+
num_images_per_prompt: int = 1,
|
| 226 |
+
max_sequence_length: int = 512,
|
| 227 |
+
device: Optional[torch.device] = None,
|
| 228 |
+
dtype: Optional[torch.dtype] = None,
|
| 229 |
+
):
|
| 230 |
+
device = device or self._execution_device
|
| 231 |
+
dtype = dtype or self.text_encoder.dtype
|
| 232 |
+
|
| 233 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 234 |
+
batch_size = len(prompt)
|
| 235 |
+
|
| 236 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 237 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 238 |
+
|
| 239 |
+
text_inputs = self.tokenizer(
|
| 240 |
+
prompt,
|
| 241 |
+
padding="max_length",
|
| 242 |
+
max_length=max_sequence_length,
|
| 243 |
+
truncation=True,
|
| 244 |
+
return_length=False,
|
| 245 |
+
return_overflowing_tokens=False,
|
| 246 |
+
return_tensors="pt",
|
| 247 |
+
)
|
| 248 |
+
text_input_ids = text_inputs.input_ids
|
| 249 |
+
attention_mask = text_inputs.attention_mask.clone()
|
| 250 |
+
|
| 251 |
+
# Chroma requires the attention mask to include one padding token
|
| 252 |
+
seq_lengths = attention_mask.sum(dim=1)
|
| 253 |
+
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
| 254 |
+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
| 255 |
+
|
| 256 |
+
prompt_embeds = self.text_encoder(
|
| 257 |
+
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
| 258 |
+
)[0]
|
| 259 |
+
|
| 260 |
+
dtype = self.text_encoder.dtype
|
| 261 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 262 |
+
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
| 263 |
+
|
| 264 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 265 |
+
|
| 266 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 267 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 268 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 269 |
+
|
| 270 |
+
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
| 271 |
+
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
|
| 272 |
+
|
| 273 |
+
return prompt_embeds, attention_mask
|
| 274 |
+
|
| 275 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
| 276 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 277 |
+
if isinstance(generator, list):
|
| 278 |
+
image_latents = [
|
| 279 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 280 |
+
for i in range(image.shape[0])
|
| 281 |
+
]
|
| 282 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 283 |
+
else:
|
| 284 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 285 |
+
|
| 286 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 287 |
+
|
| 288 |
+
return image_latents
|
| 289 |
+
|
| 290 |
+
def encode_prompt(
|
| 291 |
+
self,
|
| 292 |
+
prompt: Union[str, List[str]],
|
| 293 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 294 |
+
device: Optional[torch.device] = None,
|
| 295 |
+
num_images_per_prompt: int = 1,
|
| 296 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 297 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 298 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 300 |
+
do_classifier_free_guidance: bool = True,
|
| 301 |
+
max_sequence_length: int = 512,
|
| 302 |
+
lora_scale: Optional[float] = None,
|
| 303 |
+
):
|
| 304 |
+
r"""
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 308 |
+
prompt to be encoded
|
| 309 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 310 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
| 311 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
| 312 |
+
device: (`torch.device`):
|
| 313 |
+
torch device
|
| 314 |
+
num_images_per_prompt (`int`):
|
| 315 |
+
number of images that should be generated per prompt
|
| 316 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 317 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 318 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 319 |
+
lora_scale (`float`, *optional*):
|
| 320 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 321 |
+
"""
|
| 322 |
+
device = device or self._execution_device
|
| 323 |
+
|
| 324 |
+
# set lora scale so that monkey patched LoRA
|
| 325 |
+
# function of text encoder can correctly access it
|
| 326 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 327 |
+
self._lora_scale = lora_scale
|
| 328 |
+
|
| 329 |
+
# dynamically adjust the LoRA scale
|
| 330 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 331 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 332 |
+
|
| 333 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 334 |
+
|
| 335 |
+
if prompt is not None:
|
| 336 |
+
batch_size = len(prompt)
|
| 337 |
+
else:
|
| 338 |
+
batch_size = prompt_embeds.shape[0]
|
| 339 |
+
|
| 340 |
+
if prompt_embeds is None:
|
| 341 |
+
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
| 342 |
+
prompt=prompt,
|
| 343 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 344 |
+
max_sequence_length=max_sequence_length,
|
| 345 |
+
device=device,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 349 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 350 |
+
negative_text_ids = None
|
| 351 |
+
|
| 352 |
+
if do_classifier_free_guidance:
|
| 353 |
+
if negative_prompt_embeds is None:
|
| 354 |
+
negative_prompt = negative_prompt or ""
|
| 355 |
+
negative_prompt = (
|
| 356 |
+
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 360 |
+
raise TypeError(
|
| 361 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 362 |
+
f" {type(prompt)}."
|
| 363 |
+
)
|
| 364 |
+
elif batch_size != len(negative_prompt):
|
| 365 |
+
raise ValueError(
|
| 366 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 367 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 368 |
+
" the batch size of `prompt`."
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
| 372 |
+
prompt=negative_prompt,
|
| 373 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 374 |
+
max_sequence_length=max_sequence_length,
|
| 375 |
+
device=device,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 379 |
+
|
| 380 |
+
if self.text_encoder is not None:
|
| 381 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 382 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 383 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 384 |
+
|
| 385 |
+
return (
|
| 386 |
+
prompt_embeds,
|
| 387 |
+
text_ids,
|
| 388 |
+
prompt_attention_mask,
|
| 389 |
+
negative_prompt_embeds,
|
| 390 |
+
negative_text_ids,
|
| 391 |
+
negative_prompt_attention_mask,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
| 395 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
| 396 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 397 |
+
|
| 398 |
+
if not isinstance(image, torch.Tensor):
|
| 399 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 400 |
+
|
| 401 |
+
image = image.to(device=device, dtype=dtype)
|
| 402 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 403 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 404 |
+
return image_embeds
|
| 405 |
+
|
| 406 |
+
def prepare_ip_adapter_image_embeds(
|
| 407 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
| 408 |
+
):
|
| 409 |
+
device = device or self._execution_device
|
| 410 |
+
|
| 411 |
+
image_embeds = []
|
| 412 |
+
if ip_adapter_image_embeds is None:
|
| 413 |
+
if not isinstance(ip_adapter_image, list):
|
| 414 |
+
ip_adapter_image = [ip_adapter_image]
|
| 415 |
+
|
| 416 |
+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 417 |
+
raise ValueError(
|
| 418 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
for single_ip_adapter_image in ip_adapter_image:
|
| 422 |
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
| 423 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 424 |
+
else:
|
| 425 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 426 |
+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
| 427 |
+
|
| 428 |
+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 429 |
+
raise ValueError(
|
| 430 |
+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 434 |
+
image_embeds.append(single_image_embeds)
|
| 435 |
+
|
| 436 |
+
ip_adapter_image_embeds = []
|
| 437 |
+
for single_image_embeds in image_embeds:
|
| 438 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 439 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 440 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 441 |
+
|
| 442 |
+
return ip_adapter_image_embeds
|
| 443 |
+
|
| 444 |
+
def check_inputs(
|
| 445 |
+
self,
|
| 446 |
+
prompt,
|
| 447 |
+
height,
|
| 448 |
+
width,
|
| 449 |
+
strength,
|
| 450 |
+
negative_prompt=None,
|
| 451 |
+
prompt_embeds=None,
|
| 452 |
+
negative_prompt_embeds=None,
|
| 453 |
+
prompt_attention_mask=None,
|
| 454 |
+
negative_prompt_attention_mask=None,
|
| 455 |
+
callback_on_step_end_tensor_inputs=None,
|
| 456 |
+
max_sequence_length=None,
|
| 457 |
+
):
|
| 458 |
+
if strength < 0 or strength > 1:
|
| 459 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 460 |
+
|
| 461 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 462 |
+
logger.warning(
|
| 463 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 467 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 468 |
+
):
|
| 469 |
+
raise ValueError(
|
| 470 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if prompt is not None and prompt_embeds is not None:
|
| 474 |
+
raise ValueError(
|
| 475 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 476 |
+
" only forward one of the two."
|
| 477 |
+
)
|
| 478 |
+
elif prompt is None and prompt_embeds is None:
|
| 479 |
+
raise ValueError(
|
| 480 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 481 |
+
)
|
| 482 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 483 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 484 |
+
|
| 485 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 486 |
+
raise ValueError(
|
| 487 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 488 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 492 |
+
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
| 493 |
+
|
| 494 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
"Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 500 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 501 |
+
|
| 502 |
+
@staticmethod
|
| 503 |
+
def _prepare_latent_image_ids(height, width, device, dtype):
|
| 504 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 505 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 506 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 507 |
+
|
| 508 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 509 |
+
|
| 510 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 511 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 515 |
+
|
| 516 |
+
@staticmethod
|
| 517 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 518 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 519 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 520 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 521 |
+
|
| 522 |
+
return latents
|
| 523 |
+
|
| 524 |
+
@staticmethod
|
| 525 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 526 |
+
batch_size, num_patches, channels = latents.shape
|
| 527 |
+
|
| 528 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 529 |
+
# latent height and width to be divisible by 2.
|
| 530 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 531 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 532 |
+
|
| 533 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 534 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 535 |
+
|
| 536 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 537 |
+
|
| 538 |
+
return latents
|
| 539 |
+
|
| 540 |
+
def enable_vae_slicing(self):
|
| 541 |
+
r"""
|
| 542 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 543 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 544 |
+
"""
|
| 545 |
+
self.vae.enable_slicing()
|
| 546 |
+
|
| 547 |
+
def disable_vae_slicing(self):
|
| 548 |
+
r"""
|
| 549 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 550 |
+
computing decoding in one step.
|
| 551 |
+
"""
|
| 552 |
+
self.vae.disable_slicing()
|
| 553 |
+
|
| 554 |
+
def enable_vae_tiling(self):
|
| 555 |
+
r"""
|
| 556 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 557 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 558 |
+
processing larger images.
|
| 559 |
+
"""
|
| 560 |
+
self.vae.enable_tiling()
|
| 561 |
+
|
| 562 |
+
def disable_vae_tiling(self):
|
| 563 |
+
r"""
|
| 564 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 565 |
+
computing decoding in one step.
|
| 566 |
+
"""
|
| 567 |
+
self.vae.disable_tiling()
|
| 568 |
+
|
| 569 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
| 570 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 571 |
+
# get the original timestep using init_timestep
|
| 572 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
| 573 |
+
|
| 574 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
| 575 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 576 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
| 577 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
| 578 |
+
|
| 579 |
+
return timesteps, num_inference_steps - t_start
|
| 580 |
+
|
| 581 |
+
def prepare_latents(
|
| 582 |
+
self,
|
| 583 |
+
image,
|
| 584 |
+
timestep,
|
| 585 |
+
batch_size,
|
| 586 |
+
num_channels_latents,
|
| 587 |
+
height,
|
| 588 |
+
width,
|
| 589 |
+
dtype,
|
| 590 |
+
device,
|
| 591 |
+
generator,
|
| 592 |
+
latents=None,
|
| 593 |
+
):
|
| 594 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 595 |
+
raise ValueError(
|
| 596 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 597 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 601 |
+
# latent height and width to be divisible by 2.
|
| 602 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 603 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 604 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 605 |
+
latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype)
|
| 606 |
+
|
| 607 |
+
if latents is not None:
|
| 608 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 609 |
+
|
| 610 |
+
image = image.to(device=device, dtype=dtype)
|
| 611 |
+
if image.shape[1] != self.latent_channels:
|
| 612 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
| 613 |
+
else:
|
| 614 |
+
image_latents = image
|
| 615 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 616 |
+
# expand init_latents for batch_size
|
| 617 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 618 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 619 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 620 |
+
raise ValueError(
|
| 621 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
| 622 |
+
)
|
| 623 |
+
else:
|
| 624 |
+
image_latents = torch.cat([image_latents], dim=0)
|
| 625 |
+
|
| 626 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 627 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
| 628 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 629 |
+
return latents, latent_image_ids
|
| 630 |
+
|
| 631 |
+
def _prepare_attention_mask(
|
| 632 |
+
self,
|
| 633 |
+
batch_size,
|
| 634 |
+
sequence_length,
|
| 635 |
+
dtype,
|
| 636 |
+
attention_mask=None,
|
| 637 |
+
):
|
| 638 |
+
if attention_mask is None:
|
| 639 |
+
return attention_mask
|
| 640 |
+
|
| 641 |
+
# Extend the prompt attention mask to account for image tokens in the final sequence
|
| 642 |
+
attention_mask = torch.cat(
|
| 643 |
+
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
|
| 644 |
+
dim=1,
|
| 645 |
+
)
|
| 646 |
+
attention_mask = attention_mask.to(dtype)
|
| 647 |
+
|
| 648 |
+
return attention_mask
|
| 649 |
+
|
| 650 |
+
@property
|
| 651 |
+
def guidance_scale(self):
|
| 652 |
+
return self._guidance_scale
|
| 653 |
+
|
| 654 |
+
@property
|
| 655 |
+
def joint_attention_kwargs(self):
|
| 656 |
+
return self._joint_attention_kwargs
|
| 657 |
+
|
| 658 |
+
@property
|
| 659 |
+
def do_classifier_free_guidance(self):
|
| 660 |
+
return self._guidance_scale > 1
|
| 661 |
+
|
| 662 |
+
@property
|
| 663 |
+
def num_timesteps(self):
|
| 664 |
+
return self._num_timesteps
|
| 665 |
+
|
| 666 |
+
@property
|
| 667 |
+
def current_timestep(self):
|
| 668 |
+
return self._current_timestep
|
| 669 |
+
|
| 670 |
+
@property
|
| 671 |
+
def interrupt(self):
|
| 672 |
+
return self._interrupt
|
| 673 |
+
|
| 674 |
+
@torch.no_grad()
|
| 675 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 676 |
+
def __call__(
|
| 677 |
+
self,
|
| 678 |
+
prompt: Union[str, List[str]] = None,
|
| 679 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 680 |
+
image: PipelineImageInput = None,
|
| 681 |
+
height: Optional[int] = None,
|
| 682 |
+
width: Optional[int] = None,
|
| 683 |
+
num_inference_steps: int = 35,
|
| 684 |
+
sigmas: Optional[List[float]] = None,
|
| 685 |
+
guidance_scale: float = 5.0,
|
| 686 |
+
strength: float = 0.9,
|
| 687 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 688 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 689 |
+
latents: Optional[torch.Tensor] = None,
|
| 690 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 691 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 692 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 693 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 694 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 695 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 696 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 697 |
+
negative_prompt_attention_mask: Optional[torch.tensor] = None,
|
| 698 |
+
output_type: Optional[str] = "pil",
|
| 699 |
+
return_dict: bool = True,
|
| 700 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 701 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 702 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 703 |
+
max_sequence_length: int = 512,
|
| 704 |
+
):
|
| 705 |
+
r"""
|
| 706 |
+
Function invoked when calling the pipeline for generation.
|
| 707 |
+
|
| 708 |
+
Args:
|
| 709 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 710 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 711 |
+
instead.
|
| 712 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 713 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 714 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 715 |
+
not greater than `1`).
|
| 716 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 717 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 718 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 719 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 720 |
+
num_inference_steps (`int`, *optional*, defaults to 35):
|
| 721 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 722 |
+
expense of slower inference.
|
| 723 |
+
sigmas (`List[float]`, *optional*):
|
| 724 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 725 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 726 |
+
will be used.
|
| 727 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 728 |
+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
| 729 |
+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
| 730 |
+
|
| 731 |
+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
| 732 |
+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
| 733 |
+
strength (`float, *optional*, defaults to 0.9):
|
| 734 |
+
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
|
| 735 |
+
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
|
| 736 |
+
steps depends on the amount of noise initially added. When strength is 1, added noise will be maximum
|
| 737 |
+
and the denoising process will run for the full number of iterations specified in num_inference_steps.
|
| 738 |
+
A value of 1, therefore, essentially ignores image.
|
| 739 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 740 |
+
The number of images to generate per prompt.
|
| 741 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 742 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 743 |
+
to make generation deterministic.
|
| 744 |
+
latents (`torch.Tensor`, *optional*):
|
| 745 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 746 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 747 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 748 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 749 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 750 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 751 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 752 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 753 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 754 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 755 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 756 |
+
negative_ip_adapter_image:
|
| 757 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 758 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 759 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 760 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 761 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 762 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 763 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 764 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 765 |
+
argument.
|
| 766 |
+
prompt_attention_mask (torch.Tensor, *optional*):
|
| 767 |
+
Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
|
| 768 |
+
Chroma requires a single padding token remain unmasked. Please refer to
|
| 769 |
+
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
| 770 |
+
negative_prompt_attention_mask (torch.Tensor, *optional*):
|
| 771 |
+
Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
|
| 772 |
+
prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
|
| 773 |
+
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
| 774 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 775 |
+
The output format of the generate image. Choose between
|
| 776 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 777 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 778 |
+
Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
|
| 779 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 780 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 781 |
+
`self.processor` in
|
| 782 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 783 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 784 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 785 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 786 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 787 |
+
`callback_on_step_end_tensor_inputs`.
|
| 788 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 789 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 790 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 791 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 792 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 793 |
+
|
| 794 |
+
Examples:
|
| 795 |
+
|
| 796 |
+
Returns:
|
| 797 |
+
[`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
|
| 798 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 799 |
+
generated images.
|
| 800 |
+
"""
|
| 801 |
+
|
| 802 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 803 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 804 |
+
|
| 805 |
+
# 1. Check inputs. Raise error if not correct
|
| 806 |
+
self.check_inputs(
|
| 807 |
+
prompt,
|
| 808 |
+
height,
|
| 809 |
+
width,
|
| 810 |
+
strength,
|
| 811 |
+
negative_prompt=negative_prompt,
|
| 812 |
+
prompt_embeds=prompt_embeds,
|
| 813 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 814 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 815 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 816 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 817 |
+
max_sequence_length=max_sequence_length,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
self._guidance_scale = guidance_scale
|
| 821 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 822 |
+
self._current_timestep = None
|
| 823 |
+
self._interrupt = False
|
| 824 |
+
|
| 825 |
+
# 2. Preprocess image
|
| 826 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
| 827 |
+
init_image = init_image.to(dtype=torch.float32)
|
| 828 |
+
|
| 829 |
+
# 3. Define call parameters
|
| 830 |
+
if prompt is not None and isinstance(prompt, str):
|
| 831 |
+
batch_size = 1
|
| 832 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 833 |
+
batch_size = len(prompt)
|
| 834 |
+
else:
|
| 835 |
+
batch_size = prompt_embeds.shape[0]
|
| 836 |
+
|
| 837 |
+
device = self._execution_device
|
| 838 |
+
lora_scale = (
|
| 839 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
(
|
| 843 |
+
prompt_embeds,
|
| 844 |
+
text_ids,
|
| 845 |
+
prompt_attention_mask,
|
| 846 |
+
negative_prompt_embeds,
|
| 847 |
+
negative_text_ids,
|
| 848 |
+
negative_prompt_attention_mask,
|
| 849 |
+
) = self.encode_prompt(
|
| 850 |
+
prompt=prompt,
|
| 851 |
+
negative_prompt=negative_prompt,
|
| 852 |
+
prompt_embeds=prompt_embeds,
|
| 853 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 854 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 855 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 856 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 857 |
+
device=device,
|
| 858 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 859 |
+
max_sequence_length=max_sequence_length,
|
| 860 |
+
lora_scale=lora_scale,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
# 4. Prepare timesteps
|
| 864 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 865 |
+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
| 866 |
+
mu = calculate_shift(
|
| 867 |
+
image_seq_len,
|
| 868 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 869 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 870 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 871 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 872 |
+
)
|
| 873 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 874 |
+
self.scheduler,
|
| 875 |
+
num_inference_steps,
|
| 876 |
+
device,
|
| 877 |
+
sigmas=sigmas,
|
| 878 |
+
mu=mu,
|
| 879 |
+
)
|
| 880 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
| 881 |
+
|
| 882 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 883 |
+
self._num_timesteps = len(timesteps)
|
| 884 |
+
|
| 885 |
+
if num_inference_steps < 1:
|
| 886 |
+
raise ValueError(
|
| 887 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
| 888 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
| 889 |
+
)
|
| 890 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 891 |
+
|
| 892 |
+
# 5. Prepare latent variables
|
| 893 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 894 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 895 |
+
init_image,
|
| 896 |
+
latent_timestep,
|
| 897 |
+
batch_size * num_images_per_prompt,
|
| 898 |
+
num_channels_latents,
|
| 899 |
+
height,
|
| 900 |
+
width,
|
| 901 |
+
prompt_embeds.dtype,
|
| 902 |
+
device,
|
| 903 |
+
generator,
|
| 904 |
+
latents,
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
attention_mask = self._prepare_attention_mask(
|
| 908 |
+
batch_size=latents.shape[0],
|
| 909 |
+
sequence_length=image_seq_len,
|
| 910 |
+
dtype=latents.dtype,
|
| 911 |
+
attention_mask=prompt_attention_mask,
|
| 912 |
+
)
|
| 913 |
+
negative_attention_mask = self._prepare_attention_mask(
|
| 914 |
+
batch_size=latents.shape[0],
|
| 915 |
+
sequence_length=image_seq_len,
|
| 916 |
+
dtype=latents.dtype,
|
| 917 |
+
attention_mask=negative_prompt_attention_mask,
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
# 6. Prepare image embeddings
|
| 921 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 922 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 923 |
+
):
|
| 924 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 925 |
+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 926 |
+
|
| 927 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 928 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 929 |
+
):
|
| 930 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 931 |
+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 932 |
+
|
| 933 |
+
if self.joint_attention_kwargs is None:
|
| 934 |
+
self._joint_attention_kwargs = {}
|
| 935 |
+
|
| 936 |
+
image_embeds = None
|
| 937 |
+
negative_image_embeds = None
|
| 938 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 939 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 940 |
+
ip_adapter_image,
|
| 941 |
+
ip_adapter_image_embeds,
|
| 942 |
+
device,
|
| 943 |
+
batch_size * num_images_per_prompt,
|
| 944 |
+
)
|
| 945 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 946 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 947 |
+
negative_ip_adapter_image,
|
| 948 |
+
negative_ip_adapter_image_embeds,
|
| 949 |
+
device,
|
| 950 |
+
batch_size * num_images_per_prompt,
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
# 6. Denoising loop
|
| 954 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 955 |
+
for i, t in enumerate(timesteps):
|
| 956 |
+
if self.interrupt:
|
| 957 |
+
continue
|
| 958 |
+
|
| 959 |
+
self._current_timestep = t
|
| 960 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 961 |
+
timestep = t.expand(latents.shape[0])
|
| 962 |
+
|
| 963 |
+
if image_embeds is not None:
|
| 964 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 965 |
+
|
| 966 |
+
noise_pred = self.transformer(
|
| 967 |
+
hidden_states=latents,
|
| 968 |
+
timestep=timestep / 1000,
|
| 969 |
+
encoder_hidden_states=prompt_embeds,
|
| 970 |
+
txt_ids=text_ids,
|
| 971 |
+
img_ids=latent_image_ids,
|
| 972 |
+
attention_mask=attention_mask,
|
| 973 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 974 |
+
return_dict=False,
|
| 975 |
+
)[0]
|
| 976 |
+
|
| 977 |
+
if self.do_classifier_free_guidance:
|
| 978 |
+
if negative_image_embeds is not None:
|
| 979 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 980 |
+
|
| 981 |
+
noise_pred_uncond = self.transformer(
|
| 982 |
+
hidden_states=latents,
|
| 983 |
+
timestep=timestep / 1000,
|
| 984 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 985 |
+
txt_ids=negative_text_ids,
|
| 986 |
+
img_ids=latent_image_ids,
|
| 987 |
+
attention_mask=negative_attention_mask,
|
| 988 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 989 |
+
return_dict=False,
|
| 990 |
+
)[0]
|
| 991 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
| 992 |
+
|
| 993 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 994 |
+
latents_dtype = latents.dtype
|
| 995 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 996 |
+
|
| 997 |
+
if latents.dtype != latents_dtype:
|
| 998 |
+
if torch.backends.mps.is_available():
|
| 999 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1000 |
+
latents = latents.to(latents_dtype)
|
| 1001 |
+
|
| 1002 |
+
if callback_on_step_end is not None:
|
| 1003 |
+
callback_kwargs = {}
|
| 1004 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1005 |
+
callback_kwargs[k] = locals()[k]
|
| 1006 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1007 |
+
|
| 1008 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1009 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1010 |
+
|
| 1011 |
+
# call the callback, if provided
|
| 1012 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1013 |
+
progress_bar.update()
|
| 1014 |
+
|
| 1015 |
+
if XLA_AVAILABLE:
|
| 1016 |
+
xm.mark_step()
|
| 1017 |
+
|
| 1018 |
+
self._current_timestep = None
|
| 1019 |
+
|
| 1020 |
+
if output_type == "latent":
|
| 1021 |
+
image = latents
|
| 1022 |
+
else:
|
| 1023 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1024 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1025 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1026 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1027 |
+
|
| 1028 |
+
# Offload all models
|
| 1029 |
+
self.maybe_free_model_hooks()
|
| 1030 |
+
|
| 1031 |
+
if not return_dict:
|
| 1032 |
+
return (image,)
|
| 1033 |
+
|
| 1034 |
+
return ChromaPipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_output.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from ...utils import BaseOutput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ChromaPipelineOutput(BaseOutput):
|
| 12 |
+
"""
|
| 13 |
+
Output class for Stable Diffusion pipelines.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 17 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 18 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
|
| 26 |
+
_import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
|
| 27 |
+
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
|
| 28 |
+
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 31 |
+
try:
|
| 32 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 33 |
+
raise OptionalDependencyNotAvailable()
|
| 34 |
+
|
| 35 |
+
except OptionalDependencyNotAvailable:
|
| 36 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 37 |
+
else:
|
| 38 |
+
from .pipeline_cogvideox import CogVideoXPipeline
|
| 39 |
+
from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
|
| 40 |
+
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
| 41 |
+
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
|
| 42 |
+
|
| 43 |
+
else:
|
| 44 |
+
import sys
|
| 45 |
+
|
| 46 |
+
sys.modules[__name__] = _LazyModule(
|
| 47 |
+
__name__,
|
| 48 |
+
globals()["__file__"],
|
| 49 |
+
_import_structure,
|
| 50 |
+
module_spec=__spec__,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
for name, value in _dummy_objects.items():
|
| 54 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 22 |
+
|
| 23 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 24 |
+
from ...loaders import CogVideoXLoraLoaderMixin
|
| 25 |
+
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 26 |
+
from ...models.embeddings import get_3d_rotary_pos_embed
|
| 27 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 28 |
+
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 29 |
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 30 |
+
from ...utils.torch_utils import randn_tensor
|
| 31 |
+
from ...video_processor import VideoProcessor
|
| 32 |
+
from .pipeline_output import CogVideoXPipelineOutput
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_torch_xla_available():
|
| 36 |
+
import torch_xla.core.xla_model as xm
|
| 37 |
+
|
| 38 |
+
XLA_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
XLA_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
EXAMPLE_DOC_STRING = """
|
| 46 |
+
Examples:
|
| 47 |
+
```python
|
| 48 |
+
>>> import torch
|
| 49 |
+
>>> from diffusers import CogVideoXPipeline
|
| 50 |
+
>>> from diffusers.utils import export_to_video
|
| 51 |
+
|
| 52 |
+
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
| 53 |
+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
|
| 54 |
+
>>> prompt = (
|
| 55 |
+
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
| 56 |
+
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
| 57 |
+
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
| 58 |
+
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
| 59 |
+
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
| 60 |
+
... "atmosphere of this unique musical performance."
|
| 61 |
+
... )
|
| 62 |
+
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
| 63 |
+
>>> export_to_video(video, "output.mp4", fps=8)
|
| 64 |
+
```
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 69 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 70 |
+
tw = tgt_width
|
| 71 |
+
th = tgt_height
|
| 72 |
+
h, w = src
|
| 73 |
+
r = h / w
|
| 74 |
+
if r > (th / tw):
|
| 75 |
+
resize_height = th
|
| 76 |
+
resize_width = int(round(th / h * w))
|
| 77 |
+
else:
|
| 78 |
+
resize_width = tw
|
| 79 |
+
resize_height = int(round(tw / w * h))
|
| 80 |
+
|
| 81 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 82 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 83 |
+
|
| 84 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 88 |
+
def retrieve_timesteps(
|
| 89 |
+
scheduler,
|
| 90 |
+
num_inference_steps: Optional[int] = None,
|
| 91 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 92 |
+
timesteps: Optional[List[int]] = None,
|
| 93 |
+
sigmas: Optional[List[float]] = None,
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
r"""
|
| 97 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 98 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
scheduler (`SchedulerMixin`):
|
| 102 |
+
The scheduler to get timesteps from.
|
| 103 |
+
num_inference_steps (`int`):
|
| 104 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 105 |
+
must be `None`.
|
| 106 |
+
device (`str` or `torch.device`, *optional*):
|
| 107 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 108 |
+
timesteps (`List[int]`, *optional*):
|
| 109 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 110 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 111 |
+
sigmas (`List[float]`, *optional*):
|
| 112 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 113 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 117 |
+
second element is the number of inference steps.
|
| 118 |
+
"""
|
| 119 |
+
if timesteps is not None and sigmas is not None:
|
| 120 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 121 |
+
if timesteps is not None:
|
| 122 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 123 |
+
if not accepts_timesteps:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 126 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 127 |
+
)
|
| 128 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 129 |
+
timesteps = scheduler.timesteps
|
| 130 |
+
num_inference_steps = len(timesteps)
|
| 131 |
+
elif sigmas is not None:
|
| 132 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 133 |
+
if not accept_sigmas:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 136 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 137 |
+
)
|
| 138 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 139 |
+
timesteps = scheduler.timesteps
|
| 140 |
+
num_inference_steps = len(timesteps)
|
| 141 |
+
else:
|
| 142 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 143 |
+
timesteps = scheduler.timesteps
|
| 144 |
+
return timesteps, num_inference_steps
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 148 |
+
r"""
|
| 149 |
+
Pipeline for text-to-video generation using CogVideoX.
|
| 150 |
+
|
| 151 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 152 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
vae ([`AutoencoderKL`]):
|
| 156 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 157 |
+
text_encoder ([`T5EncoderModel`]):
|
| 158 |
+
Frozen text-encoder. CogVideoX uses
|
| 159 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 160 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 161 |
+
tokenizer (`T5Tokenizer`):
|
| 162 |
+
Tokenizer of class
|
| 163 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 164 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 165 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 166 |
+
scheduler ([`SchedulerMixin`]):
|
| 167 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
_optional_components = []
|
| 171 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 172 |
+
|
| 173 |
+
_callback_tensor_inputs = [
|
| 174 |
+
"latents",
|
| 175 |
+
"prompt_embeds",
|
| 176 |
+
"negative_prompt_embeds",
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
tokenizer: T5Tokenizer,
|
| 182 |
+
text_encoder: T5EncoderModel,
|
| 183 |
+
vae: AutoencoderKLCogVideoX,
|
| 184 |
+
transformer: CogVideoXTransformer3DModel,
|
| 185 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
|
| 189 |
+
self.register_modules(
|
| 190 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 191 |
+
)
|
| 192 |
+
self.vae_scale_factor_spatial = (
|
| 193 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 194 |
+
)
|
| 195 |
+
self.vae_scale_factor_temporal = (
|
| 196 |
+
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 197 |
+
)
|
| 198 |
+
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
| 199 |
+
|
| 200 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 201 |
+
|
| 202 |
+
def _get_t5_prompt_embeds(
|
| 203 |
+
self,
|
| 204 |
+
prompt: Union[str, List[str]] = None,
|
| 205 |
+
num_videos_per_prompt: int = 1,
|
| 206 |
+
max_sequence_length: int = 226,
|
| 207 |
+
device: Optional[torch.device] = None,
|
| 208 |
+
dtype: Optional[torch.dtype] = None,
|
| 209 |
+
):
|
| 210 |
+
device = device or self._execution_device
|
| 211 |
+
dtype = dtype or self.text_encoder.dtype
|
| 212 |
+
|
| 213 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 214 |
+
batch_size = len(prompt)
|
| 215 |
+
|
| 216 |
+
text_inputs = self.tokenizer(
|
| 217 |
+
prompt,
|
| 218 |
+
padding="max_length",
|
| 219 |
+
max_length=max_sequence_length,
|
| 220 |
+
truncation=True,
|
| 221 |
+
add_special_tokens=True,
|
| 222 |
+
return_tensors="pt",
|
| 223 |
+
)
|
| 224 |
+
text_input_ids = text_inputs.input_ids
|
| 225 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 226 |
+
|
| 227 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 228 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 229 |
+
logger.warning(
|
| 230 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 231 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 235 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 236 |
+
|
| 237 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 238 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 239 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 240 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 241 |
+
|
| 242 |
+
return prompt_embeds
|
| 243 |
+
|
| 244 |
+
def encode_prompt(
|
| 245 |
+
self,
|
| 246 |
+
prompt: Union[str, List[str]],
|
| 247 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 248 |
+
do_classifier_free_guidance: bool = True,
|
| 249 |
+
num_videos_per_prompt: int = 1,
|
| 250 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 251 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 252 |
+
max_sequence_length: int = 226,
|
| 253 |
+
device: Optional[torch.device] = None,
|
| 254 |
+
dtype: Optional[torch.dtype] = None,
|
| 255 |
+
):
|
| 256 |
+
r"""
|
| 257 |
+
Encodes the prompt into text encoder hidden states.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 261 |
+
prompt to be encoded
|
| 262 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 263 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 264 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 265 |
+
less than `1`).
|
| 266 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 267 |
+
Whether to use classifier free guidance or not.
|
| 268 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 269 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 270 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 271 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 272 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 273 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 274 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 275 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 276 |
+
argument.
|
| 277 |
+
device: (`torch.device`, *optional*):
|
| 278 |
+
torch device
|
| 279 |
+
dtype: (`torch.dtype`, *optional*):
|
| 280 |
+
torch dtype
|
| 281 |
+
"""
|
| 282 |
+
device = device or self._execution_device
|
| 283 |
+
|
| 284 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 285 |
+
if prompt is not None:
|
| 286 |
+
batch_size = len(prompt)
|
| 287 |
+
else:
|
| 288 |
+
batch_size = prompt_embeds.shape[0]
|
| 289 |
+
|
| 290 |
+
if prompt_embeds is None:
|
| 291 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 292 |
+
prompt=prompt,
|
| 293 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 294 |
+
max_sequence_length=max_sequence_length,
|
| 295 |
+
device=device,
|
| 296 |
+
dtype=dtype,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 300 |
+
negative_prompt = negative_prompt or ""
|
| 301 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 302 |
+
|
| 303 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 304 |
+
raise TypeError(
|
| 305 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 306 |
+
f" {type(prompt)}."
|
| 307 |
+
)
|
| 308 |
+
elif batch_size != len(negative_prompt):
|
| 309 |
+
raise ValueError(
|
| 310 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 311 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 312 |
+
" the batch size of `prompt`."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 316 |
+
prompt=negative_prompt,
|
| 317 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 318 |
+
max_sequence_length=max_sequence_length,
|
| 319 |
+
device=device,
|
| 320 |
+
dtype=dtype,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
return prompt_embeds, negative_prompt_embeds
|
| 324 |
+
|
| 325 |
+
def prepare_latents(
|
| 326 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 327 |
+
):
|
| 328 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 329 |
+
raise ValueError(
|
| 330 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 331 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
shape = (
|
| 335 |
+
batch_size,
|
| 336 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 337 |
+
num_channels_latents,
|
| 338 |
+
height // self.vae_scale_factor_spatial,
|
| 339 |
+
width // self.vae_scale_factor_spatial,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if latents is None:
|
| 343 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 344 |
+
else:
|
| 345 |
+
latents = latents.to(device)
|
| 346 |
+
|
| 347 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 348 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 349 |
+
return latents
|
| 350 |
+
|
| 351 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 352 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 353 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 354 |
+
|
| 355 |
+
frames = self.vae.decode(latents).sample
|
| 356 |
+
return frames
|
| 357 |
+
|
| 358 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 359 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 360 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 361 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 362 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 363 |
+
# and should be between [0, 1]
|
| 364 |
+
|
| 365 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 366 |
+
extra_step_kwargs = {}
|
| 367 |
+
if accepts_eta:
|
| 368 |
+
extra_step_kwargs["eta"] = eta
|
| 369 |
+
|
| 370 |
+
# check if the scheduler accepts generator
|
| 371 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 372 |
+
if accepts_generator:
|
| 373 |
+
extra_step_kwargs["generator"] = generator
|
| 374 |
+
return extra_step_kwargs
|
| 375 |
+
|
| 376 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 377 |
+
def check_inputs(
|
| 378 |
+
self,
|
| 379 |
+
prompt,
|
| 380 |
+
height,
|
| 381 |
+
width,
|
| 382 |
+
negative_prompt,
|
| 383 |
+
callback_on_step_end_tensor_inputs,
|
| 384 |
+
prompt_embeds=None,
|
| 385 |
+
negative_prompt_embeds=None,
|
| 386 |
+
):
|
| 387 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 388 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 389 |
+
|
| 390 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 391 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 392 |
+
):
|
| 393 |
+
raise ValueError(
|
| 394 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 395 |
+
)
|
| 396 |
+
if prompt is not None and prompt_embeds is not None:
|
| 397 |
+
raise ValueError(
|
| 398 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 399 |
+
" only forward one of the two."
|
| 400 |
+
)
|
| 401 |
+
elif prompt is None and prompt_embeds is None:
|
| 402 |
+
raise ValueError(
|
| 403 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 404 |
+
)
|
| 405 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 406 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 407 |
+
|
| 408 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 411 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 417 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 421 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 424 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 425 |
+
f" {negative_prompt_embeds.shape}."
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def fuse_qkv_projections(self) -> None:
|
| 429 |
+
r"""Enables fused QKV projections."""
|
| 430 |
+
self.fusing_transformer = True
|
| 431 |
+
self.transformer.fuse_qkv_projections()
|
| 432 |
+
|
| 433 |
+
def unfuse_qkv_projections(self) -> None:
|
| 434 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 435 |
+
if not self.fusing_transformer:
|
| 436 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 437 |
+
else:
|
| 438 |
+
self.transformer.unfuse_qkv_projections()
|
| 439 |
+
self.fusing_transformer = False
|
| 440 |
+
|
| 441 |
+
def _prepare_rotary_positional_embeddings(
|
| 442 |
+
self,
|
| 443 |
+
height: int,
|
| 444 |
+
width: int,
|
| 445 |
+
num_frames: int,
|
| 446 |
+
device: torch.device,
|
| 447 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 448 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 449 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 450 |
+
|
| 451 |
+
p = self.transformer.config.patch_size
|
| 452 |
+
p_t = self.transformer.config.patch_size_t
|
| 453 |
+
|
| 454 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 455 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 456 |
+
|
| 457 |
+
if p_t is None:
|
| 458 |
+
# CogVideoX 1.0
|
| 459 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 460 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 461 |
+
)
|
| 462 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 463 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 464 |
+
crops_coords=grid_crops_coords,
|
| 465 |
+
grid_size=(grid_height, grid_width),
|
| 466 |
+
temporal_size=num_frames,
|
| 467 |
+
device=device,
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
# CogVideoX 1.5
|
| 471 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 472 |
+
|
| 473 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 474 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 475 |
+
crops_coords=None,
|
| 476 |
+
grid_size=(grid_height, grid_width),
|
| 477 |
+
temporal_size=base_num_frames,
|
| 478 |
+
grid_type="slice",
|
| 479 |
+
max_size=(base_size_height, base_size_width),
|
| 480 |
+
device=device,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
return freqs_cos, freqs_sin
|
| 484 |
+
|
| 485 |
+
@property
|
| 486 |
+
def guidance_scale(self):
|
| 487 |
+
return self._guidance_scale
|
| 488 |
+
|
| 489 |
+
@property
|
| 490 |
+
def num_timesteps(self):
|
| 491 |
+
return self._num_timesteps
|
| 492 |
+
|
| 493 |
+
@property
|
| 494 |
+
def attention_kwargs(self):
|
| 495 |
+
return self._attention_kwargs
|
| 496 |
+
|
| 497 |
+
@property
|
| 498 |
+
def current_timestep(self):
|
| 499 |
+
return self._current_timestep
|
| 500 |
+
|
| 501 |
+
@property
|
| 502 |
+
def interrupt(self):
|
| 503 |
+
return self._interrupt
|
| 504 |
+
|
| 505 |
+
@torch.no_grad()
|
| 506 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 507 |
+
def __call__(
|
| 508 |
+
self,
|
| 509 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 510 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 511 |
+
height: Optional[int] = None,
|
| 512 |
+
width: Optional[int] = None,
|
| 513 |
+
num_frames: Optional[int] = None,
|
| 514 |
+
num_inference_steps: int = 50,
|
| 515 |
+
timesteps: Optional[List[int]] = None,
|
| 516 |
+
guidance_scale: float = 6,
|
| 517 |
+
use_dynamic_cfg: bool = False,
|
| 518 |
+
num_videos_per_prompt: int = 1,
|
| 519 |
+
eta: float = 0.0,
|
| 520 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 521 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 522 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 523 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 524 |
+
output_type: str = "pil",
|
| 525 |
+
return_dict: bool = True,
|
| 526 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 527 |
+
callback_on_step_end: Optional[
|
| 528 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 529 |
+
] = None,
|
| 530 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 531 |
+
max_sequence_length: int = 226,
|
| 532 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 533 |
+
"""
|
| 534 |
+
Function invoked when calling the pipeline for generation.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 538 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 539 |
+
instead.
|
| 540 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 541 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 542 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 543 |
+
less than `1`).
|
| 544 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 545 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
| 546 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 547 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
| 548 |
+
num_frames (`int`, defaults to `48`):
|
| 549 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 550 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 551 |
+
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
| 552 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 553 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 554 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 555 |
+
expense of slower inference.
|
| 556 |
+
timesteps (`List[int]`, *optional*):
|
| 557 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 558 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 559 |
+
passed will be used. Must be in descending order.
|
| 560 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 561 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 562 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 563 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 564 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 565 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 566 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 567 |
+
The number of videos to generate per prompt.
|
| 568 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 569 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 570 |
+
to make generation deterministic.
|
| 571 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 572 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 573 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 574 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 575 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 576 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 577 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 578 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 579 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 580 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 581 |
+
argument.
|
| 582 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 583 |
+
The output format of the generate image. Choose between
|
| 584 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 585 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 586 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 587 |
+
of a plain tuple.
|
| 588 |
+
attention_kwargs (`dict`, *optional*):
|
| 589 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 590 |
+
`self.processor` in
|
| 591 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 592 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 593 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 594 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 595 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 596 |
+
`callback_on_step_end_tensor_inputs`.
|
| 597 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 598 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 599 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 600 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 601 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 602 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 603 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 604 |
+
|
| 605 |
+
Examples:
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
| 609 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 610 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 611 |
+
"""
|
| 612 |
+
|
| 613 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 614 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 615 |
+
|
| 616 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 617 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 618 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 619 |
+
|
| 620 |
+
num_videos_per_prompt = 1
|
| 621 |
+
|
| 622 |
+
# 1. Check inputs. Raise error if not correct
|
| 623 |
+
self.check_inputs(
|
| 624 |
+
prompt,
|
| 625 |
+
height,
|
| 626 |
+
width,
|
| 627 |
+
negative_prompt,
|
| 628 |
+
callback_on_step_end_tensor_inputs,
|
| 629 |
+
prompt_embeds,
|
| 630 |
+
negative_prompt_embeds,
|
| 631 |
+
)
|
| 632 |
+
self._guidance_scale = guidance_scale
|
| 633 |
+
self._attention_kwargs = attention_kwargs
|
| 634 |
+
self._current_timestep = None
|
| 635 |
+
self._interrupt = False
|
| 636 |
+
|
| 637 |
+
# 2. Default call parameters
|
| 638 |
+
if prompt is not None and isinstance(prompt, str):
|
| 639 |
+
batch_size = 1
|
| 640 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 641 |
+
batch_size = len(prompt)
|
| 642 |
+
else:
|
| 643 |
+
batch_size = prompt_embeds.shape[0]
|
| 644 |
+
|
| 645 |
+
device = self._execution_device
|
| 646 |
+
|
| 647 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 648 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 649 |
+
# corresponds to doing no classifier free guidance.
|
| 650 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 651 |
+
|
| 652 |
+
# 3. Encode input prompt
|
| 653 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 654 |
+
prompt,
|
| 655 |
+
negative_prompt,
|
| 656 |
+
do_classifier_free_guidance,
|
| 657 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 658 |
+
prompt_embeds=prompt_embeds,
|
| 659 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 660 |
+
max_sequence_length=max_sequence_length,
|
| 661 |
+
device=device,
|
| 662 |
+
)
|
| 663 |
+
if do_classifier_free_guidance:
|
| 664 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 665 |
+
|
| 666 |
+
# 4. Prepare timesteps
|
| 667 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 668 |
+
self._num_timesteps = len(timesteps)
|
| 669 |
+
|
| 670 |
+
# 5. Prepare latents
|
| 671 |
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 672 |
+
|
| 673 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 674 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 675 |
+
additional_frames = 0
|
| 676 |
+
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 677 |
+
additional_frames = patch_size_t - latent_frames % patch_size_t
|
| 678 |
+
num_frames += additional_frames * self.vae_scale_factor_temporal
|
| 679 |
+
|
| 680 |
+
latent_channels = self.transformer.config.in_channels
|
| 681 |
+
latents = self.prepare_latents(
|
| 682 |
+
batch_size * num_videos_per_prompt,
|
| 683 |
+
latent_channels,
|
| 684 |
+
num_frames,
|
| 685 |
+
height,
|
| 686 |
+
width,
|
| 687 |
+
prompt_embeds.dtype,
|
| 688 |
+
device,
|
| 689 |
+
generator,
|
| 690 |
+
latents,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 694 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 695 |
+
|
| 696 |
+
# 7. Create rotary embeds if required
|
| 697 |
+
image_rotary_emb = (
|
| 698 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 699 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 700 |
+
else None
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# 8. Denoising loop
|
| 704 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 705 |
+
|
| 706 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 707 |
+
# for DPM-solver++
|
| 708 |
+
old_pred_original_sample = None
|
| 709 |
+
for i, t in enumerate(timesteps):
|
| 710 |
+
if self.interrupt:
|
| 711 |
+
continue
|
| 712 |
+
|
| 713 |
+
self._current_timestep = t
|
| 714 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 715 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 716 |
+
|
| 717 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 718 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 719 |
+
|
| 720 |
+
# predict noise model_output
|
| 721 |
+
with self.transformer.cache_context("cond_uncond"):
|
| 722 |
+
noise_pred = self.transformer(
|
| 723 |
+
hidden_states=latent_model_input,
|
| 724 |
+
encoder_hidden_states=prompt_embeds,
|
| 725 |
+
timestep=timestep,
|
| 726 |
+
image_rotary_emb=image_rotary_emb,
|
| 727 |
+
attention_kwargs=attention_kwargs,
|
| 728 |
+
return_dict=False,
|
| 729 |
+
)[0]
|
| 730 |
+
noise_pred = noise_pred.float()
|
| 731 |
+
|
| 732 |
+
# perform guidance
|
| 733 |
+
if use_dynamic_cfg:
|
| 734 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 735 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 736 |
+
)
|
| 737 |
+
if do_classifier_free_guidance:
|
| 738 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 739 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 740 |
+
|
| 741 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 742 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 743 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 744 |
+
else:
|
| 745 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 746 |
+
noise_pred,
|
| 747 |
+
old_pred_original_sample,
|
| 748 |
+
t,
|
| 749 |
+
timesteps[i - 1] if i > 0 else None,
|
| 750 |
+
latents,
|
| 751 |
+
**extra_step_kwargs,
|
| 752 |
+
return_dict=False,
|
| 753 |
+
)
|
| 754 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 755 |
+
|
| 756 |
+
# call the callback, if provided
|
| 757 |
+
if callback_on_step_end is not None:
|
| 758 |
+
callback_kwargs = {}
|
| 759 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 760 |
+
callback_kwargs[k] = locals()[k]
|
| 761 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 762 |
+
|
| 763 |
+
latents = callback_outputs.pop("latents", latents)
|
| 764 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 765 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 766 |
+
|
| 767 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 768 |
+
progress_bar.update()
|
| 769 |
+
|
| 770 |
+
if XLA_AVAILABLE:
|
| 771 |
+
xm.mark_step()
|
| 772 |
+
|
| 773 |
+
self._current_timestep = None
|
| 774 |
+
|
| 775 |
+
if not output_type == "latent":
|
| 776 |
+
# Discard any padding frames that were added for CogVideoX 1.5
|
| 777 |
+
latents = latents[:, additional_frames:]
|
| 778 |
+
video = self.decode_latents(latents)
|
| 779 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 780 |
+
else:
|
| 781 |
+
video = latents
|
| 782 |
+
|
| 783 |
+
# Offload all models
|
| 784 |
+
self.maybe_free_model_hooks()
|
| 785 |
+
|
| 786 |
+
if not return_dict:
|
| 787 |
+
return (video,)
|
| 788 |
+
|
| 789 |
+
return CogVideoXPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
ADDED
|
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from PIL import Image
|
| 22 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ...loaders import CogVideoXLoraLoaderMixin
|
| 26 |
+
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 27 |
+
from ...models.embeddings import get_3d_rotary_pos_embed
|
| 28 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 29 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 30 |
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 31 |
+
from ...utils.torch_utils import randn_tensor
|
| 32 |
+
from ...video_processor import VideoProcessor
|
| 33 |
+
from .pipeline_output import CogVideoXPipelineOutput
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if is_torch_xla_available():
|
| 37 |
+
import torch_xla.core.xla_model as xm
|
| 38 |
+
|
| 39 |
+
XLA_AVAILABLE = True
|
| 40 |
+
else:
|
| 41 |
+
XLA_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
EXAMPLE_DOC_STRING = """
|
| 47 |
+
Examples:
|
| 48 |
+
```python
|
| 49 |
+
>>> import torch
|
| 50 |
+
>>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler
|
| 51 |
+
>>> from diffusers.utils import export_to_video, load_video
|
| 52 |
+
|
| 53 |
+
>>> pipe = CogVideoXFunControlPipeline.from_pretrained(
|
| 54 |
+
... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16
|
| 55 |
+
... )
|
| 56 |
+
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
| 57 |
+
>>> pipe.to("cuda")
|
| 58 |
+
|
| 59 |
+
>>> control_video = load_video(
|
| 60 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
|
| 61 |
+
... )
|
| 62 |
+
>>> prompt = (
|
| 63 |
+
... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
|
| 64 |
+
... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
|
| 65 |
+
... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
|
| 66 |
+
... "moons, but the remainder of the scene is mostly realistic."
|
| 67 |
+
... )
|
| 68 |
+
|
| 69 |
+
>>> video = pipe(prompt=prompt, control_video=control_video).frames[0]
|
| 70 |
+
>>> export_to_video(video, "output.mp4", fps=8)
|
| 71 |
+
```
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid
|
| 76 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 77 |
+
tw = tgt_width
|
| 78 |
+
th = tgt_height
|
| 79 |
+
h, w = src
|
| 80 |
+
r = h / w
|
| 81 |
+
if r > (th / tw):
|
| 82 |
+
resize_height = th
|
| 83 |
+
resize_width = int(round(th / h * w))
|
| 84 |
+
else:
|
| 85 |
+
resize_width = tw
|
| 86 |
+
resize_height = int(round(tw / w * h))
|
| 87 |
+
|
| 88 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 89 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 90 |
+
|
| 91 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 95 |
+
def retrieve_timesteps(
|
| 96 |
+
scheduler,
|
| 97 |
+
num_inference_steps: Optional[int] = None,
|
| 98 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 99 |
+
timesteps: Optional[List[int]] = None,
|
| 100 |
+
sigmas: Optional[List[float]] = None,
|
| 101 |
+
**kwargs,
|
| 102 |
+
):
|
| 103 |
+
r"""
|
| 104 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 105 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
scheduler (`SchedulerMixin`):
|
| 109 |
+
The scheduler to get timesteps from.
|
| 110 |
+
num_inference_steps (`int`):
|
| 111 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 112 |
+
must be `None`.
|
| 113 |
+
device (`str` or `torch.device`, *optional*):
|
| 114 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 115 |
+
timesteps (`List[int]`, *optional*):
|
| 116 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 117 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 118 |
+
sigmas (`List[float]`, *optional*):
|
| 119 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 120 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 124 |
+
second element is the number of inference steps.
|
| 125 |
+
"""
|
| 126 |
+
if timesteps is not None and sigmas is not None:
|
| 127 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 128 |
+
if timesteps is not None:
|
| 129 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 130 |
+
if not accepts_timesteps:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 133 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 134 |
+
)
|
| 135 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 136 |
+
timesteps = scheduler.timesteps
|
| 137 |
+
num_inference_steps = len(timesteps)
|
| 138 |
+
elif sigmas is not None:
|
| 139 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 140 |
+
if not accept_sigmas:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 143 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 144 |
+
)
|
| 145 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 146 |
+
timesteps = scheduler.timesteps
|
| 147 |
+
num_inference_steps = len(timesteps)
|
| 148 |
+
else:
|
| 149 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 150 |
+
timesteps = scheduler.timesteps
|
| 151 |
+
return timesteps, num_inference_steps
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 155 |
+
r"""
|
| 156 |
+
Pipeline for controlled text-to-video generation using CogVideoX Fun.
|
| 157 |
+
|
| 158 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 159 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
vae ([`AutoencoderKL`]):
|
| 163 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 164 |
+
text_encoder ([`T5EncoderModel`]):
|
| 165 |
+
Frozen text-encoder. CogVideoX uses
|
| 166 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 167 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 168 |
+
tokenizer (`T5Tokenizer`):
|
| 169 |
+
Tokenizer of class
|
| 170 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 171 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 172 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 173 |
+
scheduler ([`SchedulerMixin`]):
|
| 174 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
_optional_components = []
|
| 178 |
+
model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
|
| 179 |
+
|
| 180 |
+
_callback_tensor_inputs = [
|
| 181 |
+
"latents",
|
| 182 |
+
"prompt_embeds",
|
| 183 |
+
"negative_prompt_embeds",
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
tokenizer: T5Tokenizer,
|
| 189 |
+
text_encoder: T5EncoderModel,
|
| 190 |
+
vae: AutoencoderKLCogVideoX,
|
| 191 |
+
transformer: CogVideoXTransformer3DModel,
|
| 192 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
|
| 196 |
+
self.register_modules(
|
| 197 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 198 |
+
)
|
| 199 |
+
self.vae_scale_factor_spatial = (
|
| 200 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 201 |
+
)
|
| 202 |
+
self.vae_scale_factor_temporal = (
|
| 203 |
+
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 204 |
+
)
|
| 205 |
+
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
| 206 |
+
|
| 207 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 208 |
+
|
| 209 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
| 210 |
+
def _get_t5_prompt_embeds(
|
| 211 |
+
self,
|
| 212 |
+
prompt: Union[str, List[str]] = None,
|
| 213 |
+
num_videos_per_prompt: int = 1,
|
| 214 |
+
max_sequence_length: int = 226,
|
| 215 |
+
device: Optional[torch.device] = None,
|
| 216 |
+
dtype: Optional[torch.dtype] = None,
|
| 217 |
+
):
|
| 218 |
+
device = device or self._execution_device
|
| 219 |
+
dtype = dtype or self.text_encoder.dtype
|
| 220 |
+
|
| 221 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 222 |
+
batch_size = len(prompt)
|
| 223 |
+
|
| 224 |
+
text_inputs = self.tokenizer(
|
| 225 |
+
prompt,
|
| 226 |
+
padding="max_length",
|
| 227 |
+
max_length=max_sequence_length,
|
| 228 |
+
truncation=True,
|
| 229 |
+
add_special_tokens=True,
|
| 230 |
+
return_tensors="pt",
|
| 231 |
+
)
|
| 232 |
+
text_input_ids = text_inputs.input_ids
|
| 233 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 234 |
+
|
| 235 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 236 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 237 |
+
logger.warning(
|
| 238 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 239 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 243 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 244 |
+
|
| 245 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 246 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 247 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 248 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 249 |
+
|
| 250 |
+
return prompt_embeds
|
| 251 |
+
|
| 252 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
| 253 |
+
def encode_prompt(
|
| 254 |
+
self,
|
| 255 |
+
prompt: Union[str, List[str]],
|
| 256 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 257 |
+
do_classifier_free_guidance: bool = True,
|
| 258 |
+
num_videos_per_prompt: int = 1,
|
| 259 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 260 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 261 |
+
max_sequence_length: int = 226,
|
| 262 |
+
device: Optional[torch.device] = None,
|
| 263 |
+
dtype: Optional[torch.dtype] = None,
|
| 264 |
+
):
|
| 265 |
+
r"""
|
| 266 |
+
Encodes the prompt into text encoder hidden states.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 270 |
+
prompt to be encoded
|
| 271 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 272 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 273 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 274 |
+
less than `1`).
|
| 275 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 276 |
+
Whether to use classifier free guidance or not.
|
| 277 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 278 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 279 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 280 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 281 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 282 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 283 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 284 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 285 |
+
argument.
|
| 286 |
+
device: (`torch.device`, *optional*):
|
| 287 |
+
torch device
|
| 288 |
+
dtype: (`torch.dtype`, *optional*):
|
| 289 |
+
torch dtype
|
| 290 |
+
"""
|
| 291 |
+
device = device or self._execution_device
|
| 292 |
+
|
| 293 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 294 |
+
if prompt is not None:
|
| 295 |
+
batch_size = len(prompt)
|
| 296 |
+
else:
|
| 297 |
+
batch_size = prompt_embeds.shape[0]
|
| 298 |
+
|
| 299 |
+
if prompt_embeds is None:
|
| 300 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 301 |
+
prompt=prompt,
|
| 302 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 303 |
+
max_sequence_length=max_sequence_length,
|
| 304 |
+
device=device,
|
| 305 |
+
dtype=dtype,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 309 |
+
negative_prompt = negative_prompt or ""
|
| 310 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 311 |
+
|
| 312 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 313 |
+
raise TypeError(
|
| 314 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 315 |
+
f" {type(prompt)}."
|
| 316 |
+
)
|
| 317 |
+
elif batch_size != len(negative_prompt):
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 320 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 321 |
+
" the batch size of `prompt`."
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 325 |
+
prompt=negative_prompt,
|
| 326 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 327 |
+
max_sequence_length=max_sequence_length,
|
| 328 |
+
device=device,
|
| 329 |
+
dtype=dtype,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
return prompt_embeds, negative_prompt_embeds
|
| 333 |
+
|
| 334 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents
|
| 335 |
+
def prepare_latents(
|
| 336 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 337 |
+
):
|
| 338 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 341 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
shape = (
|
| 345 |
+
batch_size,
|
| 346 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 347 |
+
num_channels_latents,
|
| 348 |
+
height // self.vae_scale_factor_spatial,
|
| 349 |
+
width // self.vae_scale_factor_spatial,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if latents is None:
|
| 353 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 354 |
+
else:
|
| 355 |
+
latents = latents.to(device)
|
| 356 |
+
|
| 357 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 358 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 359 |
+
return latents
|
| 360 |
+
|
| 361 |
+
# Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366
|
| 362 |
+
def prepare_control_latents(
|
| 363 |
+
self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None
|
| 364 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 365 |
+
if mask is not None:
|
| 366 |
+
masks = []
|
| 367 |
+
for i in range(mask.size(0)):
|
| 368 |
+
current_mask = mask[i].unsqueeze(0)
|
| 369 |
+
current_mask = self.vae.encode(current_mask)[0]
|
| 370 |
+
current_mask = current_mask.mode()
|
| 371 |
+
masks.append(current_mask)
|
| 372 |
+
mask = torch.cat(masks, dim=0)
|
| 373 |
+
mask = mask * self.vae.config.scaling_factor
|
| 374 |
+
|
| 375 |
+
if masked_image is not None:
|
| 376 |
+
mask_pixel_values = []
|
| 377 |
+
for i in range(masked_image.size(0)):
|
| 378 |
+
mask_pixel_value = masked_image[i].unsqueeze(0)
|
| 379 |
+
mask_pixel_value = self.vae.encode(mask_pixel_value)[0]
|
| 380 |
+
mask_pixel_value = mask_pixel_value.mode()
|
| 381 |
+
mask_pixel_values.append(mask_pixel_value)
|
| 382 |
+
masked_image_latents = torch.cat(mask_pixel_values, dim=0)
|
| 383 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 384 |
+
else:
|
| 385 |
+
masked_image_latents = None
|
| 386 |
+
|
| 387 |
+
return mask, masked_image_latents
|
| 388 |
+
|
| 389 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
|
| 390 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 391 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 392 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 393 |
+
|
| 394 |
+
frames = self.vae.decode(latents).sample
|
| 395 |
+
return frames
|
| 396 |
+
|
| 397 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 398 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 399 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 400 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 401 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 402 |
+
# and should be between [0, 1]
|
| 403 |
+
|
| 404 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 405 |
+
extra_step_kwargs = {}
|
| 406 |
+
if accepts_eta:
|
| 407 |
+
extra_step_kwargs["eta"] = eta
|
| 408 |
+
|
| 409 |
+
# check if the scheduler accepts generator
|
| 410 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 411 |
+
if accepts_generator:
|
| 412 |
+
extra_step_kwargs["generator"] = generator
|
| 413 |
+
return extra_step_kwargs
|
| 414 |
+
|
| 415 |
+
def check_inputs(
|
| 416 |
+
self,
|
| 417 |
+
prompt,
|
| 418 |
+
height,
|
| 419 |
+
width,
|
| 420 |
+
negative_prompt,
|
| 421 |
+
callback_on_step_end_tensor_inputs,
|
| 422 |
+
prompt_embeds=None,
|
| 423 |
+
negative_prompt_embeds=None,
|
| 424 |
+
control_video=None,
|
| 425 |
+
control_video_latents=None,
|
| 426 |
+
):
|
| 427 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 428 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 429 |
+
|
| 430 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 431 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 432 |
+
):
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 435 |
+
)
|
| 436 |
+
if prompt is not None and prompt_embeds is not None:
|
| 437 |
+
raise ValueError(
|
| 438 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 439 |
+
" only forward one of the two."
|
| 440 |
+
)
|
| 441 |
+
elif prompt is None and prompt_embeds is None:
|
| 442 |
+
raise ValueError(
|
| 443 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 444 |
+
)
|
| 445 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 446 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 447 |
+
|
| 448 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 449 |
+
raise ValueError(
|
| 450 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 451 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 455 |
+
raise ValueError(
|
| 456 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 457 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 461 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 462 |
+
raise ValueError(
|
| 463 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 464 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 465 |
+
f" {negative_prompt_embeds.shape}."
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if control_video is not None and control_video_latents is not None:
|
| 469 |
+
raise ValueError(
|
| 470 |
+
"Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters."
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def fuse_qkv_projections(self) -> None:
|
| 474 |
+
r"""Enables fused QKV projections."""
|
| 475 |
+
self.fusing_transformer = True
|
| 476 |
+
self.transformer.fuse_qkv_projections()
|
| 477 |
+
|
| 478 |
+
def unfuse_qkv_projections(self) -> None:
|
| 479 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 480 |
+
if not self.fusing_transformer:
|
| 481 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 482 |
+
else:
|
| 483 |
+
self.transformer.unfuse_qkv_projections()
|
| 484 |
+
self.fusing_transformer = False
|
| 485 |
+
|
| 486 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
|
| 487 |
+
def _prepare_rotary_positional_embeddings(
|
| 488 |
+
self,
|
| 489 |
+
height: int,
|
| 490 |
+
width: int,
|
| 491 |
+
num_frames: int,
|
| 492 |
+
device: torch.device,
|
| 493 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 494 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 495 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 496 |
+
|
| 497 |
+
p = self.transformer.config.patch_size
|
| 498 |
+
p_t = self.transformer.config.patch_size_t
|
| 499 |
+
|
| 500 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 501 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 502 |
+
|
| 503 |
+
if p_t is None:
|
| 504 |
+
# CogVideoX 1.0
|
| 505 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 506 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 507 |
+
)
|
| 508 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 509 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 510 |
+
crops_coords=grid_crops_coords,
|
| 511 |
+
grid_size=(grid_height, grid_width),
|
| 512 |
+
temporal_size=num_frames,
|
| 513 |
+
device=device,
|
| 514 |
+
)
|
| 515 |
+
else:
|
| 516 |
+
# CogVideoX 1.5
|
| 517 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 518 |
+
|
| 519 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 520 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 521 |
+
crops_coords=None,
|
| 522 |
+
grid_size=(grid_height, grid_width),
|
| 523 |
+
temporal_size=base_num_frames,
|
| 524 |
+
grid_type="slice",
|
| 525 |
+
max_size=(base_size_height, base_size_width),
|
| 526 |
+
device=device,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
return freqs_cos, freqs_sin
|
| 530 |
+
|
| 531 |
+
@property
|
| 532 |
+
def guidance_scale(self):
|
| 533 |
+
return self._guidance_scale
|
| 534 |
+
|
| 535 |
+
@property
|
| 536 |
+
def num_timesteps(self):
|
| 537 |
+
return self._num_timesteps
|
| 538 |
+
|
| 539 |
+
@property
|
| 540 |
+
def attention_kwargs(self):
|
| 541 |
+
return self._attention_kwargs
|
| 542 |
+
|
| 543 |
+
@property
|
| 544 |
+
def current_timestep(self):
|
| 545 |
+
return self._current_timestep
|
| 546 |
+
|
| 547 |
+
@property
|
| 548 |
+
def interrupt(self):
|
| 549 |
+
return self._interrupt
|
| 550 |
+
|
| 551 |
+
@torch.no_grad()
|
| 552 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 553 |
+
def __call__(
|
| 554 |
+
self,
|
| 555 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 556 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 557 |
+
control_video: Optional[List[Image.Image]] = None,
|
| 558 |
+
height: Optional[int] = None,
|
| 559 |
+
width: Optional[int] = None,
|
| 560 |
+
num_inference_steps: int = 50,
|
| 561 |
+
timesteps: Optional[List[int]] = None,
|
| 562 |
+
guidance_scale: float = 6,
|
| 563 |
+
use_dynamic_cfg: bool = False,
|
| 564 |
+
num_videos_per_prompt: int = 1,
|
| 565 |
+
eta: float = 0.0,
|
| 566 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 567 |
+
latents: Optional[torch.Tensor] = None,
|
| 568 |
+
control_video_latents: Optional[torch.Tensor] = None,
|
| 569 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 570 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 571 |
+
output_type: str = "pil",
|
| 572 |
+
return_dict: bool = True,
|
| 573 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 574 |
+
callback_on_step_end: Optional[
|
| 575 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 576 |
+
] = None,
|
| 577 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 578 |
+
max_sequence_length: int = 226,
|
| 579 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 580 |
+
"""
|
| 581 |
+
Function invoked when calling the pipeline for generation.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 585 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 586 |
+
instead.
|
| 587 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 588 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 589 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 590 |
+
less than `1`).
|
| 591 |
+
control_video (`List[PIL.Image.Image]`):
|
| 592 |
+
The control video to condition the generation on. Must be a list of images/frames of the video. If not
|
| 593 |
+
provided, `control_video_latents` must be provided.
|
| 594 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 595 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
| 596 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 597 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
| 598 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 599 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 600 |
+
expense of slower inference.
|
| 601 |
+
timesteps (`List[int]`, *optional*):
|
| 602 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 603 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 604 |
+
passed will be used. Must be in descending order.
|
| 605 |
+
guidance_scale (`float`, *optional*, defaults to 6.0):
|
| 606 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 607 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 608 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 609 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 610 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 611 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 612 |
+
The number of videos to generate per prompt.
|
| 613 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 614 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 615 |
+
to make generation deterministic.
|
| 616 |
+
latents (`torch.Tensor`, *optional*):
|
| 617 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
| 618 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 619 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 620 |
+
control_video_latents (`torch.Tensor`, *optional*):
|
| 621 |
+
Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for
|
| 622 |
+
controlled video generation. If not provided, `control_video` must be provided.
|
| 623 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 624 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 625 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 626 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 627 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 628 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 629 |
+
argument.
|
| 630 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 631 |
+
The output format of the generate image. Choose between
|
| 632 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 633 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 634 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 635 |
+
of a plain tuple.
|
| 636 |
+
attention_kwargs (`dict`, *optional*):
|
| 637 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 638 |
+
`self.processor` in
|
| 639 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 640 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 641 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 642 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 643 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 644 |
+
`callback_on_step_end_tensor_inputs`.
|
| 645 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 646 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 647 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 648 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 649 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 650 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 651 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 652 |
+
|
| 653 |
+
Examples:
|
| 654 |
+
|
| 655 |
+
Returns:
|
| 656 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
| 657 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 658 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 659 |
+
"""
|
| 660 |
+
|
| 661 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 662 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 663 |
+
|
| 664 |
+
if control_video is not None and isinstance(control_video[0], Image.Image):
|
| 665 |
+
control_video = [control_video]
|
| 666 |
+
|
| 667 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 668 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 669 |
+
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
|
| 670 |
+
|
| 671 |
+
num_videos_per_prompt = 1
|
| 672 |
+
|
| 673 |
+
# 1. Check inputs. Raise error if not correct
|
| 674 |
+
self.check_inputs(
|
| 675 |
+
prompt,
|
| 676 |
+
height,
|
| 677 |
+
width,
|
| 678 |
+
negative_prompt,
|
| 679 |
+
callback_on_step_end_tensor_inputs,
|
| 680 |
+
prompt_embeds,
|
| 681 |
+
negative_prompt_embeds,
|
| 682 |
+
control_video,
|
| 683 |
+
control_video_latents,
|
| 684 |
+
)
|
| 685 |
+
self._guidance_scale = guidance_scale
|
| 686 |
+
self._attention_kwargs = attention_kwargs
|
| 687 |
+
self._current_timestep = None
|
| 688 |
+
self._interrupt = False
|
| 689 |
+
|
| 690 |
+
# 2. Default call parameters
|
| 691 |
+
if prompt is not None and isinstance(prompt, str):
|
| 692 |
+
batch_size = 1
|
| 693 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 694 |
+
batch_size = len(prompt)
|
| 695 |
+
else:
|
| 696 |
+
batch_size = prompt_embeds.shape[0]
|
| 697 |
+
|
| 698 |
+
device = self._execution_device
|
| 699 |
+
|
| 700 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 701 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 702 |
+
# corresponds to doing no classifier free guidance.
|
| 703 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 704 |
+
|
| 705 |
+
# 3. Encode input prompt
|
| 706 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 707 |
+
prompt,
|
| 708 |
+
negative_prompt,
|
| 709 |
+
do_classifier_free_guidance,
|
| 710 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 711 |
+
prompt_embeds=prompt_embeds,
|
| 712 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 713 |
+
max_sequence_length=max_sequence_length,
|
| 714 |
+
device=device,
|
| 715 |
+
)
|
| 716 |
+
if do_classifier_free_guidance:
|
| 717 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 718 |
+
|
| 719 |
+
# 4. Prepare timesteps
|
| 720 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 721 |
+
self._num_timesteps = len(timesteps)
|
| 722 |
+
|
| 723 |
+
# 5. Prepare latents
|
| 724 |
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 725 |
+
|
| 726 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 727 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 728 |
+
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 729 |
+
raise ValueError(
|
| 730 |
+
f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
|
| 731 |
+
f"contains {latent_frames=}, which is not divisible."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
latent_channels = self.transformer.config.in_channels // 2
|
| 735 |
+
latents = self.prepare_latents(
|
| 736 |
+
batch_size * num_videos_per_prompt,
|
| 737 |
+
latent_channels,
|
| 738 |
+
num_frames,
|
| 739 |
+
height,
|
| 740 |
+
width,
|
| 741 |
+
prompt_embeds.dtype,
|
| 742 |
+
device,
|
| 743 |
+
generator,
|
| 744 |
+
latents,
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
if control_video_latents is None:
|
| 748 |
+
control_video = self.video_processor.preprocess_video(control_video, height=height, width=width)
|
| 749 |
+
control_video = control_video.to(device=device, dtype=prompt_embeds.dtype)
|
| 750 |
+
|
| 751 |
+
_, control_video_latents = self.prepare_control_latents(None, control_video)
|
| 752 |
+
control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4)
|
| 753 |
+
|
| 754 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 755 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 756 |
+
|
| 757 |
+
# 7. Create rotary embeds if required
|
| 758 |
+
image_rotary_emb = (
|
| 759 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 760 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 761 |
+
else None
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# 8. Denoising loop
|
| 765 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 766 |
+
|
| 767 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 768 |
+
# for DPM-solver++
|
| 769 |
+
old_pred_original_sample = None
|
| 770 |
+
for i, t in enumerate(timesteps):
|
| 771 |
+
if self.interrupt:
|
| 772 |
+
continue
|
| 773 |
+
|
| 774 |
+
self._current_timestep = t
|
| 775 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 776 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 777 |
+
|
| 778 |
+
latent_control_input = (
|
| 779 |
+
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
|
| 780 |
+
)
|
| 781 |
+
latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2)
|
| 782 |
+
|
| 783 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 784 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 785 |
+
|
| 786 |
+
# predict noise model_output
|
| 787 |
+
with self.transformer.cache_context("cond_uncond"):
|
| 788 |
+
noise_pred = self.transformer(
|
| 789 |
+
hidden_states=latent_model_input,
|
| 790 |
+
encoder_hidden_states=prompt_embeds,
|
| 791 |
+
timestep=timestep,
|
| 792 |
+
image_rotary_emb=image_rotary_emb,
|
| 793 |
+
attention_kwargs=attention_kwargs,
|
| 794 |
+
return_dict=False,
|
| 795 |
+
)[0]
|
| 796 |
+
noise_pred = noise_pred.float()
|
| 797 |
+
|
| 798 |
+
# perform guidance
|
| 799 |
+
if use_dynamic_cfg:
|
| 800 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 801 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 802 |
+
)
|
| 803 |
+
if do_classifier_free_guidance:
|
| 804 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 805 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 806 |
+
|
| 807 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 808 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 809 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 810 |
+
|
| 811 |
+
# call the callback, if provided
|
| 812 |
+
if callback_on_step_end is not None:
|
| 813 |
+
callback_kwargs = {}
|
| 814 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 815 |
+
callback_kwargs[k] = locals()[k]
|
| 816 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 817 |
+
|
| 818 |
+
latents = callback_outputs.pop("latents", latents)
|
| 819 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 820 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 821 |
+
|
| 822 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 823 |
+
progress_bar.update()
|
| 824 |
+
|
| 825 |
+
if XLA_AVAILABLE:
|
| 826 |
+
xm.mark_step()
|
| 827 |
+
|
| 828 |
+
self._current_timestep = None
|
| 829 |
+
|
| 830 |
+
if not output_type == "latent":
|
| 831 |
+
video = self.decode_latents(latents)
|
| 832 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 833 |
+
else:
|
| 834 |
+
video = latents
|
| 835 |
+
|
| 836 |
+
# Offload all models
|
| 837 |
+
self.maybe_free_model_hooks()
|
| 838 |
+
|
| 839 |
+
if not return_dict:
|
| 840 |
+
return (video,)
|
| 841 |
+
|
| 842 |
+
return CogVideoXPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import PIL
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ...image_processor import PipelineImageInput
|
| 26 |
+
from ...loaders import CogVideoXLoraLoaderMixin
|
| 27 |
+
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 28 |
+
from ...models.embeddings import get_3d_rotary_pos_embed
|
| 29 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 30 |
+
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 31 |
+
from ...utils import (
|
| 32 |
+
is_torch_xla_available,
|
| 33 |
+
logging,
|
| 34 |
+
replace_example_docstring,
|
| 35 |
+
)
|
| 36 |
+
from ...utils.torch_utils import randn_tensor
|
| 37 |
+
from ...video_processor import VideoProcessor
|
| 38 |
+
from .pipeline_output import CogVideoXPipelineOutput
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if is_torch_xla_available():
|
| 42 |
+
import torch_xla.core.xla_model as xm
|
| 43 |
+
|
| 44 |
+
XLA_AVAILABLE = True
|
| 45 |
+
else:
|
| 46 |
+
XLA_AVAILABLE = False
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
EXAMPLE_DOC_STRING = """
|
| 52 |
+
Examples:
|
| 53 |
+
```py
|
| 54 |
+
>>> import torch
|
| 55 |
+
>>> from diffusers import CogVideoXImageToVideoPipeline
|
| 56 |
+
>>> from diffusers.utils import export_to_video, load_image
|
| 57 |
+
|
| 58 |
+
>>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
|
| 59 |
+
>>> pipe.to("cuda")
|
| 60 |
+
|
| 61 |
+
>>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
| 62 |
+
>>> image = load_image(
|
| 63 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
| 64 |
+
... )
|
| 65 |
+
>>> video = pipe(image, prompt, use_dynamic_cfg=True)
|
| 66 |
+
>>> export_to_video(video.frames[0], "output.mp4", fps=8)
|
| 67 |
+
```
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 72 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 73 |
+
tw = tgt_width
|
| 74 |
+
th = tgt_height
|
| 75 |
+
h, w = src
|
| 76 |
+
r = h / w
|
| 77 |
+
if r > (th / tw):
|
| 78 |
+
resize_height = th
|
| 79 |
+
resize_width = int(round(th / h * w))
|
| 80 |
+
else:
|
| 81 |
+
resize_width = tw
|
| 82 |
+
resize_height = int(round(tw / w * h))
|
| 83 |
+
|
| 84 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 85 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 86 |
+
|
| 87 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 91 |
+
def retrieve_timesteps(
|
| 92 |
+
scheduler,
|
| 93 |
+
num_inference_steps: Optional[int] = None,
|
| 94 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 95 |
+
timesteps: Optional[List[int]] = None,
|
| 96 |
+
sigmas: Optional[List[float]] = None,
|
| 97 |
+
**kwargs,
|
| 98 |
+
):
|
| 99 |
+
r"""
|
| 100 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 101 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
scheduler (`SchedulerMixin`):
|
| 105 |
+
The scheduler to get timesteps from.
|
| 106 |
+
num_inference_steps (`int`):
|
| 107 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 108 |
+
must be `None`.
|
| 109 |
+
device (`str` or `torch.device`, *optional*):
|
| 110 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 111 |
+
timesteps (`List[int]`, *optional*):
|
| 112 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 113 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 114 |
+
sigmas (`List[float]`, *optional*):
|
| 115 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 116 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 120 |
+
second element is the number of inference steps.
|
| 121 |
+
"""
|
| 122 |
+
if timesteps is not None and sigmas is not None:
|
| 123 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 124 |
+
if timesteps is not None:
|
| 125 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 126 |
+
if not accepts_timesteps:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 129 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 130 |
+
)
|
| 131 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 132 |
+
timesteps = scheduler.timesteps
|
| 133 |
+
num_inference_steps = len(timesteps)
|
| 134 |
+
elif sigmas is not None:
|
| 135 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 136 |
+
if not accept_sigmas:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 139 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 140 |
+
)
|
| 141 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 142 |
+
timesteps = scheduler.timesteps
|
| 143 |
+
num_inference_steps = len(timesteps)
|
| 144 |
+
else:
|
| 145 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 146 |
+
timesteps = scheduler.timesteps
|
| 147 |
+
return timesteps, num_inference_steps
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 151 |
+
def retrieve_latents(
|
| 152 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 153 |
+
):
|
| 154 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 155 |
+
return encoder_output.latent_dist.sample(generator)
|
| 156 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 157 |
+
return encoder_output.latent_dist.mode()
|
| 158 |
+
elif hasattr(encoder_output, "latents"):
|
| 159 |
+
return encoder_output.latents
|
| 160 |
+
else:
|
| 161 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 165 |
+
r"""
|
| 166 |
+
Pipeline for image-to-video generation using CogVideoX.
|
| 167 |
+
|
| 168 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 169 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
vae ([`AutoencoderKL`]):
|
| 173 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 174 |
+
text_encoder ([`T5EncoderModel`]):
|
| 175 |
+
Frozen text-encoder. CogVideoX uses
|
| 176 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 177 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 178 |
+
tokenizer (`T5Tokenizer`):
|
| 179 |
+
Tokenizer of class
|
| 180 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 181 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 182 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 183 |
+
scheduler ([`SchedulerMixin`]):
|
| 184 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
_optional_components = []
|
| 188 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 189 |
+
|
| 190 |
+
_callback_tensor_inputs = [
|
| 191 |
+
"latents",
|
| 192 |
+
"prompt_embeds",
|
| 193 |
+
"negative_prompt_embeds",
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
tokenizer: T5Tokenizer,
|
| 199 |
+
text_encoder: T5EncoderModel,
|
| 200 |
+
vae: AutoencoderKLCogVideoX,
|
| 201 |
+
transformer: CogVideoXTransformer3DModel,
|
| 202 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 203 |
+
):
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
self.register_modules(
|
| 207 |
+
tokenizer=tokenizer,
|
| 208 |
+
text_encoder=text_encoder,
|
| 209 |
+
vae=vae,
|
| 210 |
+
transformer=transformer,
|
| 211 |
+
scheduler=scheduler,
|
| 212 |
+
)
|
| 213 |
+
self.vae_scale_factor_spatial = (
|
| 214 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 215 |
+
)
|
| 216 |
+
self.vae_scale_factor_temporal = (
|
| 217 |
+
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 218 |
+
)
|
| 219 |
+
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
| 220 |
+
|
| 221 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 222 |
+
|
| 223 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
| 224 |
+
def _get_t5_prompt_embeds(
|
| 225 |
+
self,
|
| 226 |
+
prompt: Union[str, List[str]] = None,
|
| 227 |
+
num_videos_per_prompt: int = 1,
|
| 228 |
+
max_sequence_length: int = 226,
|
| 229 |
+
device: Optional[torch.device] = None,
|
| 230 |
+
dtype: Optional[torch.dtype] = None,
|
| 231 |
+
):
|
| 232 |
+
device = device or self._execution_device
|
| 233 |
+
dtype = dtype or self.text_encoder.dtype
|
| 234 |
+
|
| 235 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 236 |
+
batch_size = len(prompt)
|
| 237 |
+
|
| 238 |
+
text_inputs = self.tokenizer(
|
| 239 |
+
prompt,
|
| 240 |
+
padding="max_length",
|
| 241 |
+
max_length=max_sequence_length,
|
| 242 |
+
truncation=True,
|
| 243 |
+
add_special_tokens=True,
|
| 244 |
+
return_tensors="pt",
|
| 245 |
+
)
|
| 246 |
+
text_input_ids = text_inputs.input_ids
|
| 247 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 248 |
+
|
| 249 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 250 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 251 |
+
logger.warning(
|
| 252 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 253 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 257 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 258 |
+
|
| 259 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 260 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 261 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 262 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 263 |
+
|
| 264 |
+
return prompt_embeds
|
| 265 |
+
|
| 266 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
| 267 |
+
def encode_prompt(
|
| 268 |
+
self,
|
| 269 |
+
prompt: Union[str, List[str]],
|
| 270 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 271 |
+
do_classifier_free_guidance: bool = True,
|
| 272 |
+
num_videos_per_prompt: int = 1,
|
| 273 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 274 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 275 |
+
max_sequence_length: int = 226,
|
| 276 |
+
device: Optional[torch.device] = None,
|
| 277 |
+
dtype: Optional[torch.dtype] = None,
|
| 278 |
+
):
|
| 279 |
+
r"""
|
| 280 |
+
Encodes the prompt into text encoder hidden states.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 284 |
+
prompt to be encoded
|
| 285 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 286 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 287 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 288 |
+
less than `1`).
|
| 289 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 290 |
+
Whether to use classifier free guidance or not.
|
| 291 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 292 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 293 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 294 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 295 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 296 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 297 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 298 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 299 |
+
argument.
|
| 300 |
+
device: (`torch.device`, *optional*):
|
| 301 |
+
torch device
|
| 302 |
+
dtype: (`torch.dtype`, *optional*):
|
| 303 |
+
torch dtype
|
| 304 |
+
"""
|
| 305 |
+
device = device or self._execution_device
|
| 306 |
+
|
| 307 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 308 |
+
if prompt is not None:
|
| 309 |
+
batch_size = len(prompt)
|
| 310 |
+
else:
|
| 311 |
+
batch_size = prompt_embeds.shape[0]
|
| 312 |
+
|
| 313 |
+
if prompt_embeds is None:
|
| 314 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 315 |
+
prompt=prompt,
|
| 316 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 317 |
+
max_sequence_length=max_sequence_length,
|
| 318 |
+
device=device,
|
| 319 |
+
dtype=dtype,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 323 |
+
negative_prompt = negative_prompt or ""
|
| 324 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 325 |
+
|
| 326 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 327 |
+
raise TypeError(
|
| 328 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 329 |
+
f" {type(prompt)}."
|
| 330 |
+
)
|
| 331 |
+
elif batch_size != len(negative_prompt):
|
| 332 |
+
raise ValueError(
|
| 333 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 334 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 335 |
+
" the batch size of `prompt`."
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 339 |
+
prompt=negative_prompt,
|
| 340 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 341 |
+
max_sequence_length=max_sequence_length,
|
| 342 |
+
device=device,
|
| 343 |
+
dtype=dtype,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
return prompt_embeds, negative_prompt_embeds
|
| 347 |
+
|
| 348 |
+
def prepare_latents(
|
| 349 |
+
self,
|
| 350 |
+
image: torch.Tensor,
|
| 351 |
+
batch_size: int = 1,
|
| 352 |
+
num_channels_latents: int = 16,
|
| 353 |
+
num_frames: int = 13,
|
| 354 |
+
height: int = 60,
|
| 355 |
+
width: int = 90,
|
| 356 |
+
dtype: Optional[torch.dtype] = None,
|
| 357 |
+
device: Optional[torch.device] = None,
|
| 358 |
+
generator: Optional[torch.Generator] = None,
|
| 359 |
+
latents: Optional[torch.Tensor] = None,
|
| 360 |
+
):
|
| 361 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 364 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 368 |
+
shape = (
|
| 369 |
+
batch_size,
|
| 370 |
+
num_frames,
|
| 371 |
+
num_channels_latents,
|
| 372 |
+
height // self.vae_scale_factor_spatial,
|
| 373 |
+
width // self.vae_scale_factor_spatial,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
|
| 377 |
+
if self.transformer.config.patch_size_t is not None:
|
| 378 |
+
shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
|
| 379 |
+
|
| 380 |
+
image = image.unsqueeze(2) # [B, C, F, H, W]
|
| 381 |
+
|
| 382 |
+
if isinstance(generator, list):
|
| 383 |
+
image_latents = [
|
| 384 |
+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
| 385 |
+
]
|
| 386 |
+
else:
|
| 387 |
+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
| 388 |
+
|
| 389 |
+
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
| 390 |
+
|
| 391 |
+
if not self.vae.config.invert_scale_latents:
|
| 392 |
+
image_latents = self.vae_scaling_factor_image * image_latents
|
| 393 |
+
else:
|
| 394 |
+
# This is awkward but required because the CogVideoX team forgot to multiply the
|
| 395 |
+
# scaling factor during training :)
|
| 396 |
+
image_latents = 1 / self.vae_scaling_factor_image * image_latents
|
| 397 |
+
|
| 398 |
+
padding_shape = (
|
| 399 |
+
batch_size,
|
| 400 |
+
num_frames - 1,
|
| 401 |
+
num_channels_latents,
|
| 402 |
+
height // self.vae_scale_factor_spatial,
|
| 403 |
+
width // self.vae_scale_factor_spatial,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
|
| 407 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
| 408 |
+
|
| 409 |
+
# Select the first frame along the second dimension
|
| 410 |
+
if self.transformer.config.patch_size_t is not None:
|
| 411 |
+
first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
| 412 |
+
image_latents = torch.cat([first_frame, image_latents], dim=1)
|
| 413 |
+
|
| 414 |
+
if latents is None:
|
| 415 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 416 |
+
else:
|
| 417 |
+
latents = latents.to(device)
|
| 418 |
+
|
| 419 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 420 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 421 |
+
return latents, image_latents
|
| 422 |
+
|
| 423 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
|
| 424 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 425 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 426 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 427 |
+
|
| 428 |
+
frames = self.vae.decode(latents).sample
|
| 429 |
+
return frames
|
| 430 |
+
|
| 431 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
|
| 432 |
+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
| 433 |
+
# get the original timestep using init_timestep
|
| 434 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 435 |
+
|
| 436 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 437 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 438 |
+
|
| 439 |
+
return timesteps, num_inference_steps - t_start
|
| 440 |
+
|
| 441 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 442 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 443 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 444 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 445 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 446 |
+
# and should be between [0, 1]
|
| 447 |
+
|
| 448 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 449 |
+
extra_step_kwargs = {}
|
| 450 |
+
if accepts_eta:
|
| 451 |
+
extra_step_kwargs["eta"] = eta
|
| 452 |
+
|
| 453 |
+
# check if the scheduler accepts generator
|
| 454 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 455 |
+
if accepts_generator:
|
| 456 |
+
extra_step_kwargs["generator"] = generator
|
| 457 |
+
return extra_step_kwargs
|
| 458 |
+
|
| 459 |
+
def check_inputs(
|
| 460 |
+
self,
|
| 461 |
+
image,
|
| 462 |
+
prompt,
|
| 463 |
+
height,
|
| 464 |
+
width,
|
| 465 |
+
negative_prompt,
|
| 466 |
+
callback_on_step_end_tensor_inputs,
|
| 467 |
+
latents=None,
|
| 468 |
+
prompt_embeds=None,
|
| 469 |
+
negative_prompt_embeds=None,
|
| 470 |
+
):
|
| 471 |
+
if (
|
| 472 |
+
not isinstance(image, torch.Tensor)
|
| 473 |
+
and not isinstance(image, PIL.Image.Image)
|
| 474 |
+
and not isinstance(image, list)
|
| 475 |
+
):
|
| 476 |
+
raise ValueError(
|
| 477 |
+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 478 |
+
f" {type(image)}"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 482 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 483 |
+
|
| 484 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 485 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 486 |
+
):
|
| 487 |
+
raise ValueError(
|
| 488 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 489 |
+
)
|
| 490 |
+
if prompt is not None and prompt_embeds is not None:
|
| 491 |
+
raise ValueError(
|
| 492 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 493 |
+
" only forward one of the two."
|
| 494 |
+
)
|
| 495 |
+
elif prompt is None and prompt_embeds is None:
|
| 496 |
+
raise ValueError(
|
| 497 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 498 |
+
)
|
| 499 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 500 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 501 |
+
|
| 502 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 503 |
+
raise ValueError(
|
| 504 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 505 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 509 |
+
raise ValueError(
|
| 510 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 511 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 515 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 516 |
+
raise ValueError(
|
| 517 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 518 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 519 |
+
f" {negative_prompt_embeds.shape}."
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
|
| 523 |
+
def fuse_qkv_projections(self) -> None:
|
| 524 |
+
r"""Enables fused QKV projections."""
|
| 525 |
+
self.fusing_transformer = True
|
| 526 |
+
self.transformer.fuse_qkv_projections()
|
| 527 |
+
|
| 528 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
|
| 529 |
+
def unfuse_qkv_projections(self) -> None:
|
| 530 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 531 |
+
if not self.fusing_transformer:
|
| 532 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 533 |
+
else:
|
| 534 |
+
self.transformer.unfuse_qkv_projections()
|
| 535 |
+
self.fusing_transformer = False
|
| 536 |
+
|
| 537 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
|
| 538 |
+
def _prepare_rotary_positional_embeddings(
|
| 539 |
+
self,
|
| 540 |
+
height: int,
|
| 541 |
+
width: int,
|
| 542 |
+
num_frames: int,
|
| 543 |
+
device: torch.device,
|
| 544 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 545 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 546 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 547 |
+
|
| 548 |
+
p = self.transformer.config.patch_size
|
| 549 |
+
p_t = self.transformer.config.patch_size_t
|
| 550 |
+
|
| 551 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 552 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 553 |
+
|
| 554 |
+
if p_t is None:
|
| 555 |
+
# CogVideoX 1.0
|
| 556 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 557 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 558 |
+
)
|
| 559 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 560 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 561 |
+
crops_coords=grid_crops_coords,
|
| 562 |
+
grid_size=(grid_height, grid_width),
|
| 563 |
+
temporal_size=num_frames,
|
| 564 |
+
device=device,
|
| 565 |
+
)
|
| 566 |
+
else:
|
| 567 |
+
# CogVideoX 1.5
|
| 568 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 569 |
+
|
| 570 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 571 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 572 |
+
crops_coords=None,
|
| 573 |
+
grid_size=(grid_height, grid_width),
|
| 574 |
+
temporal_size=base_num_frames,
|
| 575 |
+
grid_type="slice",
|
| 576 |
+
max_size=(base_size_height, base_size_width),
|
| 577 |
+
device=device,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
return freqs_cos, freqs_sin
|
| 581 |
+
|
| 582 |
+
@property
|
| 583 |
+
def guidance_scale(self):
|
| 584 |
+
return self._guidance_scale
|
| 585 |
+
|
| 586 |
+
@property
|
| 587 |
+
def num_timesteps(self):
|
| 588 |
+
return self._num_timesteps
|
| 589 |
+
|
| 590 |
+
@property
|
| 591 |
+
def attention_kwargs(self):
|
| 592 |
+
return self._attention_kwargs
|
| 593 |
+
|
| 594 |
+
@property
|
| 595 |
+
def current_timestep(self):
|
| 596 |
+
return self._current_timestep
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def interrupt(self):
|
| 600 |
+
return self._interrupt
|
| 601 |
+
|
| 602 |
+
@torch.no_grad()
|
| 603 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 604 |
+
def __call__(
|
| 605 |
+
self,
|
| 606 |
+
image: PipelineImageInput,
|
| 607 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 608 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 609 |
+
height: Optional[int] = None,
|
| 610 |
+
width: Optional[int] = None,
|
| 611 |
+
num_frames: int = 49,
|
| 612 |
+
num_inference_steps: int = 50,
|
| 613 |
+
timesteps: Optional[List[int]] = None,
|
| 614 |
+
guidance_scale: float = 6,
|
| 615 |
+
use_dynamic_cfg: bool = False,
|
| 616 |
+
num_videos_per_prompt: int = 1,
|
| 617 |
+
eta: float = 0.0,
|
| 618 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 619 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 620 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 621 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 622 |
+
output_type: str = "pil",
|
| 623 |
+
return_dict: bool = True,
|
| 624 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 625 |
+
callback_on_step_end: Optional[
|
| 626 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 627 |
+
] = None,
|
| 628 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 629 |
+
max_sequence_length: int = 226,
|
| 630 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 631 |
+
"""
|
| 632 |
+
Function invoked when calling the pipeline for generation.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
image (`PipelineImageInput`):
|
| 636 |
+
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
| 637 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 638 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 639 |
+
instead.
|
| 640 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 641 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 642 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 643 |
+
less than `1`).
|
| 644 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 645 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
| 646 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 647 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
| 648 |
+
num_frames (`int`, defaults to `48`):
|
| 649 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 650 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 651 |
+
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
| 652 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 653 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 654 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 655 |
+
expense of slower inference.
|
| 656 |
+
timesteps (`List[int]`, *optional*):
|
| 657 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 658 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 659 |
+
passed will be used. Must be in descending order.
|
| 660 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 661 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 662 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 663 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 664 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 665 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 666 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 667 |
+
The number of videos to generate per prompt.
|
| 668 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 669 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 670 |
+
to make generation deterministic.
|
| 671 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 672 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 673 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 674 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 675 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 676 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 677 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 678 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 679 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 680 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 681 |
+
argument.
|
| 682 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 683 |
+
The output format of the generate image. Choose between
|
| 684 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 685 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 686 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 687 |
+
of a plain tuple.
|
| 688 |
+
attention_kwargs (`dict`, *optional*):
|
| 689 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 690 |
+
`self.processor` in
|
| 691 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 692 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 693 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 694 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 695 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 696 |
+
`callback_on_step_end_tensor_inputs`.
|
| 697 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 698 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 699 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 700 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 701 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 702 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 703 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 704 |
+
|
| 705 |
+
Examples:
|
| 706 |
+
|
| 707 |
+
Returns:
|
| 708 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
|
| 709 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 710 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 714 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 715 |
+
|
| 716 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 717 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 718 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 719 |
+
|
| 720 |
+
num_videos_per_prompt = 1
|
| 721 |
+
|
| 722 |
+
# 1. Check inputs. Raise error if not correct
|
| 723 |
+
self.check_inputs(
|
| 724 |
+
image=image,
|
| 725 |
+
prompt=prompt,
|
| 726 |
+
height=height,
|
| 727 |
+
width=width,
|
| 728 |
+
negative_prompt=negative_prompt,
|
| 729 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 730 |
+
latents=latents,
|
| 731 |
+
prompt_embeds=prompt_embeds,
|
| 732 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 733 |
+
)
|
| 734 |
+
self._guidance_scale = guidance_scale
|
| 735 |
+
self._current_timestep = None
|
| 736 |
+
self._attention_kwargs = attention_kwargs
|
| 737 |
+
self._interrupt = False
|
| 738 |
+
|
| 739 |
+
# 2. Default call parameters
|
| 740 |
+
if prompt is not None and isinstance(prompt, str):
|
| 741 |
+
batch_size = 1
|
| 742 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 743 |
+
batch_size = len(prompt)
|
| 744 |
+
else:
|
| 745 |
+
batch_size = prompt_embeds.shape[0]
|
| 746 |
+
|
| 747 |
+
device = self._execution_device
|
| 748 |
+
|
| 749 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 750 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 751 |
+
# corresponds to doing no classifier free guidance.
|
| 752 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 753 |
+
|
| 754 |
+
# 3. Encode input prompt
|
| 755 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 756 |
+
prompt=prompt,
|
| 757 |
+
negative_prompt=negative_prompt,
|
| 758 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 759 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 760 |
+
prompt_embeds=prompt_embeds,
|
| 761 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 762 |
+
max_sequence_length=max_sequence_length,
|
| 763 |
+
device=device,
|
| 764 |
+
)
|
| 765 |
+
if do_classifier_free_guidance:
|
| 766 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 767 |
+
|
| 768 |
+
# 4. Prepare timesteps
|
| 769 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 770 |
+
self._num_timesteps = len(timesteps)
|
| 771 |
+
|
| 772 |
+
# 5. Prepare latents
|
| 773 |
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 774 |
+
|
| 775 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 776 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 777 |
+
additional_frames = 0
|
| 778 |
+
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 779 |
+
additional_frames = patch_size_t - latent_frames % patch_size_t
|
| 780 |
+
num_frames += additional_frames * self.vae_scale_factor_temporal
|
| 781 |
+
|
| 782 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
| 783 |
+
device, dtype=prompt_embeds.dtype
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
latent_channels = self.transformer.config.in_channels // 2
|
| 787 |
+
latents, image_latents = self.prepare_latents(
|
| 788 |
+
image,
|
| 789 |
+
batch_size * num_videos_per_prompt,
|
| 790 |
+
latent_channels,
|
| 791 |
+
num_frames,
|
| 792 |
+
height,
|
| 793 |
+
width,
|
| 794 |
+
prompt_embeds.dtype,
|
| 795 |
+
device,
|
| 796 |
+
generator,
|
| 797 |
+
latents,
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 801 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 802 |
+
|
| 803 |
+
# 7. Create rotary embeds if required
|
| 804 |
+
image_rotary_emb = (
|
| 805 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 806 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 807 |
+
else None
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
# 8. Create ofs embeds if required
|
| 811 |
+
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
|
| 812 |
+
|
| 813 |
+
# 8. Denoising loop
|
| 814 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 815 |
+
|
| 816 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 817 |
+
# for DPM-solver++
|
| 818 |
+
old_pred_original_sample = None
|
| 819 |
+
for i, t in enumerate(timesteps):
|
| 820 |
+
if self.interrupt:
|
| 821 |
+
continue
|
| 822 |
+
|
| 823 |
+
self._current_timestep = t
|
| 824 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 825 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 826 |
+
|
| 827 |
+
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
| 828 |
+
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
| 829 |
+
|
| 830 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 831 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 832 |
+
|
| 833 |
+
# predict noise model_output
|
| 834 |
+
with self.transformer.cache_context("cond_uncond"):
|
| 835 |
+
noise_pred = self.transformer(
|
| 836 |
+
hidden_states=latent_model_input,
|
| 837 |
+
encoder_hidden_states=prompt_embeds,
|
| 838 |
+
timestep=timestep,
|
| 839 |
+
ofs=ofs_emb,
|
| 840 |
+
image_rotary_emb=image_rotary_emb,
|
| 841 |
+
attention_kwargs=attention_kwargs,
|
| 842 |
+
return_dict=False,
|
| 843 |
+
)[0]
|
| 844 |
+
noise_pred = noise_pred.float()
|
| 845 |
+
|
| 846 |
+
# perform guidance
|
| 847 |
+
if use_dynamic_cfg:
|
| 848 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 849 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 850 |
+
)
|
| 851 |
+
if do_classifier_free_guidance:
|
| 852 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 853 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 854 |
+
|
| 855 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 856 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 857 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 858 |
+
else:
|
| 859 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 860 |
+
noise_pred,
|
| 861 |
+
old_pred_original_sample,
|
| 862 |
+
t,
|
| 863 |
+
timesteps[i - 1] if i > 0 else None,
|
| 864 |
+
latents,
|
| 865 |
+
**extra_step_kwargs,
|
| 866 |
+
return_dict=False,
|
| 867 |
+
)
|
| 868 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 869 |
+
|
| 870 |
+
# call the callback, if provided
|
| 871 |
+
if callback_on_step_end is not None:
|
| 872 |
+
callback_kwargs = {}
|
| 873 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 874 |
+
callback_kwargs[k] = locals()[k]
|
| 875 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 876 |
+
|
| 877 |
+
latents = callback_outputs.pop("latents", latents)
|
| 878 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 879 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 880 |
+
|
| 881 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 882 |
+
progress_bar.update()
|
| 883 |
+
|
| 884 |
+
if XLA_AVAILABLE:
|
| 885 |
+
xm.mark_step()
|
| 886 |
+
|
| 887 |
+
self._current_timestep = None
|
| 888 |
+
|
| 889 |
+
if not output_type == "latent":
|
| 890 |
+
# Discard any padding frames that were added for CogVideoX 1.5
|
| 891 |
+
latents = latents[:, additional_frames:]
|
| 892 |
+
video = self.decode_latents(latents)
|
| 893 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 894 |
+
else:
|
| 895 |
+
video = latents
|
| 896 |
+
|
| 897 |
+
# Offload all models
|
| 898 |
+
self.maybe_free_model_hooks()
|
| 899 |
+
|
| 900 |
+
if not return_dict:
|
| 901 |
+
return (video,)
|
| 902 |
+
|
| 903 |
+
return CogVideoXPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from PIL import Image
|
| 22 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ...loaders import CogVideoXLoraLoaderMixin
|
| 26 |
+
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 27 |
+
from ...models.embeddings import get_3d_rotary_pos_embed
|
| 28 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 29 |
+
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 30 |
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 31 |
+
from ...utils.torch_utils import randn_tensor
|
| 32 |
+
from ...video_processor import VideoProcessor
|
| 33 |
+
from .pipeline_output import CogVideoXPipelineOutput
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if is_torch_xla_available():
|
| 37 |
+
import torch_xla.core.xla_model as xm
|
| 38 |
+
|
| 39 |
+
XLA_AVAILABLE = True
|
| 40 |
+
else:
|
| 41 |
+
XLA_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
EXAMPLE_DOC_STRING = """
|
| 47 |
+
Examples:
|
| 48 |
+
```python
|
| 49 |
+
>>> import torch
|
| 50 |
+
>>> from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline
|
| 51 |
+
>>> from diffusers.utils import export_to_video, load_video
|
| 52 |
+
|
| 53 |
+
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
| 54 |
+
>>> pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
| 55 |
+
>>> pipe.to("cuda")
|
| 56 |
+
>>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
| 57 |
+
|
| 58 |
+
>>> input_video = load_video(
|
| 59 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
|
| 60 |
+
... )
|
| 61 |
+
>>> prompt = (
|
| 62 |
+
... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
|
| 63 |
+
... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
|
| 64 |
+
... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
|
| 65 |
+
... "moons, but the remainder of the scene is mostly realistic."
|
| 66 |
+
... )
|
| 67 |
+
|
| 68 |
+
>>> video = pipe(
|
| 69 |
+
... video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50
|
| 70 |
+
... ).frames[0]
|
| 71 |
+
>>> export_to_video(video, "output.mp4", fps=8)
|
| 72 |
+
```
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 77 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 78 |
+
tw = tgt_width
|
| 79 |
+
th = tgt_height
|
| 80 |
+
h, w = src
|
| 81 |
+
r = h / w
|
| 82 |
+
if r > (th / tw):
|
| 83 |
+
resize_height = th
|
| 84 |
+
resize_width = int(round(th / h * w))
|
| 85 |
+
else:
|
| 86 |
+
resize_width = tw
|
| 87 |
+
resize_height = int(round(tw / w * h))
|
| 88 |
+
|
| 89 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 90 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 91 |
+
|
| 92 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 96 |
+
def retrieve_timesteps(
|
| 97 |
+
scheduler,
|
| 98 |
+
num_inference_steps: Optional[int] = None,
|
| 99 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 100 |
+
timesteps: Optional[List[int]] = None,
|
| 101 |
+
sigmas: Optional[List[float]] = None,
|
| 102 |
+
**kwargs,
|
| 103 |
+
):
|
| 104 |
+
r"""
|
| 105 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 106 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
scheduler (`SchedulerMixin`):
|
| 110 |
+
The scheduler to get timesteps from.
|
| 111 |
+
num_inference_steps (`int`):
|
| 112 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 113 |
+
must be `None`.
|
| 114 |
+
device (`str` or `torch.device`, *optional*):
|
| 115 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 116 |
+
timesteps (`List[int]`, *optional*):
|
| 117 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 118 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 119 |
+
sigmas (`List[float]`, *optional*):
|
| 120 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 121 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 125 |
+
second element is the number of inference steps.
|
| 126 |
+
"""
|
| 127 |
+
if timesteps is not None and sigmas is not None:
|
| 128 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 129 |
+
if timesteps is not None:
|
| 130 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 131 |
+
if not accepts_timesteps:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 134 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 135 |
+
)
|
| 136 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 137 |
+
timesteps = scheduler.timesteps
|
| 138 |
+
num_inference_steps = len(timesteps)
|
| 139 |
+
elif sigmas is not None:
|
| 140 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 141 |
+
if not accept_sigmas:
|
| 142 |
+
raise ValueError(
|
| 143 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 144 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 145 |
+
)
|
| 146 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 147 |
+
timesteps = scheduler.timesteps
|
| 148 |
+
num_inference_steps = len(timesteps)
|
| 149 |
+
else:
|
| 150 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 151 |
+
timesteps = scheduler.timesteps
|
| 152 |
+
return timesteps, num_inference_steps
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 156 |
+
def retrieve_latents(
|
| 157 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 158 |
+
):
|
| 159 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 160 |
+
return encoder_output.latent_dist.sample(generator)
|
| 161 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 162 |
+
return encoder_output.latent_dist.mode()
|
| 163 |
+
elif hasattr(encoder_output, "latents"):
|
| 164 |
+
return encoder_output.latents
|
| 165 |
+
else:
|
| 166 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 170 |
+
r"""
|
| 171 |
+
Pipeline for video-to-video generation using CogVideoX.
|
| 172 |
+
|
| 173 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 174 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
vae ([`AutoencoderKL`]):
|
| 178 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 179 |
+
text_encoder ([`T5EncoderModel`]):
|
| 180 |
+
Frozen text-encoder. CogVideoX uses
|
| 181 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 182 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 183 |
+
tokenizer (`T5Tokenizer`):
|
| 184 |
+
Tokenizer of class
|
| 185 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 186 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 187 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 188 |
+
scheduler ([`SchedulerMixin`]):
|
| 189 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
_optional_components = []
|
| 193 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 194 |
+
|
| 195 |
+
_callback_tensor_inputs = [
|
| 196 |
+
"latents",
|
| 197 |
+
"prompt_embeds",
|
| 198 |
+
"negative_prompt_embeds",
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
tokenizer: T5Tokenizer,
|
| 204 |
+
text_encoder: T5EncoderModel,
|
| 205 |
+
vae: AutoencoderKLCogVideoX,
|
| 206 |
+
transformer: CogVideoXTransformer3DModel,
|
| 207 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
|
| 211 |
+
self.register_modules(
|
| 212 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
self.vae_scale_factor_spatial = (
|
| 216 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 217 |
+
)
|
| 218 |
+
self.vae_scale_factor_temporal = (
|
| 219 |
+
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 220 |
+
)
|
| 221 |
+
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
| 222 |
+
|
| 223 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 224 |
+
|
| 225 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
| 226 |
+
def _get_t5_prompt_embeds(
|
| 227 |
+
self,
|
| 228 |
+
prompt: Union[str, List[str]] = None,
|
| 229 |
+
num_videos_per_prompt: int = 1,
|
| 230 |
+
max_sequence_length: int = 226,
|
| 231 |
+
device: Optional[torch.device] = None,
|
| 232 |
+
dtype: Optional[torch.dtype] = None,
|
| 233 |
+
):
|
| 234 |
+
device = device or self._execution_device
|
| 235 |
+
dtype = dtype or self.text_encoder.dtype
|
| 236 |
+
|
| 237 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 238 |
+
batch_size = len(prompt)
|
| 239 |
+
|
| 240 |
+
text_inputs = self.tokenizer(
|
| 241 |
+
prompt,
|
| 242 |
+
padding="max_length",
|
| 243 |
+
max_length=max_sequence_length,
|
| 244 |
+
truncation=True,
|
| 245 |
+
add_special_tokens=True,
|
| 246 |
+
return_tensors="pt",
|
| 247 |
+
)
|
| 248 |
+
text_input_ids = text_inputs.input_ids
|
| 249 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 250 |
+
|
| 251 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 252 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 253 |
+
logger.warning(
|
| 254 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 255 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 259 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 260 |
+
|
| 261 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 262 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 263 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 264 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 265 |
+
|
| 266 |
+
return prompt_embeds
|
| 267 |
+
|
| 268 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
| 269 |
+
def encode_prompt(
|
| 270 |
+
self,
|
| 271 |
+
prompt: Union[str, List[str]],
|
| 272 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 273 |
+
do_classifier_free_guidance: bool = True,
|
| 274 |
+
num_videos_per_prompt: int = 1,
|
| 275 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 276 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 277 |
+
max_sequence_length: int = 226,
|
| 278 |
+
device: Optional[torch.device] = None,
|
| 279 |
+
dtype: Optional[torch.dtype] = None,
|
| 280 |
+
):
|
| 281 |
+
r"""
|
| 282 |
+
Encodes the prompt into text encoder hidden states.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 286 |
+
prompt to be encoded
|
| 287 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 288 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 289 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 290 |
+
less than `1`).
|
| 291 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 292 |
+
Whether to use classifier free guidance or not.
|
| 293 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 294 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 295 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 296 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 297 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 298 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 299 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 300 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 301 |
+
argument.
|
| 302 |
+
device: (`torch.device`, *optional*):
|
| 303 |
+
torch device
|
| 304 |
+
dtype: (`torch.dtype`, *optional*):
|
| 305 |
+
torch dtype
|
| 306 |
+
"""
|
| 307 |
+
device = device or self._execution_device
|
| 308 |
+
|
| 309 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 310 |
+
if prompt is not None:
|
| 311 |
+
batch_size = len(prompt)
|
| 312 |
+
else:
|
| 313 |
+
batch_size = prompt_embeds.shape[0]
|
| 314 |
+
|
| 315 |
+
if prompt_embeds is None:
|
| 316 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 317 |
+
prompt=prompt,
|
| 318 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 319 |
+
max_sequence_length=max_sequence_length,
|
| 320 |
+
device=device,
|
| 321 |
+
dtype=dtype,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 325 |
+
negative_prompt = negative_prompt or ""
|
| 326 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 327 |
+
|
| 328 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 329 |
+
raise TypeError(
|
| 330 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 331 |
+
f" {type(prompt)}."
|
| 332 |
+
)
|
| 333 |
+
elif batch_size != len(negative_prompt):
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 336 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 337 |
+
" the batch size of `prompt`."
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 341 |
+
prompt=negative_prompt,
|
| 342 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 343 |
+
max_sequence_length=max_sequence_length,
|
| 344 |
+
device=device,
|
| 345 |
+
dtype=dtype,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
return prompt_embeds, negative_prompt_embeds
|
| 349 |
+
|
| 350 |
+
def prepare_latents(
|
| 351 |
+
self,
|
| 352 |
+
video: Optional[torch.Tensor] = None,
|
| 353 |
+
batch_size: int = 1,
|
| 354 |
+
num_channels_latents: int = 16,
|
| 355 |
+
height: int = 60,
|
| 356 |
+
width: int = 90,
|
| 357 |
+
dtype: Optional[torch.dtype] = None,
|
| 358 |
+
device: Optional[torch.device] = None,
|
| 359 |
+
generator: Optional[torch.Generator] = None,
|
| 360 |
+
latents: Optional[torch.Tensor] = None,
|
| 361 |
+
timestep: Optional[torch.Tensor] = None,
|
| 362 |
+
):
|
| 363 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 366 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
|
| 370 |
+
|
| 371 |
+
shape = (
|
| 372 |
+
batch_size,
|
| 373 |
+
num_frames,
|
| 374 |
+
num_channels_latents,
|
| 375 |
+
height // self.vae_scale_factor_spatial,
|
| 376 |
+
width // self.vae_scale_factor_spatial,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if latents is None:
|
| 380 |
+
if isinstance(generator, list):
|
| 381 |
+
init_latents = [
|
| 382 |
+
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
| 383 |
+
]
|
| 384 |
+
else:
|
| 385 |
+
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
| 386 |
+
|
| 387 |
+
init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
| 388 |
+
init_latents = self.vae_scaling_factor_image * init_latents
|
| 389 |
+
|
| 390 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 391 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 392 |
+
else:
|
| 393 |
+
latents = latents.to(device)
|
| 394 |
+
|
| 395 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 396 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 397 |
+
return latents
|
| 398 |
+
|
| 399 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
|
| 400 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 401 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 402 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 403 |
+
|
| 404 |
+
frames = self.vae.decode(latents).sample
|
| 405 |
+
return frames
|
| 406 |
+
|
| 407 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
|
| 408 |
+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
| 409 |
+
# get the original timestep using init_timestep
|
| 410 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 411 |
+
|
| 412 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 413 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 414 |
+
|
| 415 |
+
return timesteps, num_inference_steps - t_start
|
| 416 |
+
|
| 417 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 418 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 419 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 420 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 421 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 422 |
+
# and should be between [0, 1]
|
| 423 |
+
|
| 424 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 425 |
+
extra_step_kwargs = {}
|
| 426 |
+
if accepts_eta:
|
| 427 |
+
extra_step_kwargs["eta"] = eta
|
| 428 |
+
|
| 429 |
+
# check if the scheduler accepts generator
|
| 430 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 431 |
+
if accepts_generator:
|
| 432 |
+
extra_step_kwargs["generator"] = generator
|
| 433 |
+
return extra_step_kwargs
|
| 434 |
+
|
| 435 |
+
def check_inputs(
|
| 436 |
+
self,
|
| 437 |
+
prompt,
|
| 438 |
+
height,
|
| 439 |
+
width,
|
| 440 |
+
strength,
|
| 441 |
+
negative_prompt,
|
| 442 |
+
callback_on_step_end_tensor_inputs,
|
| 443 |
+
video=None,
|
| 444 |
+
latents=None,
|
| 445 |
+
prompt_embeds=None,
|
| 446 |
+
negative_prompt_embeds=None,
|
| 447 |
+
):
|
| 448 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 449 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 450 |
+
|
| 451 |
+
if strength < 0 or strength > 1:
|
| 452 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 453 |
+
|
| 454 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 455 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 456 |
+
):
|
| 457 |
+
raise ValueError(
|
| 458 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 459 |
+
)
|
| 460 |
+
if prompt is not None and prompt_embeds is not None:
|
| 461 |
+
raise ValueError(
|
| 462 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 463 |
+
" only forward one of the two."
|
| 464 |
+
)
|
| 465 |
+
elif prompt is None and prompt_embeds is None:
|
| 466 |
+
raise ValueError(
|
| 467 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 468 |
+
)
|
| 469 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 470 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 471 |
+
|
| 472 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 475 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 479 |
+
raise ValueError(
|
| 480 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 481 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 485 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 486 |
+
raise ValueError(
|
| 487 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 488 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 489 |
+
f" {negative_prompt_embeds.shape}."
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
if video is not None and latents is not None:
|
| 493 |
+
raise ValueError("Only one of `video` or `latents` should be provided")
|
| 494 |
+
|
| 495 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
|
| 496 |
+
def fuse_qkv_projections(self) -> None:
|
| 497 |
+
r"""Enables fused QKV projections."""
|
| 498 |
+
self.fusing_transformer = True
|
| 499 |
+
self.transformer.fuse_qkv_projections()
|
| 500 |
+
|
| 501 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
|
| 502 |
+
def unfuse_qkv_projections(self) -> None:
|
| 503 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 504 |
+
if not self.fusing_transformer:
|
| 505 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 506 |
+
else:
|
| 507 |
+
self.transformer.unfuse_qkv_projections()
|
| 508 |
+
self.fusing_transformer = False
|
| 509 |
+
|
| 510 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
|
| 511 |
+
def _prepare_rotary_positional_embeddings(
|
| 512 |
+
self,
|
| 513 |
+
height: int,
|
| 514 |
+
width: int,
|
| 515 |
+
num_frames: int,
|
| 516 |
+
device: torch.device,
|
| 517 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 518 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 519 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 520 |
+
|
| 521 |
+
p = self.transformer.config.patch_size
|
| 522 |
+
p_t = self.transformer.config.patch_size_t
|
| 523 |
+
|
| 524 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 525 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 526 |
+
|
| 527 |
+
if p_t is None:
|
| 528 |
+
# CogVideoX 1.0
|
| 529 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 530 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 531 |
+
)
|
| 532 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 533 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 534 |
+
crops_coords=grid_crops_coords,
|
| 535 |
+
grid_size=(grid_height, grid_width),
|
| 536 |
+
temporal_size=num_frames,
|
| 537 |
+
device=device,
|
| 538 |
+
)
|
| 539 |
+
else:
|
| 540 |
+
# CogVideoX 1.5
|
| 541 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 542 |
+
|
| 543 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 544 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 545 |
+
crops_coords=None,
|
| 546 |
+
grid_size=(grid_height, grid_width),
|
| 547 |
+
temporal_size=base_num_frames,
|
| 548 |
+
grid_type="slice",
|
| 549 |
+
max_size=(base_size_height, base_size_width),
|
| 550 |
+
device=device,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
return freqs_cos, freqs_sin
|
| 554 |
+
|
| 555 |
+
@property
|
| 556 |
+
def guidance_scale(self):
|
| 557 |
+
return self._guidance_scale
|
| 558 |
+
|
| 559 |
+
@property
|
| 560 |
+
def num_timesteps(self):
|
| 561 |
+
return self._num_timesteps
|
| 562 |
+
|
| 563 |
+
@property
|
| 564 |
+
def attention_kwargs(self):
|
| 565 |
+
return self._attention_kwargs
|
| 566 |
+
|
| 567 |
+
@property
|
| 568 |
+
def current_timestep(self):
|
| 569 |
+
return self._current_timestep
|
| 570 |
+
|
| 571 |
+
@property
|
| 572 |
+
def interrupt(self):
|
| 573 |
+
return self._interrupt
|
| 574 |
+
|
| 575 |
+
@torch.no_grad()
|
| 576 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 577 |
+
def __call__(
|
| 578 |
+
self,
|
| 579 |
+
video: List[Image.Image] = None,
|
| 580 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 581 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 582 |
+
height: Optional[int] = None,
|
| 583 |
+
width: Optional[int] = None,
|
| 584 |
+
num_inference_steps: int = 50,
|
| 585 |
+
timesteps: Optional[List[int]] = None,
|
| 586 |
+
strength: float = 0.8,
|
| 587 |
+
guidance_scale: float = 6,
|
| 588 |
+
use_dynamic_cfg: bool = False,
|
| 589 |
+
num_videos_per_prompt: int = 1,
|
| 590 |
+
eta: float = 0.0,
|
| 591 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 592 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 593 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 594 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 595 |
+
output_type: str = "pil",
|
| 596 |
+
return_dict: bool = True,
|
| 597 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 598 |
+
callback_on_step_end: Optional[
|
| 599 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 600 |
+
] = None,
|
| 601 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 602 |
+
max_sequence_length: int = 226,
|
| 603 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 604 |
+
"""
|
| 605 |
+
Function invoked when calling the pipeline for generation.
|
| 606 |
+
|
| 607 |
+
Args:
|
| 608 |
+
video (`List[PIL.Image.Image]`):
|
| 609 |
+
The input video to condition the generation on. Must be a list of images/frames of the video.
|
| 610 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 611 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 612 |
+
instead.
|
| 613 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 614 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 615 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 616 |
+
less than `1`).
|
| 617 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 618 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
| 619 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 620 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
| 621 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 622 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 623 |
+
expense of slower inference.
|
| 624 |
+
timesteps (`List[int]`, *optional*):
|
| 625 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 626 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 627 |
+
passed will be used. Must be in descending order.
|
| 628 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 629 |
+
Higher strength leads to more differences between original video and generated video.
|
| 630 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 631 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 632 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 633 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 634 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 635 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 636 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 637 |
+
The number of videos to generate per prompt.
|
| 638 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 639 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 640 |
+
to make generation deterministic.
|
| 641 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 642 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 643 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 644 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 645 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 646 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 647 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 648 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 649 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 650 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 651 |
+
argument.
|
| 652 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 653 |
+
The output format of the generate image. Choose between
|
| 654 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 655 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 656 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 657 |
+
of a plain tuple.
|
| 658 |
+
attention_kwargs (`dict`, *optional*):
|
| 659 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 660 |
+
`self.processor` in
|
| 661 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 662 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 663 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 664 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 665 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 666 |
+
`callback_on_step_end_tensor_inputs`.
|
| 667 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 668 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 669 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 670 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 671 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 672 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 673 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 674 |
+
|
| 675 |
+
Examples:
|
| 676 |
+
|
| 677 |
+
Returns:
|
| 678 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
|
| 679 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 680 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 681 |
+
"""
|
| 682 |
+
|
| 683 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 684 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 685 |
+
|
| 686 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 687 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 688 |
+
num_frames = len(video) if latents is None else latents.size(1)
|
| 689 |
+
|
| 690 |
+
num_videos_per_prompt = 1
|
| 691 |
+
|
| 692 |
+
# 1. Check inputs. Raise error if not correct
|
| 693 |
+
self.check_inputs(
|
| 694 |
+
prompt=prompt,
|
| 695 |
+
height=height,
|
| 696 |
+
width=width,
|
| 697 |
+
strength=strength,
|
| 698 |
+
negative_prompt=negative_prompt,
|
| 699 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 700 |
+
video=video,
|
| 701 |
+
latents=latents,
|
| 702 |
+
prompt_embeds=prompt_embeds,
|
| 703 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 704 |
+
)
|
| 705 |
+
self._guidance_scale = guidance_scale
|
| 706 |
+
self._attention_kwargs = attention_kwargs
|
| 707 |
+
self._current_timestep = None
|
| 708 |
+
self._interrupt = False
|
| 709 |
+
|
| 710 |
+
# 2. Default call parameters
|
| 711 |
+
if prompt is not None and isinstance(prompt, str):
|
| 712 |
+
batch_size = 1
|
| 713 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 714 |
+
batch_size = len(prompt)
|
| 715 |
+
else:
|
| 716 |
+
batch_size = prompt_embeds.shape[0]
|
| 717 |
+
|
| 718 |
+
device = self._execution_device
|
| 719 |
+
|
| 720 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 721 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 722 |
+
# corresponds to doing no classifier free guidance.
|
| 723 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 724 |
+
|
| 725 |
+
# 3. Encode input prompt
|
| 726 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 727 |
+
prompt,
|
| 728 |
+
negative_prompt,
|
| 729 |
+
do_classifier_free_guidance,
|
| 730 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 731 |
+
prompt_embeds=prompt_embeds,
|
| 732 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 733 |
+
max_sequence_length=max_sequence_length,
|
| 734 |
+
device=device,
|
| 735 |
+
)
|
| 736 |
+
if do_classifier_free_guidance:
|
| 737 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 738 |
+
|
| 739 |
+
# 4. Prepare timesteps
|
| 740 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 741 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
| 742 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
| 743 |
+
self._num_timesteps = len(timesteps)
|
| 744 |
+
|
| 745 |
+
# 5. Prepare latents
|
| 746 |
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 747 |
+
|
| 748 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 749 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 750 |
+
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 751 |
+
raise ValueError(
|
| 752 |
+
f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
|
| 753 |
+
f"contains {latent_frames=}, which is not divisible."
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
if latents is None:
|
| 757 |
+
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
| 758 |
+
video = video.to(device=device, dtype=prompt_embeds.dtype)
|
| 759 |
+
|
| 760 |
+
latent_channels = self.transformer.config.in_channels
|
| 761 |
+
latents = self.prepare_latents(
|
| 762 |
+
video,
|
| 763 |
+
batch_size * num_videos_per_prompt,
|
| 764 |
+
latent_channels,
|
| 765 |
+
height,
|
| 766 |
+
width,
|
| 767 |
+
prompt_embeds.dtype,
|
| 768 |
+
device,
|
| 769 |
+
generator,
|
| 770 |
+
latents,
|
| 771 |
+
latent_timestep,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 775 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 776 |
+
|
| 777 |
+
# 7. Create rotary embeds if required
|
| 778 |
+
image_rotary_emb = (
|
| 779 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 780 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 781 |
+
else None
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# 8. Denoising loop
|
| 785 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 786 |
+
|
| 787 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 788 |
+
# for DPM-solver++
|
| 789 |
+
old_pred_original_sample = None
|
| 790 |
+
for i, t in enumerate(timesteps):
|
| 791 |
+
if self.interrupt:
|
| 792 |
+
continue
|
| 793 |
+
|
| 794 |
+
self._current_timestep = t
|
| 795 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 796 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 797 |
+
|
| 798 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 799 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 800 |
+
|
| 801 |
+
# predict noise model_output
|
| 802 |
+
with self.transformer.cache_context("cond_uncond"):
|
| 803 |
+
noise_pred = self.transformer(
|
| 804 |
+
hidden_states=latent_model_input,
|
| 805 |
+
encoder_hidden_states=prompt_embeds,
|
| 806 |
+
timestep=timestep,
|
| 807 |
+
image_rotary_emb=image_rotary_emb,
|
| 808 |
+
attention_kwargs=attention_kwargs,
|
| 809 |
+
return_dict=False,
|
| 810 |
+
)[0]
|
| 811 |
+
noise_pred = noise_pred.float()
|
| 812 |
+
|
| 813 |
+
# perform guidance
|
| 814 |
+
if use_dynamic_cfg:
|
| 815 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 816 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 817 |
+
)
|
| 818 |
+
if do_classifier_free_guidance:
|
| 819 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 820 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 821 |
+
|
| 822 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 823 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 824 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 825 |
+
else:
|
| 826 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 827 |
+
noise_pred,
|
| 828 |
+
old_pred_original_sample,
|
| 829 |
+
t,
|
| 830 |
+
timesteps[i - 1] if i > 0 else None,
|
| 831 |
+
latents,
|
| 832 |
+
**extra_step_kwargs,
|
| 833 |
+
return_dict=False,
|
| 834 |
+
)
|
| 835 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 836 |
+
|
| 837 |
+
# call the callback, if provided
|
| 838 |
+
if callback_on_step_end is not None:
|
| 839 |
+
callback_kwargs = {}
|
| 840 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 841 |
+
callback_kwargs[k] = locals()[k]
|
| 842 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 843 |
+
|
| 844 |
+
latents = callback_outputs.pop("latents", latents)
|
| 845 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 846 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 847 |
+
|
| 848 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 849 |
+
progress_bar.update()
|
| 850 |
+
|
| 851 |
+
if XLA_AVAILABLE:
|
| 852 |
+
xm.mark_step()
|
| 853 |
+
|
| 854 |
+
self._current_timestep = None
|
| 855 |
+
|
| 856 |
+
if not output_type == "latent":
|
| 857 |
+
video = self.decode_latents(latents)
|
| 858 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 859 |
+
else:
|
| 860 |
+
video = latents
|
| 861 |
+
|
| 862 |
+
# Offload all models
|
| 863 |
+
self.maybe_free_model_hooks()
|
| 864 |
+
|
| 865 |
+
if not return_dict:
|
| 866 |
+
return (video,)
|
| 867 |
+
|
| 868 |
+
return CogVideoXPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_output.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers.utils import BaseOutput
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class CogVideoXPipelineOutput(BaseOutput):
|
| 10 |
+
r"""
|
| 11 |
+
Output class for CogVideo pipelines.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 15 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 16 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 17 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
frames: torch.Tensor
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_additional_imports = {}
|
| 15 |
+
_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]}
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"]
|
| 26 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 27 |
+
try:
|
| 28 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 29 |
+
raise OptionalDependencyNotAvailable()
|
| 30 |
+
except OptionalDependencyNotAvailable:
|
| 31 |
+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
| 32 |
+
else:
|
| 33 |
+
from .pipeline_cogview3plus import CogView3PlusPipeline
|
| 34 |
+
else:
|
| 35 |
+
import sys
|
| 36 |
+
|
| 37 |
+
sys.modules[__name__] = _LazyModule(
|
| 38 |
+
__name__,
|
| 39 |
+
globals()["__file__"],
|
| 40 |
+
_import_structure,
|
| 41 |
+
module_spec=__spec__,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
for name, value in _dummy_objects.items():
|
| 45 |
+
setattr(sys.modules[__name__], name, value)
|
| 46 |
+
for name, value in _additional_imports.items():
|
| 47 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 21 |
+
|
| 22 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 23 |
+
from ...image_processor import VaeImageProcessor
|
| 24 |
+
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
|
| 25 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 27 |
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 28 |
+
from ...utils.torch_utils import randn_tensor
|
| 29 |
+
from .pipeline_output import CogView3PipelineOutput
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if is_torch_xla_available():
|
| 33 |
+
import torch_xla.core.xla_model as xm
|
| 34 |
+
|
| 35 |
+
XLA_AVAILABLE = True
|
| 36 |
+
else:
|
| 37 |
+
XLA_AVAILABLE = False
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
EXAMPLE_DOC_STRING = """
|
| 43 |
+
Examples:
|
| 44 |
+
```python
|
| 45 |
+
>>> import torch
|
| 46 |
+
>>> from diffusers import CogView3PlusPipeline
|
| 47 |
+
|
| 48 |
+
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
|
| 49 |
+
>>> pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
| 52 |
+
>>> image = pipe(prompt).images[0]
|
| 53 |
+
>>> image.save("output.png")
|
| 54 |
+
```
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 59 |
+
def retrieve_timesteps(
|
| 60 |
+
scheduler,
|
| 61 |
+
num_inference_steps: Optional[int] = None,
|
| 62 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 63 |
+
timesteps: Optional[List[int]] = None,
|
| 64 |
+
sigmas: Optional[List[float]] = None,
|
| 65 |
+
**kwargs,
|
| 66 |
+
):
|
| 67 |
+
r"""
|
| 68 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 69 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
scheduler (`SchedulerMixin`):
|
| 73 |
+
The scheduler to get timesteps from.
|
| 74 |
+
num_inference_steps (`int`):
|
| 75 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 76 |
+
must be `None`.
|
| 77 |
+
device (`str` or `torch.device`, *optional*):
|
| 78 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 79 |
+
timesteps (`List[int]`, *optional*):
|
| 80 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 81 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 82 |
+
sigmas (`List[float]`, *optional*):
|
| 83 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 84 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 88 |
+
second element is the number of inference steps.
|
| 89 |
+
"""
|
| 90 |
+
if timesteps is not None and sigmas is not None:
|
| 91 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 92 |
+
if timesteps is not None:
|
| 93 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 94 |
+
if not accepts_timesteps:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 97 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 98 |
+
)
|
| 99 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 100 |
+
timesteps = scheduler.timesteps
|
| 101 |
+
num_inference_steps = len(timesteps)
|
| 102 |
+
elif sigmas is not None:
|
| 103 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 104 |
+
if not accept_sigmas:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 107 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 108 |
+
)
|
| 109 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 110 |
+
timesteps = scheduler.timesteps
|
| 111 |
+
num_inference_steps = len(timesteps)
|
| 112 |
+
else:
|
| 113 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 114 |
+
timesteps = scheduler.timesteps
|
| 115 |
+
return timesteps, num_inference_steps
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class CogView3PlusPipeline(DiffusionPipeline):
|
| 119 |
+
r"""
|
| 120 |
+
Pipeline for text-to-image generation using CogView3Plus.
|
| 121 |
+
|
| 122 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 123 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
vae ([`AutoencoderKL`]):
|
| 127 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 128 |
+
text_encoder ([`T5EncoderModel`]):
|
| 129 |
+
Frozen text-encoder. CogView3Plus uses
|
| 130 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 131 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 132 |
+
tokenizer (`T5Tokenizer`):
|
| 133 |
+
Tokenizer of class
|
| 134 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 135 |
+
transformer ([`CogView3PlusTransformer2DModel`]):
|
| 136 |
+
A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents.
|
| 137 |
+
scheduler ([`SchedulerMixin`]):
|
| 138 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
_optional_components = []
|
| 142 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 143 |
+
|
| 144 |
+
_callback_tensor_inputs = [
|
| 145 |
+
"latents",
|
| 146 |
+
"prompt_embeds",
|
| 147 |
+
"negative_prompt_embeds",
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
tokenizer: T5Tokenizer,
|
| 153 |
+
text_encoder: T5EncoderModel,
|
| 154 |
+
vae: AutoencoderKL,
|
| 155 |
+
transformer: CogView3PlusTransformer2DModel,
|
| 156 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 157 |
+
):
|
| 158 |
+
super().__init__()
|
| 159 |
+
|
| 160 |
+
self.register_modules(
|
| 161 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 162 |
+
)
|
| 163 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 164 |
+
|
| 165 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 166 |
+
|
| 167 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt
|
| 168 |
+
def _get_t5_prompt_embeds(
|
| 169 |
+
self,
|
| 170 |
+
prompt: Union[str, List[str]] = None,
|
| 171 |
+
num_images_per_prompt: int = 1,
|
| 172 |
+
max_sequence_length: int = 226,
|
| 173 |
+
device: Optional[torch.device] = None,
|
| 174 |
+
dtype: Optional[torch.dtype] = None,
|
| 175 |
+
):
|
| 176 |
+
device = device or self._execution_device
|
| 177 |
+
dtype = dtype or self.text_encoder.dtype
|
| 178 |
+
|
| 179 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 180 |
+
batch_size = len(prompt)
|
| 181 |
+
|
| 182 |
+
text_inputs = self.tokenizer(
|
| 183 |
+
prompt,
|
| 184 |
+
padding="max_length",
|
| 185 |
+
max_length=max_sequence_length,
|
| 186 |
+
truncation=True,
|
| 187 |
+
add_special_tokens=True,
|
| 188 |
+
return_tensors="pt",
|
| 189 |
+
)
|
| 190 |
+
text_input_ids = text_inputs.input_ids
|
| 191 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 192 |
+
|
| 193 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 194 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 195 |
+
logger.warning(
|
| 196 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 197 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 201 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 202 |
+
|
| 203 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 204 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 205 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 206 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 207 |
+
|
| 208 |
+
return prompt_embeds
|
| 209 |
+
|
| 210 |
+
def encode_prompt(
|
| 211 |
+
self,
|
| 212 |
+
prompt: Union[str, List[str]],
|
| 213 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 214 |
+
do_classifier_free_guidance: bool = True,
|
| 215 |
+
num_images_per_prompt: int = 1,
|
| 216 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 217 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 218 |
+
max_sequence_length: int = 224,
|
| 219 |
+
device: Optional[torch.device] = None,
|
| 220 |
+
dtype: Optional[torch.dtype] = None,
|
| 221 |
+
):
|
| 222 |
+
r"""
|
| 223 |
+
Encodes the prompt into text encoder hidden states.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 227 |
+
prompt to be encoded
|
| 228 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 229 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 230 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 231 |
+
less than `1`).
|
| 232 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 233 |
+
Whether to use classifier free guidance or not.
|
| 234 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 235 |
+
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
| 236 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 237 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 238 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 239 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 240 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 241 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 242 |
+
argument.
|
| 243 |
+
max_sequence_length (`int`, defaults to `224`):
|
| 244 |
+
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
| 245 |
+
device: (`torch.device`, *optional*):
|
| 246 |
+
torch device
|
| 247 |
+
dtype: (`torch.dtype`, *optional*):
|
| 248 |
+
torch dtype
|
| 249 |
+
"""
|
| 250 |
+
device = device or self._execution_device
|
| 251 |
+
|
| 252 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 253 |
+
if prompt is not None:
|
| 254 |
+
batch_size = len(prompt)
|
| 255 |
+
else:
|
| 256 |
+
batch_size = prompt_embeds.shape[0]
|
| 257 |
+
|
| 258 |
+
if prompt_embeds is None:
|
| 259 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 260 |
+
prompt=prompt,
|
| 261 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 262 |
+
max_sequence_length=max_sequence_length,
|
| 263 |
+
device=device,
|
| 264 |
+
dtype=dtype,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if do_classifier_free_guidance and negative_prompt is None:
|
| 268 |
+
negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
|
| 269 |
+
|
| 270 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 271 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 272 |
+
|
| 273 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 274 |
+
raise TypeError(
|
| 275 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 276 |
+
f" {type(prompt)}."
|
| 277 |
+
)
|
| 278 |
+
elif batch_size != len(negative_prompt):
|
| 279 |
+
raise ValueError(
|
| 280 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 281 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 282 |
+
" the batch size of `prompt`."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 286 |
+
prompt=negative_prompt,
|
| 287 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 288 |
+
max_sequence_length=max_sequence_length,
|
| 289 |
+
device=device,
|
| 290 |
+
dtype=dtype,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return prompt_embeds, negative_prompt_embeds
|
| 294 |
+
|
| 295 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 296 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 297 |
+
shape = (
|
| 298 |
+
batch_size,
|
| 299 |
+
num_channels_latents,
|
| 300 |
+
int(height) // self.vae_scale_factor,
|
| 301 |
+
int(width) // self.vae_scale_factor,
|
| 302 |
+
)
|
| 303 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 306 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if latents is None:
|
| 310 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 311 |
+
else:
|
| 312 |
+
latents = latents.to(device)
|
| 313 |
+
|
| 314 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 315 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 316 |
+
return latents
|
| 317 |
+
|
| 318 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 319 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 320 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 321 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 322 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 323 |
+
# and should be between [0, 1]
|
| 324 |
+
|
| 325 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 326 |
+
extra_step_kwargs = {}
|
| 327 |
+
if accepts_eta:
|
| 328 |
+
extra_step_kwargs["eta"] = eta
|
| 329 |
+
|
| 330 |
+
# check if the scheduler accepts generator
|
| 331 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 332 |
+
if accepts_generator:
|
| 333 |
+
extra_step_kwargs["generator"] = generator
|
| 334 |
+
return extra_step_kwargs
|
| 335 |
+
|
| 336 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 337 |
+
def check_inputs(
|
| 338 |
+
self,
|
| 339 |
+
prompt,
|
| 340 |
+
height,
|
| 341 |
+
width,
|
| 342 |
+
negative_prompt,
|
| 343 |
+
callback_on_step_end_tensor_inputs,
|
| 344 |
+
prompt_embeds=None,
|
| 345 |
+
negative_prompt_embeds=None,
|
| 346 |
+
):
|
| 347 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 348 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 349 |
+
|
| 350 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 351 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 352 |
+
):
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 355 |
+
)
|
| 356 |
+
if prompt is not None and prompt_embeds is not None:
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 359 |
+
" only forward one of the two."
|
| 360 |
+
)
|
| 361 |
+
elif prompt is None and prompt_embeds is None:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 364 |
+
)
|
| 365 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 366 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 367 |
+
|
| 368 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 371 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 377 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 381 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 382 |
+
raise ValueError(
|
| 383 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 384 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 385 |
+
f" {negative_prompt_embeds.shape}."
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
@property
|
| 389 |
+
def guidance_scale(self):
|
| 390 |
+
return self._guidance_scale
|
| 391 |
+
|
| 392 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 393 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 394 |
+
# corresponds to doing no classifier free guidance.
|
| 395 |
+
@property
|
| 396 |
+
def do_classifier_free_guidance(self):
|
| 397 |
+
return self._guidance_scale > 1
|
| 398 |
+
|
| 399 |
+
@property
|
| 400 |
+
def num_timesteps(self):
|
| 401 |
+
return self._num_timesteps
|
| 402 |
+
|
| 403 |
+
@property
|
| 404 |
+
def interrupt(self):
|
| 405 |
+
return self._interrupt
|
| 406 |
+
|
| 407 |
+
@torch.no_grad()
|
| 408 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 409 |
+
def __call__(
|
| 410 |
+
self,
|
| 411 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 412 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 413 |
+
height: Optional[int] = None,
|
| 414 |
+
width: Optional[int] = None,
|
| 415 |
+
num_inference_steps: int = 50,
|
| 416 |
+
timesteps: Optional[List[int]] = None,
|
| 417 |
+
guidance_scale: float = 5.0,
|
| 418 |
+
num_images_per_prompt: int = 1,
|
| 419 |
+
eta: float = 0.0,
|
| 420 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 421 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 422 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 423 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 424 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 425 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 426 |
+
output_type: str = "pil",
|
| 427 |
+
return_dict: bool = True,
|
| 428 |
+
callback_on_step_end: Optional[
|
| 429 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 430 |
+
] = None,
|
| 431 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 432 |
+
max_sequence_length: int = 224,
|
| 433 |
+
) -> Union[CogView3PipelineOutput, Tuple]:
|
| 434 |
+
"""
|
| 435 |
+
Function invoked when calling the pipeline for generation.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 439 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 440 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 441 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 442 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 443 |
+
less than `1`).
|
| 444 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 445 |
+
The height in pixels of the generated image. If not provided, it is set to 1024.
|
| 446 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 447 |
+
The width in pixels of the generated image. If not provided it is set to 1024.
|
| 448 |
+
num_inference_steps (`int`, *optional*, defaults to `50`):
|
| 449 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 450 |
+
expense of slower inference.
|
| 451 |
+
timesteps (`List[int]`, *optional*):
|
| 452 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 453 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 454 |
+
passed will be used. Must be in descending order.
|
| 455 |
+
guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 456 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 457 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 458 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 459 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 460 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 461 |
+
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
| 462 |
+
The number of images to generate per prompt.
|
| 463 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 464 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 465 |
+
to make generation deterministic.
|
| 466 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 467 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 468 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 469 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 470 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 471 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 472 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 473 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 474 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 475 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 476 |
+
argument.
|
| 477 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 478 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 479 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 480 |
+
explained in section 2.2 of
|
| 481 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 482 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 483 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 484 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 485 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 486 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 487 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 488 |
+
The output format of the generate image. Choose between
|
| 489 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 490 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 491 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 492 |
+
of a plain tuple.
|
| 493 |
+
attention_kwargs (`dict`, *optional*):
|
| 494 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 495 |
+
`self.processor` in
|
| 496 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 497 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 498 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 499 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 500 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 501 |
+
`callback_on_step_end_tensor_inputs`.
|
| 502 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 503 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 504 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 505 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 506 |
+
max_sequence_length (`int`, defaults to `224`):
|
| 507 |
+
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
| 508 |
+
|
| 509 |
+
Examples:
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
[`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`:
|
| 513 |
+
[`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
|
| 514 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 515 |
+
"""
|
| 516 |
+
|
| 517 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 518 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 519 |
+
|
| 520 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
| 521 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
| 522 |
+
|
| 523 |
+
original_size = original_size or (height, width)
|
| 524 |
+
target_size = (height, width)
|
| 525 |
+
|
| 526 |
+
# 1. Check inputs. Raise error if not correct
|
| 527 |
+
self.check_inputs(
|
| 528 |
+
prompt,
|
| 529 |
+
height,
|
| 530 |
+
width,
|
| 531 |
+
negative_prompt,
|
| 532 |
+
callback_on_step_end_tensor_inputs,
|
| 533 |
+
prompt_embeds,
|
| 534 |
+
negative_prompt_embeds,
|
| 535 |
+
)
|
| 536 |
+
self._guidance_scale = guidance_scale
|
| 537 |
+
self._interrupt = False
|
| 538 |
+
|
| 539 |
+
# 2. Default call parameters
|
| 540 |
+
if prompt is not None and isinstance(prompt, str):
|
| 541 |
+
batch_size = 1
|
| 542 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 543 |
+
batch_size = len(prompt)
|
| 544 |
+
else:
|
| 545 |
+
batch_size = prompt_embeds.shape[0]
|
| 546 |
+
|
| 547 |
+
device = self._execution_device
|
| 548 |
+
|
| 549 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 550 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 551 |
+
# corresponds to doing no classifier free guidance.
|
| 552 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 553 |
+
|
| 554 |
+
# 3. Encode input prompt
|
| 555 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 556 |
+
prompt,
|
| 557 |
+
negative_prompt,
|
| 558 |
+
self.do_classifier_free_guidance,
|
| 559 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 560 |
+
prompt_embeds=prompt_embeds,
|
| 561 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 562 |
+
max_sequence_length=max_sequence_length,
|
| 563 |
+
device=device,
|
| 564 |
+
)
|
| 565 |
+
if self.do_classifier_free_guidance:
|
| 566 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 567 |
+
|
| 568 |
+
# 4. Prepare timesteps
|
| 569 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 570 |
+
self._num_timesteps = len(timesteps)
|
| 571 |
+
|
| 572 |
+
# 5. Prepare latents.
|
| 573 |
+
latent_channels = self.transformer.config.in_channels
|
| 574 |
+
latents = self.prepare_latents(
|
| 575 |
+
batch_size * num_images_per_prompt,
|
| 576 |
+
latent_channels,
|
| 577 |
+
height,
|
| 578 |
+
width,
|
| 579 |
+
prompt_embeds.dtype,
|
| 580 |
+
device,
|
| 581 |
+
generator,
|
| 582 |
+
latents,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 586 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 587 |
+
|
| 588 |
+
# 7. Prepare additional timestep conditions
|
| 589 |
+
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
|
| 590 |
+
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
|
| 591 |
+
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
|
| 592 |
+
|
| 593 |
+
if self.do_classifier_free_guidance:
|
| 594 |
+
original_size = torch.cat([original_size, original_size])
|
| 595 |
+
target_size = torch.cat([target_size, target_size])
|
| 596 |
+
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
|
| 597 |
+
|
| 598 |
+
original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 599 |
+
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 600 |
+
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 601 |
+
|
| 602 |
+
# 8. Denoising loop
|
| 603 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 604 |
+
|
| 605 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 606 |
+
# for DPM-solver++
|
| 607 |
+
old_pred_original_sample = None
|
| 608 |
+
for i, t in enumerate(timesteps):
|
| 609 |
+
if self.interrupt:
|
| 610 |
+
continue
|
| 611 |
+
|
| 612 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 613 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 614 |
+
|
| 615 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 616 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 617 |
+
|
| 618 |
+
# predict noise model_output
|
| 619 |
+
noise_pred = self.transformer(
|
| 620 |
+
hidden_states=latent_model_input,
|
| 621 |
+
encoder_hidden_states=prompt_embeds,
|
| 622 |
+
timestep=timestep,
|
| 623 |
+
original_size=original_size,
|
| 624 |
+
target_size=target_size,
|
| 625 |
+
crop_coords=crops_coords_top_left,
|
| 626 |
+
return_dict=False,
|
| 627 |
+
)[0]
|
| 628 |
+
noise_pred = noise_pred.float()
|
| 629 |
+
|
| 630 |
+
# perform guidance
|
| 631 |
+
if self.do_classifier_free_guidance:
|
| 632 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 633 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 634 |
+
|
| 635 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 636 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 637 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 638 |
+
else:
|
| 639 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 640 |
+
noise_pred,
|
| 641 |
+
old_pred_original_sample,
|
| 642 |
+
t,
|
| 643 |
+
timesteps[i - 1] if i > 0 else None,
|
| 644 |
+
latents,
|
| 645 |
+
**extra_step_kwargs,
|
| 646 |
+
return_dict=False,
|
| 647 |
+
)
|
| 648 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 649 |
+
|
| 650 |
+
# call the callback, if provided
|
| 651 |
+
if callback_on_step_end is not None:
|
| 652 |
+
callback_kwargs = {}
|
| 653 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 654 |
+
callback_kwargs[k] = locals()[k]
|
| 655 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 656 |
+
|
| 657 |
+
latents = callback_outputs.pop("latents", latents)
|
| 658 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 659 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 660 |
+
|
| 661 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 662 |
+
progress_bar.update()
|
| 663 |
+
|
| 664 |
+
if XLA_AVAILABLE:
|
| 665 |
+
xm.mark_step()
|
| 666 |
+
|
| 667 |
+
if not output_type == "latent":
|
| 668 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 669 |
+
0
|
| 670 |
+
]
|
| 671 |
+
else:
|
| 672 |
+
image = latents
|
| 673 |
+
|
| 674 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 675 |
+
|
| 676 |
+
# Offload all models
|
| 677 |
+
self.maybe_free_model_hooks()
|
| 678 |
+
|
| 679 |
+
if not return_dict:
|
| 680 |
+
return (image,)
|
| 681 |
+
|
| 682 |
+
return CogView3PipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_output.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from ...utils import BaseOutput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class CogView3PipelineOutput(BaseOutput):
|
| 12 |
+
"""
|
| 13 |
+
Output class for CogView3 pipelines.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 17 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 18 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_additional_imports = {}
|
| 15 |
+
_import_structure = {"pipeline_output": ["CogView4PlusPipelineOutput"]}
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["pipeline_cogview4"] = ["CogView4Pipeline"]
|
| 26 |
+
_import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"]
|
| 27 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 28 |
+
try:
|
| 29 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 30 |
+
raise OptionalDependencyNotAvailable()
|
| 31 |
+
except OptionalDependencyNotAvailable:
|
| 32 |
+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
| 33 |
+
else:
|
| 34 |
+
from .pipeline_cogview4 import CogView4Pipeline
|
| 35 |
+
from .pipeline_cogview4_control import CogView4ControlPipeline
|
| 36 |
+
else:
|
| 37 |
+
import sys
|
| 38 |
+
|
| 39 |
+
sys.modules[__name__] = _LazyModule(
|
| 40 |
+
__name__,
|
| 41 |
+
globals()["__file__"],
|
| 42 |
+
_import_structure,
|
| 43 |
+
module_spec=__spec__,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
for name, value in _dummy_objects.items():
|
| 47 |
+
setattr(sys.modules[__name__], name, value)
|
| 48 |
+
for name, value in _additional_imports.items():
|
| 49 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import AutoTokenizer, GlmModel
|
| 22 |
+
|
| 23 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 24 |
+
from ...image_processor import VaeImageProcessor
|
| 25 |
+
from ...loaders import CogView4LoraLoaderMixin
|
| 26 |
+
from ...models import AutoencoderKL, CogView4Transformer2DModel
|
| 27 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 28 |
+
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 29 |
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 30 |
+
from ...utils.torch_utils import randn_tensor
|
| 31 |
+
from .pipeline_output import CogView4PipelineOutput
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
EXAMPLE_DOC_STRING = """
|
| 44 |
+
Examples:
|
| 45 |
+
```python
|
| 46 |
+
>>> import torch
|
| 47 |
+
>>> from diffusers import CogView4Pipeline
|
| 48 |
+
|
| 49 |
+
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
| 50 |
+
>>> pipe.to("cuda")
|
| 51 |
+
|
| 52 |
+
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
| 53 |
+
>>> image = pipe(prompt).images[0]
|
| 54 |
+
>>> image.save("output.png")
|
| 55 |
+
```
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def calculate_shift(
|
| 60 |
+
image_seq_len,
|
| 61 |
+
base_seq_len: int = 256,
|
| 62 |
+
base_shift: float = 0.25,
|
| 63 |
+
max_shift: float = 0.75,
|
| 64 |
+
) -> float:
|
| 65 |
+
m = (image_seq_len / base_seq_len) ** 0.5
|
| 66 |
+
mu = m * max_shift + base_shift
|
| 67 |
+
return mu
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def retrieve_timesteps(
|
| 71 |
+
scheduler,
|
| 72 |
+
num_inference_steps: Optional[int] = None,
|
| 73 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 74 |
+
timesteps: Optional[List[int]] = None,
|
| 75 |
+
sigmas: Optional[List[float]] = None,
|
| 76 |
+
**kwargs,
|
| 77 |
+
):
|
| 78 |
+
r"""
|
| 79 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 80 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
scheduler (`SchedulerMixin`):
|
| 84 |
+
The scheduler to get timesteps from.
|
| 85 |
+
num_inference_steps (`int`):
|
| 86 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 87 |
+
must be `None`.
|
| 88 |
+
device (`str` or `torch.device`, *optional*):
|
| 89 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 90 |
+
timesteps (`List[int]`, *optional*):
|
| 91 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 92 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 93 |
+
sigmas (`List[float]`, *optional*):
|
| 94 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 95 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 99 |
+
second element is the number of inference steps.
|
| 100 |
+
"""
|
| 101 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 102 |
+
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 103 |
+
|
| 104 |
+
if timesteps is not None and sigmas is not None:
|
| 105 |
+
if not accepts_timesteps and not accepts_sigmas:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 108 |
+
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
|
| 109 |
+
)
|
| 110 |
+
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
|
| 111 |
+
timesteps = scheduler.timesteps
|
| 112 |
+
num_inference_steps = len(timesteps)
|
| 113 |
+
elif timesteps is not None and sigmas is None:
|
| 114 |
+
if not accepts_timesteps:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 117 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 118 |
+
)
|
| 119 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 120 |
+
timesteps = scheduler.timesteps
|
| 121 |
+
num_inference_steps = len(timesteps)
|
| 122 |
+
elif timesteps is None and sigmas is not None:
|
| 123 |
+
if not accepts_sigmas:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 126 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 127 |
+
)
|
| 128 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 129 |
+
timesteps = scheduler.timesteps
|
| 130 |
+
num_inference_steps = len(timesteps)
|
| 131 |
+
else:
|
| 132 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 133 |
+
timesteps = scheduler.timesteps
|
| 134 |
+
return timesteps, num_inference_steps
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
| 138 |
+
r"""
|
| 139 |
+
Pipeline for text-to-image generation using CogView4.
|
| 140 |
+
|
| 141 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 142 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
vae ([`AutoencoderKL`]):
|
| 146 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 147 |
+
text_encoder ([`GLMModel`]):
|
| 148 |
+
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
|
| 149 |
+
tokenizer (`PreTrainedTokenizer`):
|
| 150 |
+
Tokenizer of class
|
| 151 |
+
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
|
| 152 |
+
transformer ([`CogView4Transformer2DModel`]):
|
| 153 |
+
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
|
| 154 |
+
scheduler ([`SchedulerMixin`]):
|
| 155 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
_optional_components = []
|
| 159 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 160 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
tokenizer: AutoTokenizer,
|
| 165 |
+
text_encoder: GlmModel,
|
| 166 |
+
vae: AutoencoderKL,
|
| 167 |
+
transformer: CogView4Transformer2DModel,
|
| 168 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
|
| 172 |
+
self.register_modules(
|
| 173 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 174 |
+
)
|
| 175 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 176 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 177 |
+
|
| 178 |
+
def _get_glm_embeds(
|
| 179 |
+
self,
|
| 180 |
+
prompt: Union[str, List[str]] = None,
|
| 181 |
+
max_sequence_length: int = 1024,
|
| 182 |
+
device: Optional[torch.device] = None,
|
| 183 |
+
dtype: Optional[torch.dtype] = None,
|
| 184 |
+
):
|
| 185 |
+
device = device or self._execution_device
|
| 186 |
+
dtype = dtype or self.text_encoder.dtype
|
| 187 |
+
|
| 188 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 189 |
+
|
| 190 |
+
text_inputs = self.tokenizer(
|
| 191 |
+
prompt,
|
| 192 |
+
padding="longest", # not use max length
|
| 193 |
+
max_length=max_sequence_length,
|
| 194 |
+
truncation=True,
|
| 195 |
+
add_special_tokens=True,
|
| 196 |
+
return_tensors="pt",
|
| 197 |
+
)
|
| 198 |
+
text_input_ids = text_inputs.input_ids
|
| 199 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 200 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 201 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 202 |
+
logger.warning(
|
| 203 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 204 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 205 |
+
)
|
| 206 |
+
current_length = text_input_ids.shape[1]
|
| 207 |
+
pad_length = (16 - (current_length % 16)) % 16
|
| 208 |
+
if pad_length > 0:
|
| 209 |
+
pad_ids = torch.full(
|
| 210 |
+
(text_input_ids.shape[0], pad_length),
|
| 211 |
+
fill_value=self.tokenizer.pad_token_id,
|
| 212 |
+
dtype=text_input_ids.dtype,
|
| 213 |
+
device=text_input_ids.device,
|
| 214 |
+
)
|
| 215 |
+
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
| 216 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
|
| 217 |
+
|
| 218 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 219 |
+
return prompt_embeds
|
| 220 |
+
|
| 221 |
+
def encode_prompt(
|
| 222 |
+
self,
|
| 223 |
+
prompt: Union[str, List[str]],
|
| 224 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 225 |
+
do_classifier_free_guidance: bool = True,
|
| 226 |
+
num_images_per_prompt: int = 1,
|
| 227 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 228 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 229 |
+
device: Optional[torch.device] = None,
|
| 230 |
+
dtype: Optional[torch.dtype] = None,
|
| 231 |
+
max_sequence_length: int = 1024,
|
| 232 |
+
):
|
| 233 |
+
r"""
|
| 234 |
+
Encodes the prompt into text encoder hidden states.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 238 |
+
prompt to be encoded
|
| 239 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 240 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 241 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 242 |
+
less than `1`).
|
| 243 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 244 |
+
Whether to use classifier free guidance or not.
|
| 245 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 246 |
+
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
| 247 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 248 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 249 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 250 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 251 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 252 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 253 |
+
argument.
|
| 254 |
+
device: (`torch.device`, *optional*):
|
| 255 |
+
torch device
|
| 256 |
+
dtype: (`torch.dtype`, *optional*):
|
| 257 |
+
torch dtype
|
| 258 |
+
max_sequence_length (`int`, defaults to `1024`):
|
| 259 |
+
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
| 260 |
+
"""
|
| 261 |
+
device = device or self._execution_device
|
| 262 |
+
|
| 263 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 264 |
+
if prompt is not None:
|
| 265 |
+
batch_size = len(prompt)
|
| 266 |
+
else:
|
| 267 |
+
batch_size = prompt_embeds.shape[0]
|
| 268 |
+
|
| 269 |
+
if prompt_embeds is None:
|
| 270 |
+
prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
|
| 271 |
+
|
| 272 |
+
seq_len = prompt_embeds.size(1)
|
| 273 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 274 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 275 |
+
|
| 276 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 277 |
+
negative_prompt = negative_prompt or ""
|
| 278 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 279 |
+
|
| 280 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 281 |
+
raise TypeError(
|
| 282 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 283 |
+
f" {type(prompt)}."
|
| 284 |
+
)
|
| 285 |
+
elif batch_size != len(negative_prompt):
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 288 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 289 |
+
" the batch size of `prompt`."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
|
| 293 |
+
|
| 294 |
+
seq_len = negative_prompt_embeds.size(1)
|
| 295 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 296 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 297 |
+
|
| 298 |
+
return prompt_embeds, negative_prompt_embeds
|
| 299 |
+
|
| 300 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 301 |
+
if latents is not None:
|
| 302 |
+
return latents.to(device)
|
| 303 |
+
|
| 304 |
+
shape = (
|
| 305 |
+
batch_size,
|
| 306 |
+
num_channels_latents,
|
| 307 |
+
int(height) // self.vae_scale_factor,
|
| 308 |
+
int(width) // self.vae_scale_factor,
|
| 309 |
+
)
|
| 310 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 311 |
+
raise ValueError(
|
| 312 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 313 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 314 |
+
)
|
| 315 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 316 |
+
return latents
|
| 317 |
+
|
| 318 |
+
def check_inputs(
|
| 319 |
+
self,
|
| 320 |
+
prompt,
|
| 321 |
+
height,
|
| 322 |
+
width,
|
| 323 |
+
negative_prompt,
|
| 324 |
+
callback_on_step_end_tensor_inputs,
|
| 325 |
+
prompt_embeds=None,
|
| 326 |
+
negative_prompt_embeds=None,
|
| 327 |
+
):
|
| 328 |
+
if height % 16 != 0 or width % 16 != 0:
|
| 329 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
| 330 |
+
|
| 331 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 332 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 333 |
+
):
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 336 |
+
)
|
| 337 |
+
if prompt is not None and prompt_embeds is not None:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 340 |
+
" only forward one of the two."
|
| 341 |
+
)
|
| 342 |
+
elif prompt is None and prompt_embeds is None:
|
| 343 |
+
raise ValueError(
|
| 344 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 345 |
+
)
|
| 346 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 347 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 348 |
+
|
| 349 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 350 |
+
raise ValueError(
|
| 351 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 352 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 356 |
+
raise ValueError(
|
| 357 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 358 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 362 |
+
if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
|
| 363 |
+
raise ValueError(
|
| 364 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
|
| 365 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
|
| 366 |
+
f" {negative_prompt_embeds.shape}."
|
| 367 |
+
)
|
| 368 |
+
if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
|
| 369 |
+
raise ValueError(
|
| 370 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
|
| 371 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
|
| 372 |
+
f" {negative_prompt_embeds.shape}."
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
@property
|
| 376 |
+
def guidance_scale(self):
|
| 377 |
+
return self._guidance_scale
|
| 378 |
+
|
| 379 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 380 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 381 |
+
# corresponds to doing no classifier free guidance.
|
| 382 |
+
@property
|
| 383 |
+
def do_classifier_free_guidance(self):
|
| 384 |
+
return self._guidance_scale > 1
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def num_timesteps(self):
|
| 388 |
+
return self._num_timesteps
|
| 389 |
+
|
| 390 |
+
@property
|
| 391 |
+
def attention_kwargs(self):
|
| 392 |
+
return self._attention_kwargs
|
| 393 |
+
|
| 394 |
+
@property
|
| 395 |
+
def current_timestep(self):
|
| 396 |
+
return self._current_timestep
|
| 397 |
+
|
| 398 |
+
@property
|
| 399 |
+
def interrupt(self):
|
| 400 |
+
return self._interrupt
|
| 401 |
+
|
| 402 |
+
@torch.no_grad()
|
| 403 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 404 |
+
def __call__(
|
| 405 |
+
self,
|
| 406 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 407 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 408 |
+
height: Optional[int] = None,
|
| 409 |
+
width: Optional[int] = None,
|
| 410 |
+
num_inference_steps: int = 50,
|
| 411 |
+
timesteps: Optional[List[int]] = None,
|
| 412 |
+
sigmas: Optional[List[float]] = None,
|
| 413 |
+
guidance_scale: float = 5.0,
|
| 414 |
+
num_images_per_prompt: int = 1,
|
| 415 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 416 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 417 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 418 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 419 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 420 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 421 |
+
output_type: str = "pil",
|
| 422 |
+
return_dict: bool = True,
|
| 423 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 424 |
+
callback_on_step_end: Optional[
|
| 425 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 426 |
+
] = None,
|
| 427 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 428 |
+
max_sequence_length: int = 1024,
|
| 429 |
+
) -> Union[CogView4PipelineOutput, Tuple]:
|
| 430 |
+
"""
|
| 431 |
+
Function invoked when calling the pipeline for generation.
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 435 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 436 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 437 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 438 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 439 |
+
less than `1`).
|
| 440 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 441 |
+
The height in pixels of the generated image. If not provided, it is set to 1024.
|
| 442 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 443 |
+
The width in pixels of the generated image. If not provided it is set to 1024.
|
| 444 |
+
num_inference_steps (`int`, *optional*, defaults to `50`):
|
| 445 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 446 |
+
expense of slower inference.
|
| 447 |
+
timesteps (`List[int]`, *optional*):
|
| 448 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 449 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 450 |
+
passed will be used. Must be in descending order.
|
| 451 |
+
sigmas (`List[float]`, *optional*):
|
| 452 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 453 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 454 |
+
will be used.
|
| 455 |
+
guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 456 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 457 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 458 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 459 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 460 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 461 |
+
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
| 462 |
+
The number of images to generate per prompt.
|
| 463 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 464 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 465 |
+
to make generation deterministic.
|
| 466 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 467 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 468 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 469 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 470 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 471 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 472 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 473 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 474 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 475 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 476 |
+
argument.
|
| 477 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 478 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 479 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 480 |
+
explained in section 2.2 of
|
| 481 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 482 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 483 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 484 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 485 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 486 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 487 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 488 |
+
The output format of the generate image. Choose between
|
| 489 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 490 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 491 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 492 |
+
of a plain tuple.
|
| 493 |
+
attention_kwargs (`dict`, *optional*):
|
| 494 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 495 |
+
`self.processor` in
|
| 496 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 497 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 498 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 499 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 500 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 501 |
+
`callback_on_step_end_tensor_inputs`.
|
| 502 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 503 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 504 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 505 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 506 |
+
max_sequence_length (`int`, defaults to `224`):
|
| 507 |
+
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
| 508 |
+
|
| 509 |
+
Examples:
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
|
| 513 |
+
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
|
| 514 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 515 |
+
"""
|
| 516 |
+
|
| 517 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 518 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 519 |
+
|
| 520 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
| 521 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
| 522 |
+
|
| 523 |
+
original_size = original_size or (height, width)
|
| 524 |
+
target_size = (height, width)
|
| 525 |
+
|
| 526 |
+
# Check inputs. Raise error if not correct
|
| 527 |
+
self.check_inputs(
|
| 528 |
+
prompt,
|
| 529 |
+
height,
|
| 530 |
+
width,
|
| 531 |
+
negative_prompt,
|
| 532 |
+
callback_on_step_end_tensor_inputs,
|
| 533 |
+
prompt_embeds,
|
| 534 |
+
negative_prompt_embeds,
|
| 535 |
+
)
|
| 536 |
+
self._guidance_scale = guidance_scale
|
| 537 |
+
self._attention_kwargs = attention_kwargs
|
| 538 |
+
self._current_timestep = None
|
| 539 |
+
self._interrupt = False
|
| 540 |
+
|
| 541 |
+
# Default call parameters
|
| 542 |
+
if prompt is not None and isinstance(prompt, str):
|
| 543 |
+
batch_size = 1
|
| 544 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 545 |
+
batch_size = len(prompt)
|
| 546 |
+
else:
|
| 547 |
+
batch_size = prompt_embeds.shape[0]
|
| 548 |
+
|
| 549 |
+
device = self._execution_device
|
| 550 |
+
|
| 551 |
+
# Encode input prompt
|
| 552 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 553 |
+
prompt,
|
| 554 |
+
negative_prompt,
|
| 555 |
+
self.do_classifier_free_guidance,
|
| 556 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 557 |
+
prompt_embeds=prompt_embeds,
|
| 558 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 559 |
+
max_sequence_length=max_sequence_length,
|
| 560 |
+
device=device,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# Prepare latents
|
| 564 |
+
latent_channels = self.transformer.config.in_channels
|
| 565 |
+
latents = self.prepare_latents(
|
| 566 |
+
batch_size * num_images_per_prompt,
|
| 567 |
+
latent_channels,
|
| 568 |
+
height,
|
| 569 |
+
width,
|
| 570 |
+
torch.float32,
|
| 571 |
+
device,
|
| 572 |
+
generator,
|
| 573 |
+
latents,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Prepare additional timestep conditions
|
| 577 |
+
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
|
| 578 |
+
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
| 579 |
+
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
| 580 |
+
|
| 581 |
+
original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
|
| 582 |
+
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
|
| 583 |
+
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
|
| 584 |
+
|
| 585 |
+
# Prepare timesteps
|
| 586 |
+
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
|
| 587 |
+
self.transformer.config.patch_size**2
|
| 588 |
+
)
|
| 589 |
+
timesteps = (
|
| 590 |
+
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
|
| 591 |
+
if timesteps is None
|
| 592 |
+
else np.array(timesteps)
|
| 593 |
+
)
|
| 594 |
+
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
| 595 |
+
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
|
| 596 |
+
mu = calculate_shift(
|
| 597 |
+
image_seq_len,
|
| 598 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 599 |
+
self.scheduler.config.get("base_shift", 0.25),
|
| 600 |
+
self.scheduler.config.get("max_shift", 0.75),
|
| 601 |
+
)
|
| 602 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 603 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
| 604 |
+
)
|
| 605 |
+
self._num_timesteps = len(timesteps)
|
| 606 |
+
|
| 607 |
+
# Denoising loop
|
| 608 |
+
transformer_dtype = self.transformer.dtype
|
| 609 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 610 |
+
|
| 611 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 612 |
+
for i, t in enumerate(timesteps):
|
| 613 |
+
if self.interrupt:
|
| 614 |
+
continue
|
| 615 |
+
|
| 616 |
+
self._current_timestep = t
|
| 617 |
+
latent_model_input = latents.to(transformer_dtype)
|
| 618 |
+
|
| 619 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 620 |
+
timestep = t.expand(latents.shape[0])
|
| 621 |
+
|
| 622 |
+
with self.transformer.cache_context("cond"):
|
| 623 |
+
noise_pred_cond = self.transformer(
|
| 624 |
+
hidden_states=latent_model_input,
|
| 625 |
+
encoder_hidden_states=prompt_embeds,
|
| 626 |
+
timestep=timestep,
|
| 627 |
+
original_size=original_size,
|
| 628 |
+
target_size=target_size,
|
| 629 |
+
crop_coords=crops_coords_top_left,
|
| 630 |
+
attention_kwargs=attention_kwargs,
|
| 631 |
+
return_dict=False,
|
| 632 |
+
)[0]
|
| 633 |
+
|
| 634 |
+
# perform guidance
|
| 635 |
+
if self.do_classifier_free_guidance:
|
| 636 |
+
with self.transformer.cache_context("uncond"):
|
| 637 |
+
noise_pred_uncond = self.transformer(
|
| 638 |
+
hidden_states=latent_model_input,
|
| 639 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 640 |
+
timestep=timestep,
|
| 641 |
+
original_size=original_size,
|
| 642 |
+
target_size=target_size,
|
| 643 |
+
crop_coords=crops_coords_top_left,
|
| 644 |
+
attention_kwargs=attention_kwargs,
|
| 645 |
+
return_dict=False,
|
| 646 |
+
)[0]
|
| 647 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 648 |
+
else:
|
| 649 |
+
noise_pred = noise_pred_cond
|
| 650 |
+
|
| 651 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 652 |
+
|
| 653 |
+
# call the callback, if provided
|
| 654 |
+
if callback_on_step_end is not None:
|
| 655 |
+
callback_kwargs = {}
|
| 656 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 657 |
+
callback_kwargs[k] = locals()[k]
|
| 658 |
+
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
| 659 |
+
latents = callback_outputs.pop("latents", latents)
|
| 660 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 661 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 662 |
+
|
| 663 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 664 |
+
progress_bar.update()
|
| 665 |
+
|
| 666 |
+
if XLA_AVAILABLE:
|
| 667 |
+
xm.mark_step()
|
| 668 |
+
|
| 669 |
+
self._current_timestep = None
|
| 670 |
+
|
| 671 |
+
if not output_type == "latent":
|
| 672 |
+
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
| 673 |
+
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
| 674 |
+
else:
|
| 675 |
+
image = latents
|
| 676 |
+
|
| 677 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 678 |
+
|
| 679 |
+
# Offload all models
|
| 680 |
+
self.maybe_free_model_hooks()
|
| 681 |
+
|
| 682 |
+
if not return_dict:
|
| 683 |
+
return (image,)
|
| 684 |
+
|
| 685 |
+
return CogView4PipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
ADDED
|
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import AutoTokenizer, GlmModel
|
| 22 |
+
|
| 23 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 24 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 25 |
+
from ...models import AutoencoderKL, CogView4Transformer2DModel
|
| 26 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 27 |
+
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 28 |
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 29 |
+
from ...utils.torch_utils import randn_tensor
|
| 30 |
+
from .pipeline_output import CogView4PipelineOutput
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if is_torch_xla_available():
|
| 34 |
+
import torch_xla.core.xla_model as xm
|
| 35 |
+
|
| 36 |
+
XLA_AVAILABLE = True
|
| 37 |
+
else:
|
| 38 |
+
XLA_AVAILABLE = False
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
+
|
| 42 |
+
EXAMPLE_DOC_STRING = """
|
| 43 |
+
Examples:
|
| 44 |
+
```python
|
| 45 |
+
>>> import torch
|
| 46 |
+
>>> from diffusers import CogView4ControlPipeline
|
| 47 |
+
|
| 48 |
+
>>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
|
| 49 |
+
>>> control_image = load_image(
|
| 50 |
+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
| 51 |
+
... )
|
| 52 |
+
>>> prompt = "A bird in space"
|
| 53 |
+
>>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
|
| 54 |
+
>>> image.save("cogview4-control.png")
|
| 55 |
+
```
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
|
| 60 |
+
def calculate_shift(
|
| 61 |
+
image_seq_len,
|
| 62 |
+
base_seq_len: int = 256,
|
| 63 |
+
base_shift: float = 0.25,
|
| 64 |
+
max_shift: float = 0.75,
|
| 65 |
+
) -> float:
|
| 66 |
+
m = (image_seq_len / base_seq_len) ** 0.5
|
| 67 |
+
mu = m * max_shift + base_shift
|
| 68 |
+
return mu
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
|
| 72 |
+
def retrieve_timesteps(
|
| 73 |
+
scheduler,
|
| 74 |
+
num_inference_steps: Optional[int] = None,
|
| 75 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 76 |
+
timesteps: Optional[List[int]] = None,
|
| 77 |
+
sigmas: Optional[List[float]] = None,
|
| 78 |
+
**kwargs,
|
| 79 |
+
):
|
| 80 |
+
r"""
|
| 81 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 82 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
scheduler (`SchedulerMixin`):
|
| 86 |
+
The scheduler to get timesteps from.
|
| 87 |
+
num_inference_steps (`int`):
|
| 88 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 89 |
+
must be `None`.
|
| 90 |
+
device (`str` or `torch.device`, *optional*):
|
| 91 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 92 |
+
timesteps (`List[int]`, *optional*):
|
| 93 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 94 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 95 |
+
sigmas (`List[float]`, *optional*):
|
| 96 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 97 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 101 |
+
second element is the number of inference steps.
|
| 102 |
+
"""
|
| 103 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 104 |
+
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 105 |
+
|
| 106 |
+
if timesteps is not None and sigmas is not None:
|
| 107 |
+
if not accepts_timesteps and not accepts_sigmas:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 110 |
+
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
|
| 111 |
+
)
|
| 112 |
+
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
|
| 113 |
+
timesteps = scheduler.timesteps
|
| 114 |
+
num_inference_steps = len(timesteps)
|
| 115 |
+
elif timesteps is not None and sigmas is None:
|
| 116 |
+
if not accepts_timesteps:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 119 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 120 |
+
)
|
| 121 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 122 |
+
timesteps = scheduler.timesteps
|
| 123 |
+
num_inference_steps = len(timesteps)
|
| 124 |
+
elif timesteps is None and sigmas is not None:
|
| 125 |
+
if not accepts_sigmas:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 128 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 129 |
+
)
|
| 130 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 131 |
+
timesteps = scheduler.timesteps
|
| 132 |
+
num_inference_steps = len(timesteps)
|
| 133 |
+
else:
|
| 134 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 135 |
+
timesteps = scheduler.timesteps
|
| 136 |
+
return timesteps, num_inference_steps
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class CogView4ControlPipeline(DiffusionPipeline):
|
| 140 |
+
r"""
|
| 141 |
+
Pipeline for text-to-image generation using CogView4.
|
| 142 |
+
|
| 143 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 144 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
vae ([`AutoencoderKL`]):
|
| 148 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 149 |
+
text_encoder ([`GLMModel`]):
|
| 150 |
+
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
|
| 151 |
+
tokenizer (`PreTrainedTokenizer`):
|
| 152 |
+
Tokenizer of class
|
| 153 |
+
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
|
| 154 |
+
transformer ([`CogView4Transformer2DModel`]):
|
| 155 |
+
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
|
| 156 |
+
scheduler ([`SchedulerMixin`]):
|
| 157 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
_optional_components = []
|
| 161 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 162 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
tokenizer: AutoTokenizer,
|
| 167 |
+
text_encoder: GlmModel,
|
| 168 |
+
vae: AutoencoderKL,
|
| 169 |
+
transformer: CogView4Transformer2DModel,
|
| 170 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
|
| 174 |
+
self.register_modules(
|
| 175 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 176 |
+
)
|
| 177 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 178 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 179 |
+
|
| 180 |
+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds
|
| 181 |
+
def _get_glm_embeds(
|
| 182 |
+
self,
|
| 183 |
+
prompt: Union[str, List[str]] = None,
|
| 184 |
+
max_sequence_length: int = 1024,
|
| 185 |
+
device: Optional[torch.device] = None,
|
| 186 |
+
dtype: Optional[torch.dtype] = None,
|
| 187 |
+
):
|
| 188 |
+
device = device or self._execution_device
|
| 189 |
+
dtype = dtype or self.text_encoder.dtype
|
| 190 |
+
|
| 191 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 192 |
+
|
| 193 |
+
text_inputs = self.tokenizer(
|
| 194 |
+
prompt,
|
| 195 |
+
padding="longest", # not use max length
|
| 196 |
+
max_length=max_sequence_length,
|
| 197 |
+
truncation=True,
|
| 198 |
+
add_special_tokens=True,
|
| 199 |
+
return_tensors="pt",
|
| 200 |
+
)
|
| 201 |
+
text_input_ids = text_inputs.input_ids
|
| 202 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 203 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 204 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 205 |
+
logger.warning(
|
| 206 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 207 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 208 |
+
)
|
| 209 |
+
current_length = text_input_ids.shape[1]
|
| 210 |
+
pad_length = (16 - (current_length % 16)) % 16
|
| 211 |
+
if pad_length > 0:
|
| 212 |
+
pad_ids = torch.full(
|
| 213 |
+
(text_input_ids.shape[0], pad_length),
|
| 214 |
+
fill_value=self.tokenizer.pad_token_id,
|
| 215 |
+
dtype=text_input_ids.dtype,
|
| 216 |
+
device=text_input_ids.device,
|
| 217 |
+
)
|
| 218 |
+
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
| 219 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
|
| 220 |
+
|
| 221 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 222 |
+
return prompt_embeds
|
| 223 |
+
|
| 224 |
+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
|
| 225 |
+
def encode_prompt(
|
| 226 |
+
self,
|
| 227 |
+
prompt: Union[str, List[str]],
|
| 228 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 229 |
+
do_classifier_free_guidance: bool = True,
|
| 230 |
+
num_images_per_prompt: int = 1,
|
| 231 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 232 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 233 |
+
device: Optional[torch.device] = None,
|
| 234 |
+
dtype: Optional[torch.dtype] = None,
|
| 235 |
+
max_sequence_length: int = 1024,
|
| 236 |
+
):
|
| 237 |
+
r"""
|
| 238 |
+
Encodes the prompt into text encoder hidden states.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 242 |
+
prompt to be encoded
|
| 243 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 244 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 245 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 246 |
+
less than `1`).
|
| 247 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 248 |
+
Whether to use classifier free guidance or not.
|
| 249 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 250 |
+
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
| 251 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 252 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 253 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 254 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 255 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 256 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 257 |
+
argument.
|
| 258 |
+
device: (`torch.device`, *optional*):
|
| 259 |
+
torch device
|
| 260 |
+
dtype: (`torch.dtype`, *optional*):
|
| 261 |
+
torch dtype
|
| 262 |
+
max_sequence_length (`int`, defaults to `1024`):
|
| 263 |
+
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
| 264 |
+
"""
|
| 265 |
+
device = device or self._execution_device
|
| 266 |
+
|
| 267 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 268 |
+
if prompt is not None:
|
| 269 |
+
batch_size = len(prompt)
|
| 270 |
+
else:
|
| 271 |
+
batch_size = prompt_embeds.shape[0]
|
| 272 |
+
|
| 273 |
+
if prompt_embeds is None:
|
| 274 |
+
prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
|
| 275 |
+
|
| 276 |
+
seq_len = prompt_embeds.size(1)
|
| 277 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 278 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 279 |
+
|
| 280 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 281 |
+
negative_prompt = negative_prompt or ""
|
| 282 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 283 |
+
|
| 284 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 285 |
+
raise TypeError(
|
| 286 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 287 |
+
f" {type(prompt)}."
|
| 288 |
+
)
|
| 289 |
+
elif batch_size != len(negative_prompt):
|
| 290 |
+
raise ValueError(
|
| 291 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 292 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 293 |
+
" the batch size of `prompt`."
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
|
| 297 |
+
|
| 298 |
+
seq_len = negative_prompt_embeds.size(1)
|
| 299 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 300 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 301 |
+
|
| 302 |
+
return prompt_embeds, negative_prompt_embeds
|
| 303 |
+
|
| 304 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 305 |
+
if latents is not None:
|
| 306 |
+
return latents.to(device)
|
| 307 |
+
|
| 308 |
+
shape = (
|
| 309 |
+
batch_size,
|
| 310 |
+
num_channels_latents,
|
| 311 |
+
int(height) // self.vae_scale_factor,
|
| 312 |
+
int(width) // self.vae_scale_factor,
|
| 313 |
+
)
|
| 314 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 315 |
+
raise ValueError(
|
| 316 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 317 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 318 |
+
)
|
| 319 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 320 |
+
return latents
|
| 321 |
+
|
| 322 |
+
def prepare_image(
|
| 323 |
+
self,
|
| 324 |
+
image,
|
| 325 |
+
width,
|
| 326 |
+
height,
|
| 327 |
+
batch_size,
|
| 328 |
+
num_images_per_prompt,
|
| 329 |
+
device,
|
| 330 |
+
dtype,
|
| 331 |
+
do_classifier_free_guidance=False,
|
| 332 |
+
guess_mode=False,
|
| 333 |
+
):
|
| 334 |
+
if isinstance(image, torch.Tensor):
|
| 335 |
+
pass
|
| 336 |
+
else:
|
| 337 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
| 338 |
+
|
| 339 |
+
image_batch_size = image.shape[0]
|
| 340 |
+
|
| 341 |
+
if image_batch_size == 1:
|
| 342 |
+
repeat_by = batch_size
|
| 343 |
+
else:
|
| 344 |
+
# image batch size is the same as prompt batch size
|
| 345 |
+
repeat_by = num_images_per_prompt
|
| 346 |
+
|
| 347 |
+
image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by)
|
| 348 |
+
|
| 349 |
+
image = image.to(device=device, dtype=dtype)
|
| 350 |
+
|
| 351 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 352 |
+
image = torch.cat([image] * 2)
|
| 353 |
+
|
| 354 |
+
return image
|
| 355 |
+
|
| 356 |
+
def check_inputs(
|
| 357 |
+
self,
|
| 358 |
+
prompt,
|
| 359 |
+
height,
|
| 360 |
+
width,
|
| 361 |
+
negative_prompt,
|
| 362 |
+
callback_on_step_end_tensor_inputs,
|
| 363 |
+
prompt_embeds=None,
|
| 364 |
+
negative_prompt_embeds=None,
|
| 365 |
+
):
|
| 366 |
+
if height % 16 != 0 or width % 16 != 0:
|
| 367 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
| 368 |
+
|
| 369 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 370 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 371 |
+
):
|
| 372 |
+
raise ValueError(
|
| 373 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 374 |
+
)
|
| 375 |
+
if prompt is not None and prompt_embeds is not None:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 378 |
+
" only forward one of the two."
|
| 379 |
+
)
|
| 380 |
+
elif prompt is None and prompt_embeds is None:
|
| 381 |
+
raise ValueError(
|
| 382 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 383 |
+
)
|
| 384 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 385 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 386 |
+
|
| 387 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 388 |
+
raise ValueError(
|
| 389 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 390 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 394 |
+
raise ValueError(
|
| 395 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 396 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 400 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 403 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 404 |
+
f" {negative_prompt_embeds.shape}."
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
@property
|
| 408 |
+
def guidance_scale(self):
|
| 409 |
+
return self._guidance_scale
|
| 410 |
+
|
| 411 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 412 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 413 |
+
# corresponds to doing no classifier free guidance.
|
| 414 |
+
@property
|
| 415 |
+
def do_classifier_free_guidance(self):
|
| 416 |
+
return self._guidance_scale > 1
|
| 417 |
+
|
| 418 |
+
@property
|
| 419 |
+
def num_timesteps(self):
|
| 420 |
+
return self._num_timesteps
|
| 421 |
+
|
| 422 |
+
@property
|
| 423 |
+
def attention_kwargs(self):
|
| 424 |
+
return self._attention_kwargs
|
| 425 |
+
|
| 426 |
+
@property
|
| 427 |
+
def current_timestep(self):
|
| 428 |
+
return self._current_timestep
|
| 429 |
+
|
| 430 |
+
@property
|
| 431 |
+
def interrupt(self):
|
| 432 |
+
return self._interrupt
|
| 433 |
+
|
| 434 |
+
@torch.no_grad()
|
| 435 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 436 |
+
def __call__(
|
| 437 |
+
self,
|
| 438 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 439 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 440 |
+
control_image: PipelineImageInput = None,
|
| 441 |
+
height: Optional[int] = None,
|
| 442 |
+
width: Optional[int] = None,
|
| 443 |
+
num_inference_steps: int = 50,
|
| 444 |
+
timesteps: Optional[List[int]] = None,
|
| 445 |
+
sigmas: Optional[List[float]] = None,
|
| 446 |
+
guidance_scale: float = 5.0,
|
| 447 |
+
num_images_per_prompt: int = 1,
|
| 448 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 449 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 450 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 451 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 452 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 453 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 454 |
+
output_type: str = "pil",
|
| 455 |
+
return_dict: bool = True,
|
| 456 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 457 |
+
callback_on_step_end: Optional[
|
| 458 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 459 |
+
] = None,
|
| 460 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 461 |
+
max_sequence_length: int = 1024,
|
| 462 |
+
) -> Union[CogView4PipelineOutput, Tuple]:
|
| 463 |
+
"""
|
| 464 |
+
Function invoked when calling the pipeline for generation.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 468 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 469 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 470 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 471 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 472 |
+
less than `1`).
|
| 473 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 474 |
+
The height in pixels of the generated image. If not provided, it is set to 1024.
|
| 475 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
| 476 |
+
The width in pixels of the generated image. If not provided it is set to 1024.
|
| 477 |
+
num_inference_steps (`int`, *optional*, defaults to `50`):
|
| 478 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 479 |
+
expense of slower inference.
|
| 480 |
+
timesteps (`List[int]`, *optional*):
|
| 481 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 482 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 483 |
+
passed will be used. Must be in descending order.
|
| 484 |
+
sigmas (`List[float]`, *optional*):
|
| 485 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 486 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 487 |
+
will be used.
|
| 488 |
+
guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 489 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 490 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 491 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 492 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 493 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 494 |
+
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
| 495 |
+
The number of images to generate per prompt.
|
| 496 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 497 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 498 |
+
to make generation deterministic.
|
| 499 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 500 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 501 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 502 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 503 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 504 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 505 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 506 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 507 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 508 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 509 |
+
argument.
|
| 510 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 511 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 512 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 513 |
+
explained in section 2.2 of
|
| 514 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 515 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 516 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 517 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 518 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 519 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 520 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 521 |
+
The output format of the generate image. Choose between
|
| 522 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 523 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 524 |
+
Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
|
| 525 |
+
tuple.
|
| 526 |
+
attention_kwargs (`dict`, *optional*):
|
| 527 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 528 |
+
`self.processor` in
|
| 529 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 530 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 531 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 532 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 533 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 534 |
+
`callback_on_step_end_tensor_inputs`.
|
| 535 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 536 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 537 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 538 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 539 |
+
max_sequence_length (`int`, defaults to `224`):
|
| 540 |
+
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
| 541 |
+
Examples:
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
|
| 545 |
+
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
|
| 546 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 550 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 551 |
+
|
| 552 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
| 553 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
| 554 |
+
|
| 555 |
+
original_size = original_size or (height, width)
|
| 556 |
+
target_size = (height, width)
|
| 557 |
+
|
| 558 |
+
# Check inputs. Raise error if not correct
|
| 559 |
+
self.check_inputs(
|
| 560 |
+
prompt,
|
| 561 |
+
height,
|
| 562 |
+
width,
|
| 563 |
+
negative_prompt,
|
| 564 |
+
callback_on_step_end_tensor_inputs,
|
| 565 |
+
prompt_embeds,
|
| 566 |
+
negative_prompt_embeds,
|
| 567 |
+
)
|
| 568 |
+
self._guidance_scale = guidance_scale
|
| 569 |
+
self._attention_kwargs = attention_kwargs
|
| 570 |
+
self._current_timestep = None
|
| 571 |
+
self._interrupt = False
|
| 572 |
+
|
| 573 |
+
# Default call parameters
|
| 574 |
+
if prompt is not None and isinstance(prompt, str):
|
| 575 |
+
batch_size = 1
|
| 576 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 577 |
+
batch_size = len(prompt)
|
| 578 |
+
else:
|
| 579 |
+
batch_size = prompt_embeds.shape[0]
|
| 580 |
+
|
| 581 |
+
device = self._execution_device
|
| 582 |
+
|
| 583 |
+
# Encode input prompt
|
| 584 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 585 |
+
prompt,
|
| 586 |
+
negative_prompt,
|
| 587 |
+
self.do_classifier_free_guidance,
|
| 588 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 589 |
+
prompt_embeds=prompt_embeds,
|
| 590 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 591 |
+
max_sequence_length=max_sequence_length,
|
| 592 |
+
device=device,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Prepare latents
|
| 596 |
+
latent_channels = self.transformer.config.in_channels // 2
|
| 597 |
+
|
| 598 |
+
control_image = self.prepare_image(
|
| 599 |
+
image=control_image,
|
| 600 |
+
width=width,
|
| 601 |
+
height=height,
|
| 602 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 603 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 604 |
+
device=device,
|
| 605 |
+
dtype=self.vae.dtype,
|
| 606 |
+
)
|
| 607 |
+
height, width = control_image.shape[-2:]
|
| 608 |
+
|
| 609 |
+
vae_shift_factor = 0
|
| 610 |
+
|
| 611 |
+
control_image = self.vae.encode(control_image).latent_dist.sample()
|
| 612 |
+
control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
|
| 613 |
+
|
| 614 |
+
latents = self.prepare_latents(
|
| 615 |
+
batch_size * num_images_per_prompt,
|
| 616 |
+
latent_channels,
|
| 617 |
+
height,
|
| 618 |
+
width,
|
| 619 |
+
torch.float32,
|
| 620 |
+
device,
|
| 621 |
+
generator,
|
| 622 |
+
latents,
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Prepare additional timestep conditions
|
| 626 |
+
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
|
| 627 |
+
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
| 628 |
+
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
| 629 |
+
|
| 630 |
+
original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
|
| 631 |
+
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
|
| 632 |
+
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
|
| 633 |
+
|
| 634 |
+
# Prepare timesteps
|
| 635 |
+
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
|
| 636 |
+
self.transformer.config.patch_size**2
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
timesteps = (
|
| 640 |
+
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
|
| 641 |
+
if timesteps is None
|
| 642 |
+
else np.array(timesteps)
|
| 643 |
+
)
|
| 644 |
+
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
| 645 |
+
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
|
| 646 |
+
mu = calculate_shift(
|
| 647 |
+
image_seq_len,
|
| 648 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 649 |
+
self.scheduler.config.get("base_shift", 0.25),
|
| 650 |
+
self.scheduler.config.get("max_shift", 0.75),
|
| 651 |
+
)
|
| 652 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 653 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
| 654 |
+
)
|
| 655 |
+
self._num_timesteps = len(timesteps)
|
| 656 |
+
# Denoising loop
|
| 657 |
+
transformer_dtype = self.transformer.dtype
|
| 658 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 659 |
+
|
| 660 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 661 |
+
for i, t in enumerate(timesteps):
|
| 662 |
+
if self.interrupt:
|
| 663 |
+
continue
|
| 664 |
+
|
| 665 |
+
self._current_timestep = t
|
| 666 |
+
latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype)
|
| 667 |
+
|
| 668 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 669 |
+
timestep = t.expand(latents.shape[0])
|
| 670 |
+
|
| 671 |
+
noise_pred_cond = self.transformer(
|
| 672 |
+
hidden_states=latent_model_input,
|
| 673 |
+
encoder_hidden_states=prompt_embeds,
|
| 674 |
+
timestep=timestep,
|
| 675 |
+
original_size=original_size,
|
| 676 |
+
target_size=target_size,
|
| 677 |
+
crop_coords=crops_coords_top_left,
|
| 678 |
+
attention_kwargs=attention_kwargs,
|
| 679 |
+
return_dict=False,
|
| 680 |
+
)[0]
|
| 681 |
+
|
| 682 |
+
# perform guidance
|
| 683 |
+
if self.do_classifier_free_guidance:
|
| 684 |
+
noise_pred_uncond = self.transformer(
|
| 685 |
+
hidden_states=latent_model_input,
|
| 686 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 687 |
+
timestep=timestep,
|
| 688 |
+
original_size=original_size,
|
| 689 |
+
target_size=target_size,
|
| 690 |
+
crop_coords=crops_coords_top_left,
|
| 691 |
+
attention_kwargs=attention_kwargs,
|
| 692 |
+
return_dict=False,
|
| 693 |
+
)[0]
|
| 694 |
+
|
| 695 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 696 |
+
else:
|
| 697 |
+
noise_pred = noise_pred_cond
|
| 698 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 699 |
+
|
| 700 |
+
# call the callback, if provided
|
| 701 |
+
if callback_on_step_end is not None:
|
| 702 |
+
callback_kwargs = {}
|
| 703 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 704 |
+
callback_kwargs[k] = locals()[k]
|
| 705 |
+
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
| 706 |
+
latents = callback_outputs.pop("latents", latents)
|
| 707 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 708 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 709 |
+
|
| 710 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 711 |
+
progress_bar.update()
|
| 712 |
+
|
| 713 |
+
if XLA_AVAILABLE:
|
| 714 |
+
xm.mark_step()
|
| 715 |
+
|
| 716 |
+
self._current_timestep = None
|
| 717 |
+
|
| 718 |
+
if not output_type == "latent":
|
| 719 |
+
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
| 720 |
+
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
| 721 |
+
else:
|
| 722 |
+
image = latents
|
| 723 |
+
|
| 724 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 725 |
+
|
| 726 |
+
# Offload all models
|
| 727 |
+
self.maybe_free_model_hooks()
|
| 728 |
+
|
| 729 |
+
if not return_dict:
|
| 730 |
+
return (image,)
|
| 731 |
+
|
| 732 |
+
return CogView4PipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_output.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from ...utils import BaseOutput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class CogView4PipelineOutput(BaseOutput):
|
| 12 |
+
"""
|
| 13 |
+
Output class for CogView3 pipelines.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 17 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 18 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_opencv_available,
|
| 9 |
+
is_torch_available,
|
| 10 |
+
is_transformers_available,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_dummy_objects = {}
|
| 15 |
+
_import_structure = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
if not (is_transformers_available() and is_torch_available() and is_opencv_available()):
|
| 20 |
+
raise OptionalDependencyNotAvailable()
|
| 21 |
+
except OptionalDependencyNotAvailable:
|
| 22 |
+
from ...utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
|
| 23 |
+
|
| 24 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
|
| 25 |
+
else:
|
| 26 |
+
_import_structure["pipeline_consisid"] = ["ConsisIDPipeline"]
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 29 |
+
try:
|
| 30 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 31 |
+
raise OptionalDependencyNotAvailable()
|
| 32 |
+
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 35 |
+
else:
|
| 36 |
+
from .pipeline_consisid import ConsisIDPipeline
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
import sys
|
| 40 |
+
|
| 41 |
+
sys.modules[__name__] = _LazyModule(
|
| 42 |
+
__name__,
|
| 43 |
+
globals()["__file__"],
|
| 44 |
+
_import_structure,
|
| 45 |
+
module_spec=__spec__,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
for name, value in _dummy_objects.items():
|
| 49 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/consisid_utils.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image, ImageOps
|
| 8 |
+
from torchvision.transforms import InterpolationMode
|
| 9 |
+
from torchvision.transforms.functional import normalize, resize
|
| 10 |
+
|
| 11 |
+
from ...utils import get_logger, load_image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
_insightface_available = importlib.util.find_spec("insightface") is not None
|
| 17 |
+
_consisid_eva_clip_available = importlib.util.find_spec("consisid_eva_clip") is not None
|
| 18 |
+
_facexlib_available = importlib.util.find_spec("facexlib") is not None
|
| 19 |
+
|
| 20 |
+
if _insightface_available:
|
| 21 |
+
import insightface
|
| 22 |
+
from insightface.app import FaceAnalysis
|
| 23 |
+
else:
|
| 24 |
+
raise ImportError("insightface is not available. Please install it using 'pip install insightface'.")
|
| 25 |
+
|
| 26 |
+
if _consisid_eva_clip_available:
|
| 27 |
+
from consisid_eva_clip import create_model_and_transforms
|
| 28 |
+
from consisid_eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 29 |
+
else:
|
| 30 |
+
raise ImportError("consisid_eva_clip is not available. Please install it using 'pip install consisid_eva_clip'.")
|
| 31 |
+
|
| 32 |
+
if _facexlib_available:
|
| 33 |
+
from facexlib.parsing import init_parsing_model
|
| 34 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
| 35 |
+
else:
|
| 36 |
+
raise ImportError("facexlib is not available. Please install it using 'pip install facexlib'.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def resize_numpy_image_long(image, resize_long_edge=768):
|
| 40 |
+
"""
|
| 41 |
+
Resize the input image to a specified long edge while maintaining aspect ratio.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
image (numpy.ndarray): Input image (H x W x C or H x W).
|
| 45 |
+
resize_long_edge (int): The target size for the long edge of the image. Default is 768.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
numpy.ndarray: Resized image with the long edge matching `resize_long_edge`, while maintaining the aspect
|
| 49 |
+
ratio.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
h, w = image.shape[:2]
|
| 53 |
+
if max(h, w) <= resize_long_edge:
|
| 54 |
+
return image
|
| 55 |
+
k = resize_long_edge / max(h, w)
|
| 56 |
+
h = int(h * k)
|
| 57 |
+
w = int(w * k)
|
| 58 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
| 59 |
+
return image
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 63 |
+
"""Numpy array to tensor.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
imgs (list[ndarray] | ndarray): Input images.
|
| 67 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 68 |
+
float32 (bool): Whether to change to float32.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
| 72 |
+
one element, just return tensor.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def _totensor(img, bgr2rgb, float32):
|
| 76 |
+
if img.shape[2] == 3 and bgr2rgb:
|
| 77 |
+
if img.dtype == "float64":
|
| 78 |
+
img = img.astype("float32")
|
| 79 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 80 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 81 |
+
if float32:
|
| 82 |
+
img = img.float()
|
| 83 |
+
return img
|
| 84 |
+
|
| 85 |
+
if isinstance(imgs, list):
|
| 86 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 87 |
+
return _totensor(imgs, bgr2rgb, float32)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def to_gray(img):
|
| 91 |
+
"""
|
| 92 |
+
Converts an RGB image to grayscale by applying the standard luminosity formula.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
img (torch.Tensor): The input image tensor with shape (batch_size, channels, height, width).
|
| 96 |
+
The image is expected to be in RGB format (3 channels).
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
torch.Tensor: The grayscale image tensor with shape (batch_size, 3, height, width).
|
| 100 |
+
The grayscale values are replicated across all three channels.
|
| 101 |
+
"""
|
| 102 |
+
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
|
| 103 |
+
x = x.repeat(1, 3, 1, 1)
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def process_face_embeddings(
|
| 108 |
+
face_helper_1,
|
| 109 |
+
clip_vision_model,
|
| 110 |
+
face_helper_2,
|
| 111 |
+
eva_transform_mean,
|
| 112 |
+
eva_transform_std,
|
| 113 |
+
app,
|
| 114 |
+
device,
|
| 115 |
+
weight_dtype,
|
| 116 |
+
image,
|
| 117 |
+
original_id_image=None,
|
| 118 |
+
is_align_face=True,
|
| 119 |
+
):
|
| 120 |
+
"""
|
| 121 |
+
Process face embeddings from an image, extracting relevant features such as face embeddings, landmarks, and parsed
|
| 122 |
+
face features using a series of face detection and alignment tools.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
face_helper_1: Face helper object (first helper) for alignment and landmark detection.
|
| 126 |
+
clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
|
| 127 |
+
face_helper_2: Face helper object (second helper) for embedding extraction.
|
| 128 |
+
eva_transform_mean: Mean values for image normalization before passing to EVA model.
|
| 129 |
+
eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
|
| 130 |
+
app: Application instance used for face detection.
|
| 131 |
+
device: Device (CPU or GPU) where the computations will be performed.
|
| 132 |
+
weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
|
| 133 |
+
image: Input image in RGB format with pixel values in the range [0, 255].
|
| 134 |
+
original_id_image: (Optional) Original image for feature extraction if `is_align_face` is False.
|
| 135 |
+
is_align_face: Boolean flag indicating whether face alignment should be performed.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Tuple:
|
| 139 |
+
- id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding
|
| 140 |
+
- id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
|
| 141 |
+
- return_face_features_image_2: Processed face features image after normalization and parsing.
|
| 142 |
+
- face_kps: Keypoints of the face detected in the image.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
face_helper_1.clean_all()
|
| 146 |
+
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 147 |
+
# get antelopev2 embedding
|
| 148 |
+
face_info = app.get(image_bgr)
|
| 149 |
+
if len(face_info) > 0:
|
| 150 |
+
face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[
|
| 151 |
+
-1
|
| 152 |
+
] # only use the maximum face
|
| 153 |
+
id_ante_embedding = face_info["embedding"] # (512,)
|
| 154 |
+
face_kps = face_info["kps"]
|
| 155 |
+
else:
|
| 156 |
+
id_ante_embedding = None
|
| 157 |
+
face_kps = None
|
| 158 |
+
|
| 159 |
+
# using facexlib to detect and align face
|
| 160 |
+
face_helper_1.read_image(image_bgr)
|
| 161 |
+
face_helper_1.get_face_landmarks_5(only_center_face=True)
|
| 162 |
+
if face_kps is None:
|
| 163 |
+
face_kps = face_helper_1.all_landmarks_5[0]
|
| 164 |
+
face_helper_1.align_warp_face()
|
| 165 |
+
if len(face_helper_1.cropped_faces) == 0:
|
| 166 |
+
raise RuntimeError("facexlib align face fail")
|
| 167 |
+
align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
|
| 168 |
+
|
| 169 |
+
# in case insightface didn't detect face
|
| 170 |
+
if id_ante_embedding is None:
|
| 171 |
+
logger.warning("Failed to detect face using insightface. Extracting embedding with align face")
|
| 172 |
+
id_ante_embedding = face_helper_2.get_feat(align_face)
|
| 173 |
+
|
| 174 |
+
id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
|
| 175 |
+
if id_ante_embedding.ndim == 1:
|
| 176 |
+
id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512])
|
| 177 |
+
|
| 178 |
+
# parsing
|
| 179 |
+
if is_align_face:
|
| 180 |
+
input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
|
| 181 |
+
input = input.to(device)
|
| 182 |
+
parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
|
| 183 |
+
parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
|
| 184 |
+
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
|
| 185 |
+
bg = sum(parsing_out == i for i in bg_label).bool()
|
| 186 |
+
white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512])
|
| 187 |
+
# only keep the face features
|
| 188 |
+
return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512])
|
| 189 |
+
return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512])
|
| 190 |
+
else:
|
| 191 |
+
original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR)
|
| 192 |
+
input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
|
| 193 |
+
input = input.to(device)
|
| 194 |
+
return_face_features_image = return_face_features_image_2 = input
|
| 195 |
+
|
| 196 |
+
# transform img before sending to eva-clip-vit
|
| 197 |
+
face_features_image = resize(
|
| 198 |
+
return_face_features_image, clip_vision_model.image_size, InterpolationMode.BICUBIC
|
| 199 |
+
) # torch.Size([1, 3, 336, 336])
|
| 200 |
+
face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std)
|
| 201 |
+
id_cond_vit, id_vit_hidden = clip_vision_model(
|
| 202 |
+
face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False
|
| 203 |
+
) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024]))
|
| 204 |
+
id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
|
| 205 |
+
id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
|
| 206 |
+
|
| 207 |
+
id_cond = torch.cat(
|
| 208 |
+
[id_ante_embedding, id_cond_vit], dim=-1
|
| 209 |
+
) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
|
| 210 |
+
|
| 211 |
+
return (
|
| 212 |
+
id_cond,
|
| 213 |
+
id_vit_hidden,
|
| 214 |
+
return_face_features_image_2,
|
| 215 |
+
face_kps,
|
| 216 |
+
) # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def process_face_embeddings_infer(
|
| 220 |
+
face_helper_1,
|
| 221 |
+
clip_vision_model,
|
| 222 |
+
face_helper_2,
|
| 223 |
+
eva_transform_mean,
|
| 224 |
+
eva_transform_std,
|
| 225 |
+
app,
|
| 226 |
+
device,
|
| 227 |
+
weight_dtype,
|
| 228 |
+
img_file_path,
|
| 229 |
+
is_align_face=True,
|
| 230 |
+
):
|
| 231 |
+
"""
|
| 232 |
+
Process face embeddings from an input image for inference, including alignment, feature extraction, and embedding
|
| 233 |
+
concatenation.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
face_helper_1: Face helper object (first helper) for alignment and landmark detection.
|
| 237 |
+
clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
|
| 238 |
+
face_helper_2: Face helper object (second helper) for embedding extraction.
|
| 239 |
+
eva_transform_mean: Mean values for image normalization before passing to EVA model.
|
| 240 |
+
eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
|
| 241 |
+
app: Application instance used for face detection.
|
| 242 |
+
device: Device (CPU or GPU) where the computations will be performed.
|
| 243 |
+
weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
|
| 244 |
+
img_file_path: Path to the input image file (string) or a numpy array representing an image.
|
| 245 |
+
is_align_face: Boolean flag indicating whether face alignment should be performed (default: True).
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Tuple:
|
| 249 |
+
- id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding.
|
| 250 |
+
- id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
|
| 251 |
+
- image: Processed face image after feature extraction and alignment.
|
| 252 |
+
- face_kps: Keypoints of the face detected in the image.
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
# Load and preprocess the input image
|
| 256 |
+
if isinstance(img_file_path, str):
|
| 257 |
+
image = np.array(load_image(image=img_file_path).convert("RGB"))
|
| 258 |
+
else:
|
| 259 |
+
image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB"))
|
| 260 |
+
|
| 261 |
+
# Resize image to ensure the longer side is 1024 pixels
|
| 262 |
+
image = resize_numpy_image_long(image, 1024)
|
| 263 |
+
original_id_image = image
|
| 264 |
+
|
| 265 |
+
# Process the image to extract face embeddings and related features
|
| 266 |
+
id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(
|
| 267 |
+
face_helper_1,
|
| 268 |
+
clip_vision_model,
|
| 269 |
+
face_helper_2,
|
| 270 |
+
eva_transform_mean,
|
| 271 |
+
eva_transform_std,
|
| 272 |
+
app,
|
| 273 |
+
device,
|
| 274 |
+
weight_dtype,
|
| 275 |
+
image,
|
| 276 |
+
original_id_image,
|
| 277 |
+
is_align_face,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Convert the aligned cropped face image (torch tensor) to a numpy array
|
| 281 |
+
tensor = align_crop_face_image.cpu().detach()
|
| 282 |
+
tensor = tensor.squeeze()
|
| 283 |
+
tensor = tensor.permute(1, 2, 0)
|
| 284 |
+
tensor = tensor.numpy() * 255
|
| 285 |
+
tensor = tensor.astype(np.uint8)
|
| 286 |
+
image = ImageOps.exif_transpose(Image.fromarray(tensor))
|
| 287 |
+
|
| 288 |
+
return id_cond, id_vit_hidden, image, face_kps
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def prepare_face_models(model_path, device, dtype):
|
| 292 |
+
"""
|
| 293 |
+
Prepare all face models for the facial recognition task.
|
| 294 |
+
|
| 295 |
+
Parameters:
|
| 296 |
+
- model_path: Path to the directory containing model files.
|
| 297 |
+
- device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
|
| 298 |
+
- dtype: Data type (e.g., torch.float32) for model inference.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
- face_helper_1: First face restoration helper.
|
| 302 |
+
- face_helper_2: Second face restoration helper.
|
| 303 |
+
- face_clip_model: CLIP model for face extraction.
|
| 304 |
+
- eva_transform_mean: Mean value for image normalization.
|
| 305 |
+
- eva_transform_std: Standard deviation value for image normalization.
|
| 306 |
+
- face_main_model: Main face analysis model.
|
| 307 |
+
"""
|
| 308 |
+
# get helper model
|
| 309 |
+
face_helper_1 = FaceRestoreHelper(
|
| 310 |
+
upscale_factor=1,
|
| 311 |
+
face_size=512,
|
| 312 |
+
crop_ratio=(1, 1),
|
| 313 |
+
det_model="retinaface_resnet50",
|
| 314 |
+
save_ext="png",
|
| 315 |
+
device=device,
|
| 316 |
+
model_rootpath=os.path.join(model_path, "face_encoder"),
|
| 317 |
+
)
|
| 318 |
+
face_helper_1.face_parse = None
|
| 319 |
+
face_helper_1.face_parse = init_parsing_model(
|
| 320 |
+
model_name="bisenet", device=device, model_rootpath=os.path.join(model_path, "face_encoder")
|
| 321 |
+
)
|
| 322 |
+
face_helper_2 = insightface.model_zoo.get_model(
|
| 323 |
+
f"{model_path}/face_encoder/models/antelopev2/glintr100.onnx", providers=["CUDAExecutionProvider"]
|
| 324 |
+
)
|
| 325 |
+
face_helper_2.prepare(ctx_id=0)
|
| 326 |
+
|
| 327 |
+
# get local facial extractor part 1
|
| 328 |
+
model, _, _ = create_model_and_transforms(
|
| 329 |
+
"EVA02-CLIP-L-14-336",
|
| 330 |
+
os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"),
|
| 331 |
+
force_custom_clip=True,
|
| 332 |
+
)
|
| 333 |
+
face_clip_model = model.visual
|
| 334 |
+
eva_transform_mean = getattr(face_clip_model, "image_mean", OPENAI_DATASET_MEAN)
|
| 335 |
+
eva_transform_std = getattr(face_clip_model, "image_std", OPENAI_DATASET_STD)
|
| 336 |
+
if not isinstance(eva_transform_mean, (list, tuple)):
|
| 337 |
+
eva_transform_mean = (eva_transform_mean,) * 3
|
| 338 |
+
if not isinstance(eva_transform_std, (list, tuple)):
|
| 339 |
+
eva_transform_std = (eva_transform_std,) * 3
|
| 340 |
+
eva_transform_mean = eva_transform_mean
|
| 341 |
+
eva_transform_std = eva_transform_std
|
| 342 |
+
|
| 343 |
+
# get local facial extractor part 2
|
| 344 |
+
face_main_model = FaceAnalysis(
|
| 345 |
+
name="antelopev2", root=os.path.join(model_path, "face_encoder"), providers=["CUDAExecutionProvider"]
|
| 346 |
+
)
|
| 347 |
+
face_main_model.prepare(ctx_id=0, det_size=(640, 640))
|
| 348 |
+
|
| 349 |
+
# move face models to device
|
| 350 |
+
face_helper_1.face_det.eval()
|
| 351 |
+
face_helper_1.face_parse.eval()
|
| 352 |
+
face_clip_model.eval()
|
| 353 |
+
face_helper_1.face_det.to(device)
|
| 354 |
+
face_helper_1.face_parse.to(device)
|
| 355 |
+
face_clip_model.to(device, dtype=dtype)
|
| 356 |
+
|
| 357 |
+
return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_consisid.py
ADDED
|
@@ -0,0 +1,974 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import math
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import PIL
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ...image_processor import PipelineImageInput
|
| 26 |
+
from ...loaders import CogVideoXLoraLoaderMixin
|
| 27 |
+
from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel
|
| 28 |
+
from ...models.embeddings import get_3d_rotary_pos_embed
|
| 29 |
+
from ...pipelines.pipeline_utils import DiffusionPipeline
|
| 30 |
+
from ...schedulers import CogVideoXDPMScheduler
|
| 31 |
+
from ...utils import is_opencv_available, logging, replace_example_docstring
|
| 32 |
+
from ...utils.torch_utils import randn_tensor
|
| 33 |
+
from ...video_processor import VideoProcessor
|
| 34 |
+
from .pipeline_output import ConsisIDPipelineOutput
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if is_opencv_available():
|
| 38 |
+
import cv2
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```python
|
| 47 |
+
>>> import torch
|
| 48 |
+
>>> from diffusers import ConsisIDPipeline
|
| 49 |
+
>>> from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
|
| 50 |
+
>>> from diffusers.utils import export_to_video
|
| 51 |
+
>>> from huggingface_hub import snapshot_download
|
| 52 |
+
|
| 53 |
+
>>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
|
| 54 |
+
>>> (
|
| 55 |
+
... face_helper_1,
|
| 56 |
+
... face_helper_2,
|
| 57 |
+
... face_clip_model,
|
| 58 |
+
... face_main_model,
|
| 59 |
+
... eva_transform_mean,
|
| 60 |
+
... eva_transform_std,
|
| 61 |
+
... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
|
| 62 |
+
>>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
|
| 63 |
+
>>> pipe.to("cuda")
|
| 64 |
+
|
| 65 |
+
>>> # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body).
|
| 66 |
+
>>> prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
|
| 67 |
+
>>> image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
|
| 68 |
+
|
| 69 |
+
>>> id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
|
| 70 |
+
... face_helper_1,
|
| 71 |
+
... face_clip_model,
|
| 72 |
+
... face_helper_2,
|
| 73 |
+
... eva_transform_mean,
|
| 74 |
+
... eva_transform_std,
|
| 75 |
+
... face_main_model,
|
| 76 |
+
... "cuda",
|
| 77 |
+
... torch.bfloat16,
|
| 78 |
+
... image,
|
| 79 |
+
... is_align_face=True,
|
| 80 |
+
... )
|
| 81 |
+
|
| 82 |
+
>>> video = pipe(
|
| 83 |
+
... image=image,
|
| 84 |
+
... prompt=prompt,
|
| 85 |
+
... num_inference_steps=50,
|
| 86 |
+
... guidance_scale=6.0,
|
| 87 |
+
... use_dynamic_cfg=False,
|
| 88 |
+
... id_vit_hidden=id_vit_hidden,
|
| 89 |
+
... id_cond=id_cond,
|
| 90 |
+
... kps_cond=face_kps,
|
| 91 |
+
... generator=torch.Generator("cuda").manual_seed(42),
|
| 92 |
+
... )
|
| 93 |
+
>>> export_to_video(video.frames[0], "output.mp4", fps=8)
|
| 94 |
+
```
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
|
| 99 |
+
"""
|
| 100 |
+
This function draws keypoints and the limbs connecting them on an image.
|
| 101 |
+
|
| 102 |
+
Parameters:
|
| 103 |
+
- image_pil (PIL.Image): Input image as a PIL object.
|
| 104 |
+
- kps (list of tuples): A list of keypoints where each keypoint is a tuple of (x, y) coordinates.
|
| 105 |
+
- color_list (list of tuples, optional): List of colors (in RGB format) for each keypoint. Default is a set of five
|
| 106 |
+
colors.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
- PIL.Image: Image with the keypoints and limbs drawn.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
stickwidth = 4
|
| 113 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
| 114 |
+
kps = np.array(kps)
|
| 115 |
+
|
| 116 |
+
w, h = image_pil.size
|
| 117 |
+
out_img = np.zeros([h, w, 3])
|
| 118 |
+
|
| 119 |
+
for i in range(len(limbSeq)):
|
| 120 |
+
index = limbSeq[i]
|
| 121 |
+
color = color_list[index[0]]
|
| 122 |
+
|
| 123 |
+
x = kps[index][:, 0]
|
| 124 |
+
y = kps[index][:, 1]
|
| 125 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
| 126 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
| 127 |
+
polygon = cv2.ellipse2Poly(
|
| 128 |
+
(int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
|
| 129 |
+
)
|
| 130 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
| 131 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
| 132 |
+
|
| 133 |
+
for idx_kp, kp in enumerate(kps):
|
| 134 |
+
color = color_list[idx_kp]
|
| 135 |
+
x, y = kp
|
| 136 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
| 137 |
+
|
| 138 |
+
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
|
| 139 |
+
return out_img_pil
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 143 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 144 |
+
"""
|
| 145 |
+
This function calculates the resize and crop region for an image to fit a target width and height while preserving
|
| 146 |
+
the aspect ratio.
|
| 147 |
+
|
| 148 |
+
Parameters:
|
| 149 |
+
- src (tuple): A tuple containing the source image's height (h) and width (w).
|
| 150 |
+
- tgt_width (int): The target width to resize the image.
|
| 151 |
+
- tgt_height (int): The target height to resize the image.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
- tuple: Two tuples representing the crop region:
|
| 155 |
+
1. The top-left coordinates of the crop region.
|
| 156 |
+
2. The bottom-right coordinates of the crop region.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
tw = tgt_width
|
| 160 |
+
th = tgt_height
|
| 161 |
+
h, w = src
|
| 162 |
+
r = h / w
|
| 163 |
+
if r > (th / tw):
|
| 164 |
+
resize_height = th
|
| 165 |
+
resize_width = int(round(th / h * w))
|
| 166 |
+
else:
|
| 167 |
+
resize_width = tw
|
| 168 |
+
resize_height = int(round(tw / w * h))
|
| 169 |
+
|
| 170 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 171 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 172 |
+
|
| 173 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 177 |
+
def retrieve_timesteps(
|
| 178 |
+
scheduler,
|
| 179 |
+
num_inference_steps: Optional[int] = None,
|
| 180 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 181 |
+
timesteps: Optional[List[int]] = None,
|
| 182 |
+
sigmas: Optional[List[float]] = None,
|
| 183 |
+
**kwargs,
|
| 184 |
+
):
|
| 185 |
+
r"""
|
| 186 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 187 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
scheduler (`SchedulerMixin`):
|
| 191 |
+
The scheduler to get timesteps from.
|
| 192 |
+
num_inference_steps (`int`):
|
| 193 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 194 |
+
must be `None`.
|
| 195 |
+
device (`str` or `torch.device`, *optional*):
|
| 196 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 197 |
+
timesteps (`List[int]`, *optional*):
|
| 198 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 199 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 200 |
+
sigmas (`List[float]`, *optional*):
|
| 201 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 202 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 206 |
+
second element is the number of inference steps.
|
| 207 |
+
"""
|
| 208 |
+
if timesteps is not None and sigmas is not None:
|
| 209 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 210 |
+
if timesteps is not None:
|
| 211 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 212 |
+
if not accepts_timesteps:
|
| 213 |
+
raise ValueError(
|
| 214 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 215 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 216 |
+
)
|
| 217 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 218 |
+
timesteps = scheduler.timesteps
|
| 219 |
+
num_inference_steps = len(timesteps)
|
| 220 |
+
elif sigmas is not None:
|
| 221 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 222 |
+
if not accept_sigmas:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 225 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 226 |
+
)
|
| 227 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 228 |
+
timesteps = scheduler.timesteps
|
| 229 |
+
num_inference_steps = len(timesteps)
|
| 230 |
+
else:
|
| 231 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 232 |
+
timesteps = scheduler.timesteps
|
| 233 |
+
return timesteps, num_inference_steps
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 237 |
+
def retrieve_latents(
|
| 238 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 239 |
+
):
|
| 240 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 241 |
+
return encoder_output.latent_dist.sample(generator)
|
| 242 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 243 |
+
return encoder_output.latent_dist.mode()
|
| 244 |
+
elif hasattr(encoder_output, "latents"):
|
| 245 |
+
return encoder_output.latents
|
| 246 |
+
else:
|
| 247 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 251 |
+
r"""
|
| 252 |
+
Pipeline for image-to-video generation using ConsisID.
|
| 253 |
+
|
| 254 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 255 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
vae ([`AutoencoderKL`]):
|
| 259 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 260 |
+
text_encoder ([`T5EncoderModel`]):
|
| 261 |
+
Frozen text-encoder. ConsisID uses
|
| 262 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 263 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 264 |
+
tokenizer (`T5Tokenizer`):
|
| 265 |
+
Tokenizer of class
|
| 266 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 267 |
+
transformer ([`ConsisIDTransformer3DModel`]):
|
| 268 |
+
A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents.
|
| 269 |
+
scheduler ([`SchedulerMixin`]):
|
| 270 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
_optional_components = []
|
| 274 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 275 |
+
|
| 276 |
+
_callback_tensor_inputs = [
|
| 277 |
+
"latents",
|
| 278 |
+
"prompt_embeds",
|
| 279 |
+
"negative_prompt_embeds",
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
tokenizer: T5Tokenizer,
|
| 285 |
+
text_encoder: T5EncoderModel,
|
| 286 |
+
vae: AutoencoderKLCogVideoX,
|
| 287 |
+
transformer: ConsisIDTransformer3DModel,
|
| 288 |
+
scheduler: CogVideoXDPMScheduler,
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
|
| 292 |
+
self.register_modules(
|
| 293 |
+
tokenizer=tokenizer,
|
| 294 |
+
text_encoder=text_encoder,
|
| 295 |
+
vae=vae,
|
| 296 |
+
transformer=transformer,
|
| 297 |
+
scheduler=scheduler,
|
| 298 |
+
)
|
| 299 |
+
self.vae_scale_factor_spatial = (
|
| 300 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 301 |
+
)
|
| 302 |
+
self.vae_scale_factor_temporal = (
|
| 303 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 304 |
+
)
|
| 305 |
+
self.vae_scaling_factor_image = (
|
| 306 |
+
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 310 |
+
|
| 311 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
| 312 |
+
def _get_t5_prompt_embeds(
|
| 313 |
+
self,
|
| 314 |
+
prompt: Union[str, List[str]] = None,
|
| 315 |
+
num_videos_per_prompt: int = 1,
|
| 316 |
+
max_sequence_length: int = 226,
|
| 317 |
+
device: Optional[torch.device] = None,
|
| 318 |
+
dtype: Optional[torch.dtype] = None,
|
| 319 |
+
):
|
| 320 |
+
device = device or self._execution_device
|
| 321 |
+
dtype = dtype or self.text_encoder.dtype
|
| 322 |
+
|
| 323 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 324 |
+
batch_size = len(prompt)
|
| 325 |
+
|
| 326 |
+
text_inputs = self.tokenizer(
|
| 327 |
+
prompt,
|
| 328 |
+
padding="max_length",
|
| 329 |
+
max_length=max_sequence_length,
|
| 330 |
+
truncation=True,
|
| 331 |
+
add_special_tokens=True,
|
| 332 |
+
return_tensors="pt",
|
| 333 |
+
)
|
| 334 |
+
text_input_ids = text_inputs.input_ids
|
| 335 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 336 |
+
|
| 337 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 338 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 339 |
+
logger.warning(
|
| 340 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 341 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 345 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 346 |
+
|
| 347 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 348 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 349 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 350 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 351 |
+
|
| 352 |
+
return prompt_embeds
|
| 353 |
+
|
| 354 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
| 355 |
+
def encode_prompt(
|
| 356 |
+
self,
|
| 357 |
+
prompt: Union[str, List[str]],
|
| 358 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 359 |
+
do_classifier_free_guidance: bool = True,
|
| 360 |
+
num_videos_per_prompt: int = 1,
|
| 361 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 362 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 363 |
+
max_sequence_length: int = 226,
|
| 364 |
+
device: Optional[torch.device] = None,
|
| 365 |
+
dtype: Optional[torch.dtype] = None,
|
| 366 |
+
):
|
| 367 |
+
r"""
|
| 368 |
+
Encodes the prompt into text encoder hidden states.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 372 |
+
prompt to be encoded
|
| 373 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 374 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 375 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 376 |
+
less than `1`).
|
| 377 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 378 |
+
Whether to use classifier free guidance or not.
|
| 379 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 380 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 381 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 382 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 383 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 384 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 385 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 386 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 387 |
+
argument.
|
| 388 |
+
device: (`torch.device`, *optional*):
|
| 389 |
+
torch device
|
| 390 |
+
dtype: (`torch.dtype`, *optional*):
|
| 391 |
+
torch dtype
|
| 392 |
+
"""
|
| 393 |
+
device = device or self._execution_device
|
| 394 |
+
|
| 395 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 396 |
+
if prompt is not None:
|
| 397 |
+
batch_size = len(prompt)
|
| 398 |
+
else:
|
| 399 |
+
batch_size = prompt_embeds.shape[0]
|
| 400 |
+
|
| 401 |
+
if prompt_embeds is None:
|
| 402 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 403 |
+
prompt=prompt,
|
| 404 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 405 |
+
max_sequence_length=max_sequence_length,
|
| 406 |
+
device=device,
|
| 407 |
+
dtype=dtype,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 411 |
+
negative_prompt = negative_prompt or ""
|
| 412 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 413 |
+
|
| 414 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 415 |
+
raise TypeError(
|
| 416 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 417 |
+
f" {type(prompt)}."
|
| 418 |
+
)
|
| 419 |
+
elif batch_size != len(negative_prompt):
|
| 420 |
+
raise ValueError(
|
| 421 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 422 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 423 |
+
" the batch size of `prompt`."
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 427 |
+
prompt=negative_prompt,
|
| 428 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 429 |
+
max_sequence_length=max_sequence_length,
|
| 430 |
+
device=device,
|
| 431 |
+
dtype=dtype,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
return prompt_embeds, negative_prompt_embeds
|
| 435 |
+
|
| 436 |
+
def prepare_latents(
|
| 437 |
+
self,
|
| 438 |
+
image: torch.Tensor,
|
| 439 |
+
batch_size: int = 1,
|
| 440 |
+
num_channels_latents: int = 16,
|
| 441 |
+
num_frames: int = 13,
|
| 442 |
+
height: int = 60,
|
| 443 |
+
width: int = 90,
|
| 444 |
+
dtype: Optional[torch.dtype] = None,
|
| 445 |
+
device: Optional[torch.device] = None,
|
| 446 |
+
generator: Optional[torch.Generator] = None,
|
| 447 |
+
latents: Optional[torch.Tensor] = None,
|
| 448 |
+
kps_cond: Optional[torch.Tensor] = None,
|
| 449 |
+
):
|
| 450 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 451 |
+
raise ValueError(
|
| 452 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 453 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 457 |
+
shape = (
|
| 458 |
+
batch_size,
|
| 459 |
+
num_frames,
|
| 460 |
+
num_channels_latents,
|
| 461 |
+
height // self.vae_scale_factor_spatial,
|
| 462 |
+
width // self.vae_scale_factor_spatial,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
image = image.unsqueeze(2) # [B, C, F, H, W]
|
| 466 |
+
|
| 467 |
+
if isinstance(generator, list):
|
| 468 |
+
image_latents = [
|
| 469 |
+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
| 470 |
+
]
|
| 471 |
+
if kps_cond is not None:
|
| 472 |
+
kps_cond = kps_cond.unsqueeze(2)
|
| 473 |
+
kps_cond_latents = [
|
| 474 |
+
retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i])
|
| 475 |
+
for i in range(batch_size)
|
| 476 |
+
]
|
| 477 |
+
else:
|
| 478 |
+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
| 479 |
+
if kps_cond is not None:
|
| 480 |
+
kps_cond = kps_cond.unsqueeze(2)
|
| 481 |
+
kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond]
|
| 482 |
+
|
| 483 |
+
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
| 484 |
+
image_latents = self.vae_scaling_factor_image * image_latents
|
| 485 |
+
|
| 486 |
+
if kps_cond is not None:
|
| 487 |
+
kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
| 488 |
+
kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents
|
| 489 |
+
|
| 490 |
+
padding_shape = (
|
| 491 |
+
batch_size,
|
| 492 |
+
num_frames - 2,
|
| 493 |
+
num_channels_latents,
|
| 494 |
+
height // self.vae_scale_factor_spatial,
|
| 495 |
+
width // self.vae_scale_factor_spatial,
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
padding_shape = (
|
| 499 |
+
batch_size,
|
| 500 |
+
num_frames - 1,
|
| 501 |
+
num_channels_latents,
|
| 502 |
+
height // self.vae_scale_factor_spatial,
|
| 503 |
+
width // self.vae_scale_factor_spatial,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
|
| 507 |
+
if kps_cond is not None:
|
| 508 |
+
image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1)
|
| 509 |
+
else:
|
| 510 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
| 511 |
+
|
| 512 |
+
if latents is None:
|
| 513 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 514 |
+
else:
|
| 515 |
+
latents = latents.to(device)
|
| 516 |
+
|
| 517 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 518 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 519 |
+
return latents, image_latents
|
| 520 |
+
|
| 521 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
|
| 522 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 523 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 524 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 525 |
+
|
| 526 |
+
frames = self.vae.decode(latents).sample
|
| 527 |
+
return frames
|
| 528 |
+
|
| 529 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
|
| 530 |
+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
| 531 |
+
# get the original timestep using init_timestep
|
| 532 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 533 |
+
|
| 534 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 535 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 536 |
+
|
| 537 |
+
return timesteps, num_inference_steps - t_start
|
| 538 |
+
|
| 539 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 540 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 541 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 542 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 543 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 544 |
+
# and should be between [0, 1]
|
| 545 |
+
|
| 546 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 547 |
+
extra_step_kwargs = {}
|
| 548 |
+
if accepts_eta:
|
| 549 |
+
extra_step_kwargs["eta"] = eta
|
| 550 |
+
|
| 551 |
+
# check if the scheduler accepts generator
|
| 552 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 553 |
+
if accepts_generator:
|
| 554 |
+
extra_step_kwargs["generator"] = generator
|
| 555 |
+
return extra_step_kwargs
|
| 556 |
+
|
| 557 |
+
def check_inputs(
|
| 558 |
+
self,
|
| 559 |
+
image,
|
| 560 |
+
prompt,
|
| 561 |
+
height,
|
| 562 |
+
width,
|
| 563 |
+
negative_prompt,
|
| 564 |
+
callback_on_step_end_tensor_inputs,
|
| 565 |
+
latents=None,
|
| 566 |
+
prompt_embeds=None,
|
| 567 |
+
negative_prompt_embeds=None,
|
| 568 |
+
):
|
| 569 |
+
if (
|
| 570 |
+
not isinstance(image, torch.Tensor)
|
| 571 |
+
and not isinstance(image, PIL.Image.Image)
|
| 572 |
+
and not isinstance(image, list)
|
| 573 |
+
):
|
| 574 |
+
raise ValueError(
|
| 575 |
+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 576 |
+
f" {type(image)}"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 580 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 581 |
+
|
| 582 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 583 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 584 |
+
):
|
| 585 |
+
raise ValueError(
|
| 586 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 587 |
+
)
|
| 588 |
+
if prompt is not None and prompt_embeds is not None:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 591 |
+
" only forward one of the two."
|
| 592 |
+
)
|
| 593 |
+
elif prompt is None and prompt_embeds is None:
|
| 594 |
+
raise ValueError(
|
| 595 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 596 |
+
)
|
| 597 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 598 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 599 |
+
|
| 600 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 601 |
+
raise ValueError(
|
| 602 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 603 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 607 |
+
raise ValueError(
|
| 608 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 609 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 613 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 614 |
+
raise ValueError(
|
| 615 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 616 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 617 |
+
f" {negative_prompt_embeds.shape}."
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
def _prepare_rotary_positional_embeddings(
|
| 621 |
+
self,
|
| 622 |
+
height: int,
|
| 623 |
+
width: int,
|
| 624 |
+
num_frames: int,
|
| 625 |
+
device: torch.device,
|
| 626 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 627 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 628 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 629 |
+
base_size_width = self.transformer.config.sample_width // self.transformer.config.patch_size
|
| 630 |
+
base_size_height = self.transformer.config.sample_height // self.transformer.config.patch_size
|
| 631 |
+
|
| 632 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 633 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 634 |
+
)
|
| 635 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 636 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 637 |
+
crops_coords=grid_crops_coords,
|
| 638 |
+
grid_size=(grid_height, grid_width),
|
| 639 |
+
temporal_size=num_frames,
|
| 640 |
+
device=device,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
return freqs_cos, freqs_sin
|
| 644 |
+
|
| 645 |
+
@property
|
| 646 |
+
def guidance_scale(self):
|
| 647 |
+
return self._guidance_scale
|
| 648 |
+
|
| 649 |
+
@property
|
| 650 |
+
def num_timesteps(self):
|
| 651 |
+
return self._num_timesteps
|
| 652 |
+
|
| 653 |
+
@property
|
| 654 |
+
def attention_kwargs(self):
|
| 655 |
+
return self._attention_kwargs
|
| 656 |
+
|
| 657 |
+
@property
|
| 658 |
+
def interrupt(self):
|
| 659 |
+
return self._interrupt
|
| 660 |
+
|
| 661 |
+
@torch.no_grad()
|
| 662 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 663 |
+
def __call__(
|
| 664 |
+
self,
|
| 665 |
+
image: PipelineImageInput,
|
| 666 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 667 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 668 |
+
height: int = 480,
|
| 669 |
+
width: int = 720,
|
| 670 |
+
num_frames: int = 49,
|
| 671 |
+
num_inference_steps: int = 50,
|
| 672 |
+
guidance_scale: float = 6.0,
|
| 673 |
+
use_dynamic_cfg: bool = False,
|
| 674 |
+
num_videos_per_prompt: int = 1,
|
| 675 |
+
eta: float = 0.0,
|
| 676 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 677 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 678 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 679 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 680 |
+
output_type: str = "pil",
|
| 681 |
+
return_dict: bool = True,
|
| 682 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 683 |
+
callback_on_step_end: Optional[
|
| 684 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 685 |
+
] = None,
|
| 686 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 687 |
+
max_sequence_length: int = 226,
|
| 688 |
+
id_vit_hidden: Optional[torch.Tensor] = None,
|
| 689 |
+
id_cond: Optional[torch.Tensor] = None,
|
| 690 |
+
kps_cond: Optional[torch.Tensor] = None,
|
| 691 |
+
) -> Union[ConsisIDPipelineOutput, Tuple]:
|
| 692 |
+
"""
|
| 693 |
+
Function invoked when calling the pipeline for generation.
|
| 694 |
+
|
| 695 |
+
Args:
|
| 696 |
+
image (`PipelineImageInput`):
|
| 697 |
+
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
| 698 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 699 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 700 |
+
instead.
|
| 701 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 702 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 703 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 704 |
+
less than `1`).
|
| 705 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 706 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
| 707 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 708 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
| 709 |
+
num_frames (`int`, defaults to `49`):
|
| 710 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 711 |
+
contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where
|
| 712 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 713 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 714 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 715 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 716 |
+
expense of slower inference.
|
| 717 |
+
guidance_scale (`float`, *optional*, defaults to 6):
|
| 718 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 719 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 720 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 721 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 722 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 723 |
+
use_dynamic_cfg (`bool`, *optional*, defaults to `False`):
|
| 724 |
+
If True, dynamically adjusts the guidance scale during inference. This allows the model to use a
|
| 725 |
+
progressive guidance scale, improving the balance between text-guided generation and image quality over
|
| 726 |
+
the course of the inference steps. Typically, early inference steps use a higher guidance scale for
|
| 727 |
+
more faithful image generation, while later steps reduce it for more diverse and natural results.
|
| 728 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 729 |
+
The number of videos to generate per prompt.
|
| 730 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 731 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 732 |
+
to make generation deterministic.
|
| 733 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 734 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 735 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 736 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 737 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 738 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 739 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 740 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 741 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 742 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 743 |
+
argument.
|
| 744 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 745 |
+
The output format of the generate image. Choose between
|
| 746 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 747 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 748 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 749 |
+
of a plain tuple.
|
| 750 |
+
attention_kwargs (`dict`, *optional*):
|
| 751 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 752 |
+
`self.processor` in
|
| 753 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 754 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 755 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 756 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 757 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 758 |
+
`callback_on_step_end_tensor_inputs`.
|
| 759 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 760 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 761 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 762 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 763 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 764 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 765 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 766 |
+
id_vit_hidden (`Optional[torch.Tensor]`, *optional*):
|
| 767 |
+
The tensor representing the hidden features extracted from the face model, which are used to condition
|
| 768 |
+
the local facial extractor. This is crucial for the model to obtain high-frequency information of the
|
| 769 |
+
face. If not provided, the local facial extractor will not run normally.
|
| 770 |
+
id_cond (`Optional[torch.Tensor]`, *optional*):
|
| 771 |
+
The tensor representing the hidden features extracted from the clip model, which are used to condition
|
| 772 |
+
the local facial extractor. This is crucial for the model to edit facial features If not provided, the
|
| 773 |
+
local facial extractor will not run normally.
|
| 774 |
+
kps_cond (`Optional[torch.Tensor]`, *optional*):
|
| 775 |
+
A tensor that determines whether the global facial extractor use keypoint information for conditioning.
|
| 776 |
+
If provided, this tensor controls whether facial keypoints such as eyes, nose, and mouth landmarks are
|
| 777 |
+
used during the generation process. This helps ensure the model retains more facial low-frequency
|
| 778 |
+
information.
|
| 779 |
+
|
| 780 |
+
Examples:
|
| 781 |
+
|
| 782 |
+
Returns:
|
| 783 |
+
[`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`:
|
| 784 |
+
[`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a
|
| 785 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 786 |
+
"""
|
| 787 |
+
|
| 788 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 789 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 790 |
+
|
| 791 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 792 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 793 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 794 |
+
|
| 795 |
+
num_videos_per_prompt = 1
|
| 796 |
+
|
| 797 |
+
# 1. Check inputs. Raise error if not correct
|
| 798 |
+
self.check_inputs(
|
| 799 |
+
image=image,
|
| 800 |
+
prompt=prompt,
|
| 801 |
+
height=height,
|
| 802 |
+
width=width,
|
| 803 |
+
negative_prompt=negative_prompt,
|
| 804 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 805 |
+
latents=latents,
|
| 806 |
+
prompt_embeds=prompt_embeds,
|
| 807 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 808 |
+
)
|
| 809 |
+
self._guidance_scale = guidance_scale
|
| 810 |
+
self._attention_kwargs = attention_kwargs
|
| 811 |
+
self._interrupt = False
|
| 812 |
+
|
| 813 |
+
# 2. Default call parameters
|
| 814 |
+
if prompt is not None and isinstance(prompt, str):
|
| 815 |
+
batch_size = 1
|
| 816 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 817 |
+
batch_size = len(prompt)
|
| 818 |
+
else:
|
| 819 |
+
batch_size = prompt_embeds.shape[0]
|
| 820 |
+
|
| 821 |
+
device = self._execution_device
|
| 822 |
+
|
| 823 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 824 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 825 |
+
# corresponds to doing no classifier free guidance.
|
| 826 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 827 |
+
|
| 828 |
+
# 3. Encode input prompt
|
| 829 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 830 |
+
prompt=prompt,
|
| 831 |
+
negative_prompt=negative_prompt,
|
| 832 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 833 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 834 |
+
prompt_embeds=prompt_embeds,
|
| 835 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 836 |
+
max_sequence_length=max_sequence_length,
|
| 837 |
+
device=device,
|
| 838 |
+
)
|
| 839 |
+
if do_classifier_free_guidance:
|
| 840 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 841 |
+
|
| 842 |
+
# 4. Prepare timesteps
|
| 843 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
|
| 844 |
+
self._num_timesteps = len(timesteps)
|
| 845 |
+
|
| 846 |
+
# 5. Prepare latents
|
| 847 |
+
is_kps = getattr(self.transformer.config, "is_kps", False)
|
| 848 |
+
kps_cond = kps_cond if is_kps else None
|
| 849 |
+
if kps_cond is not None:
|
| 850 |
+
kps_cond = draw_kps(image, kps_cond)
|
| 851 |
+
kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to(
|
| 852 |
+
device, dtype=prompt_embeds.dtype
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
| 856 |
+
device, dtype=prompt_embeds.dtype
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
latent_channels = self.transformer.config.in_channels // 2
|
| 860 |
+
latents, image_latents = self.prepare_latents(
|
| 861 |
+
image,
|
| 862 |
+
batch_size * num_videos_per_prompt,
|
| 863 |
+
latent_channels,
|
| 864 |
+
num_frames,
|
| 865 |
+
height,
|
| 866 |
+
width,
|
| 867 |
+
prompt_embeds.dtype,
|
| 868 |
+
device,
|
| 869 |
+
generator,
|
| 870 |
+
latents,
|
| 871 |
+
kps_cond,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 875 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 876 |
+
|
| 877 |
+
# 7. Create rotary embeds if required
|
| 878 |
+
image_rotary_emb = (
|
| 879 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 880 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 881 |
+
else None
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
# 8. Denoising loop
|
| 885 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 886 |
+
|
| 887 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 888 |
+
# for DPM-solver++
|
| 889 |
+
old_pred_original_sample = None
|
| 890 |
+
timesteps_cpu = timesteps.cpu()
|
| 891 |
+
for i, t in enumerate(timesteps):
|
| 892 |
+
if self.interrupt:
|
| 893 |
+
continue
|
| 894 |
+
|
| 895 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 896 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 897 |
+
|
| 898 |
+
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
| 899 |
+
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
| 900 |
+
|
| 901 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 902 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 903 |
+
|
| 904 |
+
# predict noise model_output
|
| 905 |
+
noise_pred = self.transformer(
|
| 906 |
+
hidden_states=latent_model_input,
|
| 907 |
+
encoder_hidden_states=prompt_embeds,
|
| 908 |
+
timestep=timestep,
|
| 909 |
+
image_rotary_emb=image_rotary_emb,
|
| 910 |
+
attention_kwargs=attention_kwargs,
|
| 911 |
+
return_dict=False,
|
| 912 |
+
id_vit_hidden=id_vit_hidden,
|
| 913 |
+
id_cond=id_cond,
|
| 914 |
+
)[0]
|
| 915 |
+
noise_pred = noise_pred.float()
|
| 916 |
+
|
| 917 |
+
# perform guidance
|
| 918 |
+
if use_dynamic_cfg:
|
| 919 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 920 |
+
(
|
| 921 |
+
1
|
| 922 |
+
- math.cos(
|
| 923 |
+
math.pi
|
| 924 |
+
* ((num_inference_steps - timesteps_cpu[i].item()) / num_inference_steps) ** 5.0
|
| 925 |
+
)
|
| 926 |
+
)
|
| 927 |
+
/ 2
|
| 928 |
+
)
|
| 929 |
+
if do_classifier_free_guidance:
|
| 930 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 931 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 932 |
+
|
| 933 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 934 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 935 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 936 |
+
else:
|
| 937 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 938 |
+
noise_pred,
|
| 939 |
+
old_pred_original_sample,
|
| 940 |
+
t,
|
| 941 |
+
timesteps[i - 1] if i > 0 else None,
|
| 942 |
+
latents,
|
| 943 |
+
**extra_step_kwargs,
|
| 944 |
+
return_dict=False,
|
| 945 |
+
)
|
| 946 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 947 |
+
|
| 948 |
+
# call the callback, if provided
|
| 949 |
+
if callback_on_step_end is not None:
|
| 950 |
+
callback_kwargs = {}
|
| 951 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 952 |
+
callback_kwargs[k] = locals()[k]
|
| 953 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 954 |
+
|
| 955 |
+
latents = callback_outputs.pop("latents", latents)
|
| 956 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 957 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 958 |
+
|
| 959 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 960 |
+
progress_bar.update()
|
| 961 |
+
|
| 962 |
+
if not output_type == "latent":
|
| 963 |
+
video = self.decode_latents(latents)
|
| 964 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 965 |
+
else:
|
| 966 |
+
video = latents
|
| 967 |
+
|
| 968 |
+
# Offload all models
|
| 969 |
+
self.maybe_free_model_hooks()
|
| 970 |
+
|
| 971 |
+
if not return_dict:
|
| 972 |
+
return (video,)
|
| 973 |
+
|
| 974 |
+
return ConsisIDPipelineOutput(frames=video)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_output.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers.utils import BaseOutput
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class ConsisIDPipelineOutput(BaseOutput):
|
| 10 |
+
r"""
|
| 11 |
+
Output class for ConsisID pipelines.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 15 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 16 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 17 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
frames: torch.Tensor
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
_LazyModule,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_import_structure = {
|
| 10 |
+
"pipeline_consistency_models": ["ConsistencyModelPipeline"],
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 14 |
+
from .pipeline_consistency_models import ConsistencyModelPipeline
|
| 15 |
+
|
| 16 |
+
else:
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
sys.modules[__name__] = _LazyModule(
|
| 20 |
+
__name__,
|
| 21 |
+
globals()["__file__"],
|
| 22 |
+
_import_structure,
|
| 23 |
+
module_spec=__spec__,
|
| 24 |
+
)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Callable, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from ...models import UNet2DModel
|
| 20 |
+
from ...schedulers import CMStochasticIterativeScheduler
|
| 21 |
+
from ...utils import (
|
| 22 |
+
is_torch_xla_available,
|
| 23 |
+
logging,
|
| 24 |
+
replace_example_docstring,
|
| 25 |
+
)
|
| 26 |
+
from ...utils.torch_utils import randn_tensor
|
| 27 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if is_torch_xla_available():
|
| 31 |
+
import torch_xla.core.xla_model as xm
|
| 32 |
+
|
| 33 |
+
XLA_AVAILABLE = True
|
| 34 |
+
else:
|
| 35 |
+
XLA_AVAILABLE = False
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
EXAMPLE_DOC_STRING = """
|
| 41 |
+
Examples:
|
| 42 |
+
```py
|
| 43 |
+
>>> import torch
|
| 44 |
+
|
| 45 |
+
>>> from diffusers import ConsistencyModelPipeline
|
| 46 |
+
|
| 47 |
+
>>> device = "cuda"
|
| 48 |
+
>>> # Load the cd_imagenet64_l2 checkpoint.
|
| 49 |
+
>>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
|
| 50 |
+
>>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
| 51 |
+
>>> pipe.to(device)
|
| 52 |
+
|
| 53 |
+
>>> # Onestep Sampling
|
| 54 |
+
>>> image = pipe(num_inference_steps=1).images[0]
|
| 55 |
+
>>> image.save("cd_imagenet64_l2_onestep_sample.png")
|
| 56 |
+
|
| 57 |
+
>>> # Onestep sampling, class-conditional image generation
|
| 58 |
+
>>> # ImageNet-64 class label 145 corresponds to king penguins
|
| 59 |
+
>>> image = pipe(num_inference_steps=1, class_labels=145).images[0]
|
| 60 |
+
>>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png")
|
| 61 |
+
|
| 62 |
+
>>> # Multistep sampling, class-conditional image generation
|
| 63 |
+
>>> # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo:
|
| 64 |
+
>>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
|
| 65 |
+
>>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
|
| 66 |
+
>>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png")
|
| 67 |
+
```
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ConsistencyModelPipeline(DiffusionPipeline):
|
| 72 |
+
r"""
|
| 73 |
+
Pipeline for unconditional or class-conditional image generation.
|
| 74 |
+
|
| 75 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 76 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
unet ([`UNet2DModel`]):
|
| 80 |
+
A `UNet2DModel` to denoise the encoded image latents.
|
| 81 |
+
scheduler ([`SchedulerMixin`]):
|
| 82 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
|
| 83 |
+
compatible with [`CMStochasticIterativeScheduler`].
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
model_cpu_offload_seq = "unet"
|
| 87 |
+
|
| 88 |
+
def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
self.register_modules(
|
| 92 |
+
unet=unet,
|
| 93 |
+
scheduler=scheduler,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.safety_checker = None
|
| 97 |
+
|
| 98 |
+
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
| 99 |
+
shape = (batch_size, num_channels, height, width)
|
| 100 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 103 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if latents is None:
|
| 107 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 108 |
+
else:
|
| 109 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 110 |
+
|
| 111 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 112 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 113 |
+
return latents
|
| 114 |
+
|
| 115 |
+
# Follows diffusers.VaeImageProcessor.postprocess
|
| 116 |
+
def postprocess_image(self, sample: torch.Tensor, output_type: str = "pil"):
|
| 117 |
+
if output_type not in ["pt", "np", "pil"]:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Equivalent to diffusers.VaeImageProcessor.denormalize
|
| 123 |
+
sample = (sample / 2 + 0.5).clamp(0, 1)
|
| 124 |
+
if output_type == "pt":
|
| 125 |
+
return sample
|
| 126 |
+
|
| 127 |
+
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
|
| 128 |
+
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
| 129 |
+
if output_type == "np":
|
| 130 |
+
return sample
|
| 131 |
+
|
| 132 |
+
# Output_type must be 'pil'
|
| 133 |
+
sample = self.numpy_to_pil(sample)
|
| 134 |
+
return sample
|
| 135 |
+
|
| 136 |
+
def prepare_class_labels(self, batch_size, device, class_labels=None):
|
| 137 |
+
if self.unet.config.num_class_embeds is not None:
|
| 138 |
+
if isinstance(class_labels, list):
|
| 139 |
+
class_labels = torch.tensor(class_labels, dtype=torch.int)
|
| 140 |
+
elif isinstance(class_labels, int):
|
| 141 |
+
assert batch_size == 1, "Batch size must be 1 if classes is an int"
|
| 142 |
+
class_labels = torch.tensor([class_labels], dtype=torch.int)
|
| 143 |
+
elif class_labels is None:
|
| 144 |
+
# Randomly generate batch_size class labels
|
| 145 |
+
# TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
|
| 146 |
+
class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
|
| 147 |
+
class_labels = class_labels.to(device)
|
| 148 |
+
else:
|
| 149 |
+
class_labels = None
|
| 150 |
+
return class_labels
|
| 151 |
+
|
| 152 |
+
def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
|
| 153 |
+
if num_inference_steps is None and timesteps is None:
|
| 154 |
+
raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
|
| 155 |
+
|
| 156 |
+
if num_inference_steps is not None and timesteps is not None:
|
| 157 |
+
logger.warning(
|
| 158 |
+
f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;"
|
| 159 |
+
" `timesteps` will be used over `num_inference_steps`."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if latents is not None:
|
| 163 |
+
expected_shape = (batch_size, 3, img_size, img_size)
|
| 164 |
+
if latents.shape != expected_shape:
|
| 165 |
+
raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")
|
| 166 |
+
|
| 167 |
+
if (callback_steps is None) or (
|
| 168 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 169 |
+
):
|
| 170 |
+
raise ValueError(
|
| 171 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 172 |
+
f" {type(callback_steps)}."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
@torch.no_grad()
|
| 176 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 177 |
+
def __call__(
|
| 178 |
+
self,
|
| 179 |
+
batch_size: int = 1,
|
| 180 |
+
class_labels: Optional[Union[torch.Tensor, List[int], int]] = None,
|
| 181 |
+
num_inference_steps: int = 1,
|
| 182 |
+
timesteps: List[int] = None,
|
| 183 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 184 |
+
latents: Optional[torch.Tensor] = None,
|
| 185 |
+
output_type: Optional[str] = "pil",
|
| 186 |
+
return_dict: bool = True,
|
| 187 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 188 |
+
callback_steps: int = 1,
|
| 189 |
+
):
|
| 190 |
+
r"""
|
| 191 |
+
Args:
|
| 192 |
+
batch_size (`int`, *optional*, defaults to 1):
|
| 193 |
+
The number of images to generate.
|
| 194 |
+
class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*):
|
| 195 |
+
Optional class labels for conditioning class-conditional consistency models. Not used if the model is
|
| 196 |
+
not class-conditional.
|
| 197 |
+
num_inference_steps (`int`, *optional*, defaults to 1):
|
| 198 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 199 |
+
expense of slower inference.
|
| 200 |
+
timesteps (`List[int]`, *optional*):
|
| 201 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
| 202 |
+
timesteps are used. Must be in descending order.
|
| 203 |
+
generator (`torch.Generator`, *optional*):
|
| 204 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 205 |
+
generation deterministic.
|
| 206 |
+
latents (`torch.Tensor`, *optional*):
|
| 207 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 208 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 209 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 210 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 211 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 212 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 213 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 214 |
+
callback (`Callable`, *optional*):
|
| 215 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 216 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 217 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 218 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 219 |
+
every step.
|
| 220 |
+
|
| 221 |
+
Examples:
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 225 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
| 226 |
+
returned where the first element is a list with the generated images.
|
| 227 |
+
"""
|
| 228 |
+
# 0. Prepare call parameters
|
| 229 |
+
img_size = self.unet.config.sample_size
|
| 230 |
+
device = self._execution_device
|
| 231 |
+
|
| 232 |
+
# 1. Check inputs
|
| 233 |
+
self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
|
| 234 |
+
|
| 235 |
+
# 2. Prepare image latents
|
| 236 |
+
# Sample image latents x_0 ~ N(0, sigma_0^2 * I)
|
| 237 |
+
sample = self.prepare_latents(
|
| 238 |
+
batch_size=batch_size,
|
| 239 |
+
num_channels=self.unet.config.in_channels,
|
| 240 |
+
height=img_size,
|
| 241 |
+
width=img_size,
|
| 242 |
+
dtype=self.unet.dtype,
|
| 243 |
+
device=device,
|
| 244 |
+
generator=generator,
|
| 245 |
+
latents=latents,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# 3. Handle class_labels for class-conditional models
|
| 249 |
+
class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels)
|
| 250 |
+
|
| 251 |
+
# 4. Prepare timesteps
|
| 252 |
+
if timesteps is not None:
|
| 253 |
+
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
| 254 |
+
timesteps = self.scheduler.timesteps
|
| 255 |
+
num_inference_steps = len(timesteps)
|
| 256 |
+
else:
|
| 257 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 258 |
+
timesteps = self.scheduler.timesteps
|
| 259 |
+
|
| 260 |
+
# 5. Denoising loop
|
| 261 |
+
# Multistep sampling: implements Algorithm 1 in the paper
|
| 262 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 263 |
+
for i, t in enumerate(timesteps):
|
| 264 |
+
scaled_sample = self.scheduler.scale_model_input(sample, t)
|
| 265 |
+
model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]
|
| 266 |
+
|
| 267 |
+
sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]
|
| 268 |
+
|
| 269 |
+
# call the callback, if provided
|
| 270 |
+
progress_bar.update()
|
| 271 |
+
if callback is not None and i % callback_steps == 0:
|
| 272 |
+
callback(i, t, sample)
|
| 273 |
+
|
| 274 |
+
if XLA_AVAILABLE:
|
| 275 |
+
xm.mark_step()
|
| 276 |
+
|
| 277 |
+
# 6. Post-process image sample
|
| 278 |
+
image = self.postprocess_image(sample, output_type=output_type)
|
| 279 |
+
|
| 280 |
+
# Offload all models
|
| 281 |
+
self.maybe_free_model_hooks()
|
| 282 |
+
|
| 283 |
+
if not return_dict:
|
| 284 |
+
return (image,)
|
| 285 |
+
|
| 286 |
+
return ImagePipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/__init__.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_flax_available,
|
| 9 |
+
is_torch_available,
|
| 10 |
+
is_transformers_available,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_dummy_objects = {}
|
| 15 |
+
_import_structure = {}
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
| 26 |
+
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
| 27 |
+
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
|
| 28 |
+
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
| 29 |
+
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
| 30 |
+
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
| 31 |
+
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
| 32 |
+
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
| 33 |
+
_import_structure["pipeline_controlnet_union_inpaint_sd_xl"] = ["StableDiffusionXLControlNetUnionInpaintPipeline"]
|
| 34 |
+
_import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"]
|
| 35 |
+
_import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"]
|
| 36 |
+
try:
|
| 37 |
+
if not (is_transformers_available() and is_flax_available()):
|
| 38 |
+
raise OptionalDependencyNotAvailable()
|
| 39 |
+
except OptionalDependencyNotAvailable:
|
| 40 |
+
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
| 41 |
+
|
| 42 |
+
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
| 43 |
+
else:
|
| 44 |
+
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 48 |
+
try:
|
| 49 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 50 |
+
raise OptionalDependencyNotAvailable()
|
| 51 |
+
|
| 52 |
+
except OptionalDependencyNotAvailable:
|
| 53 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 54 |
+
else:
|
| 55 |
+
from .multicontrolnet import MultiControlNetModel
|
| 56 |
+
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
| 57 |
+
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
|
| 58 |
+
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
| 59 |
+
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
| 60 |
+
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
| 61 |
+
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
| 62 |
+
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
| 63 |
+
from .pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
|
| 64 |
+
from .pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline
|
| 65 |
+
from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
if not (is_transformers_available() and is_flax_available()):
|
| 69 |
+
raise OptionalDependencyNotAvailable()
|
| 70 |
+
except OptionalDependencyNotAvailable:
|
| 71 |
+
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
| 72 |
+
else:
|
| 73 |
+
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
import sys
|
| 78 |
+
|
| 79 |
+
sys.modules[__name__] = _LazyModule(
|
| 80 |
+
__name__,
|
| 81 |
+
globals()["__file__"],
|
| 82 |
+
_import_structure,
|
| 83 |
+
module_spec=__spec__,
|
| 84 |
+
)
|
| 85 |
+
for name, value in _dummy_objects.items():
|
| 86 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/multicontrolnet.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ...models.controlnets.multicontrolnet import MultiControlNetModel
|
| 2 |
+
from ...utils import deprecate, logging
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
logger = logging.get_logger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MultiControlNetModel(MultiControlNetModel):
|
| 9 |
+
def __init__(self, *args, **kwargs):
|
| 10 |
+
deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
|
| 11 |
+
deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message)
|
| 12 |
+
super().__init__(*args, **kwargs)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet.py
ADDED
|
@@ -0,0 +1,1366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import PIL.Image
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 24 |
+
|
| 25 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 26 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 27 |
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
| 28 |
+
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
|
| 29 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 30 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 31 |
+
from ...utils import (
|
| 32 |
+
USE_PEFT_BACKEND,
|
| 33 |
+
deprecate,
|
| 34 |
+
is_torch_xla_available,
|
| 35 |
+
logging,
|
| 36 |
+
replace_example_docstring,
|
| 37 |
+
scale_lora_layers,
|
| 38 |
+
unscale_lora_layers,
|
| 39 |
+
)
|
| 40 |
+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
|
| 41 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 42 |
+
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
| 43 |
+
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if is_torch_xla_available():
|
| 47 |
+
import torch_xla.core.xla_model as xm
|
| 48 |
+
|
| 49 |
+
XLA_AVAILABLE = True
|
| 50 |
+
else:
|
| 51 |
+
XLA_AVAILABLE = False
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
EXAMPLE_DOC_STRING = """
|
| 57 |
+
Examples:
|
| 58 |
+
```py
|
| 59 |
+
>>> # !pip install opencv-python transformers accelerate
|
| 60 |
+
>>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
|
| 61 |
+
>>> from diffusers.utils import load_image
|
| 62 |
+
>>> import numpy as np
|
| 63 |
+
>>> import torch
|
| 64 |
+
|
| 65 |
+
>>> import cv2
|
| 66 |
+
>>> from PIL import Image
|
| 67 |
+
|
| 68 |
+
>>> # download an image
|
| 69 |
+
>>> image = load_image(
|
| 70 |
+
... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
| 71 |
+
... )
|
| 72 |
+
>>> image = np.array(image)
|
| 73 |
+
|
| 74 |
+
>>> # get canny image
|
| 75 |
+
>>> image = cv2.Canny(image, 100, 200)
|
| 76 |
+
>>> image = image[:, :, None]
|
| 77 |
+
>>> image = np.concatenate([image, image, image], axis=2)
|
| 78 |
+
>>> canny_image = Image.fromarray(image)
|
| 79 |
+
|
| 80 |
+
>>> # load control net and stable diffusion v1-5
|
| 81 |
+
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
| 82 |
+
>>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 83 |
+
... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
| 84 |
+
... )
|
| 85 |
+
|
| 86 |
+
>>> # speed up diffusion process with faster scheduler and memory optimization
|
| 87 |
+
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 88 |
+
>>> # remove following line if xformers is not installed
|
| 89 |
+
>>> pipe.enable_xformers_memory_efficient_attention()
|
| 90 |
+
|
| 91 |
+
>>> pipe.enable_model_cpu_offload()
|
| 92 |
+
|
| 93 |
+
>>> # generate image
|
| 94 |
+
>>> generator = torch.manual_seed(0)
|
| 95 |
+
>>> image = pipe(
|
| 96 |
+
... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
|
| 97 |
+
... ).images[0]
|
| 98 |
+
```
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 103 |
+
def retrieve_timesteps(
|
| 104 |
+
scheduler,
|
| 105 |
+
num_inference_steps: Optional[int] = None,
|
| 106 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 107 |
+
timesteps: Optional[List[int]] = None,
|
| 108 |
+
sigmas: Optional[List[float]] = None,
|
| 109 |
+
**kwargs,
|
| 110 |
+
):
|
| 111 |
+
r"""
|
| 112 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 113 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
scheduler (`SchedulerMixin`):
|
| 117 |
+
The scheduler to get timesteps from.
|
| 118 |
+
num_inference_steps (`int`):
|
| 119 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 120 |
+
must be `None`.
|
| 121 |
+
device (`str` or `torch.device`, *optional*):
|
| 122 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 123 |
+
timesteps (`List[int]`, *optional*):
|
| 124 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 125 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 126 |
+
sigmas (`List[float]`, *optional*):
|
| 127 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 128 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 132 |
+
second element is the number of inference steps.
|
| 133 |
+
"""
|
| 134 |
+
if timesteps is not None and sigmas is not None:
|
| 135 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 136 |
+
if timesteps is not None:
|
| 137 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 138 |
+
if not accepts_timesteps:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 141 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 142 |
+
)
|
| 143 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 144 |
+
timesteps = scheduler.timesteps
|
| 145 |
+
num_inference_steps = len(timesteps)
|
| 146 |
+
elif sigmas is not None:
|
| 147 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 148 |
+
if not accept_sigmas:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 151 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 152 |
+
)
|
| 153 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 154 |
+
timesteps = scheduler.timesteps
|
| 155 |
+
num_inference_steps = len(timesteps)
|
| 156 |
+
else:
|
| 157 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 158 |
+
timesteps = scheduler.timesteps
|
| 159 |
+
return timesteps, num_inference_steps
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class StableDiffusionControlNetPipeline(
|
| 163 |
+
DiffusionPipeline,
|
| 164 |
+
StableDiffusionMixin,
|
| 165 |
+
TextualInversionLoaderMixin,
|
| 166 |
+
StableDiffusionLoraLoaderMixin,
|
| 167 |
+
IPAdapterMixin,
|
| 168 |
+
FromSingleFileMixin,
|
| 169 |
+
):
|
| 170 |
+
r"""
|
| 171 |
+
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
| 172 |
+
|
| 173 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 174 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 175 |
+
|
| 176 |
+
The pipeline also inherits the following loading methods:
|
| 177 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 178 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 179 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 180 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 181 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
vae ([`AutoencoderKL`]):
|
| 185 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 186 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 187 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 188 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 189 |
+
A `CLIPTokenizer` to tokenize text.
|
| 190 |
+
unet ([`UNet2DConditionModel`]):
|
| 191 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 192 |
+
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
|
| 193 |
+
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
| 194 |
+
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
| 195 |
+
additional conditioning.
|
| 196 |
+
scheduler ([`SchedulerMixin`]):
|
| 197 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 198 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 199 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 200 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 201 |
+
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
|
| 202 |
+
more details about a model's potential harms.
|
| 203 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 204 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 208 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
| 209 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 210 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image"]
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
vae: AutoencoderKL,
|
| 215 |
+
text_encoder: CLIPTextModel,
|
| 216 |
+
tokenizer: CLIPTokenizer,
|
| 217 |
+
unet: UNet2DConditionModel,
|
| 218 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
| 219 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 220 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 221 |
+
feature_extractor: CLIPImageProcessor,
|
| 222 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 223 |
+
requires_safety_checker: bool = True,
|
| 224 |
+
):
|
| 225 |
+
super().__init__()
|
| 226 |
+
|
| 227 |
+
if safety_checker is None and requires_safety_checker:
|
| 228 |
+
logger.warning(
|
| 229 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 230 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 231 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 232 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 233 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 234 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if safety_checker is not None and feature_extractor is None:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 240 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if isinstance(controlnet, (list, tuple)):
|
| 244 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 245 |
+
|
| 246 |
+
self.register_modules(
|
| 247 |
+
vae=vae,
|
| 248 |
+
text_encoder=text_encoder,
|
| 249 |
+
tokenizer=tokenizer,
|
| 250 |
+
unet=unet,
|
| 251 |
+
controlnet=controlnet,
|
| 252 |
+
scheduler=scheduler,
|
| 253 |
+
safety_checker=safety_checker,
|
| 254 |
+
feature_extractor=feature_extractor,
|
| 255 |
+
image_encoder=image_encoder,
|
| 256 |
+
)
|
| 257 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 258 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
| 259 |
+
self.control_image_processor = VaeImageProcessor(
|
| 260 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 261 |
+
)
|
| 262 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 263 |
+
|
| 264 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
| 265 |
+
def _encode_prompt(
|
| 266 |
+
self,
|
| 267 |
+
prompt,
|
| 268 |
+
device,
|
| 269 |
+
num_images_per_prompt,
|
| 270 |
+
do_classifier_free_guidance,
|
| 271 |
+
negative_prompt=None,
|
| 272 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 273 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 274 |
+
lora_scale: Optional[float] = None,
|
| 275 |
+
**kwargs,
|
| 276 |
+
):
|
| 277 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| 278 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
| 279 |
+
|
| 280 |
+
prompt_embeds_tuple = self.encode_prompt(
|
| 281 |
+
prompt=prompt,
|
| 282 |
+
device=device,
|
| 283 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 284 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 285 |
+
negative_prompt=negative_prompt,
|
| 286 |
+
prompt_embeds=prompt_embeds,
|
| 287 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 288 |
+
lora_scale=lora_scale,
|
| 289 |
+
**kwargs,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# concatenate for backwards comp
|
| 293 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
| 294 |
+
|
| 295 |
+
return prompt_embeds
|
| 296 |
+
|
| 297 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
| 298 |
+
def encode_prompt(
|
| 299 |
+
self,
|
| 300 |
+
prompt,
|
| 301 |
+
device,
|
| 302 |
+
num_images_per_prompt,
|
| 303 |
+
do_classifier_free_guidance,
|
| 304 |
+
negative_prompt=None,
|
| 305 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 306 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 307 |
+
lora_scale: Optional[float] = None,
|
| 308 |
+
clip_skip: Optional[int] = None,
|
| 309 |
+
):
|
| 310 |
+
r"""
|
| 311 |
+
Encodes the prompt into text encoder hidden states.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 315 |
+
prompt to be encoded
|
| 316 |
+
device: (`torch.device`):
|
| 317 |
+
torch device
|
| 318 |
+
num_images_per_prompt (`int`):
|
| 319 |
+
number of images that should be generated per prompt
|
| 320 |
+
do_classifier_free_guidance (`bool`):
|
| 321 |
+
whether to use classifier free guidance or not
|
| 322 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 323 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 324 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 325 |
+
less than `1`).
|
| 326 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 327 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 328 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 329 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 330 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 331 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 332 |
+
argument.
|
| 333 |
+
lora_scale (`float`, *optional*):
|
| 334 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 335 |
+
clip_skip (`int`, *optional*):
|
| 336 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 337 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 338 |
+
"""
|
| 339 |
+
# set lora scale so that monkey patched LoRA
|
| 340 |
+
# function of text encoder can correctly access it
|
| 341 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 342 |
+
self._lora_scale = lora_scale
|
| 343 |
+
|
| 344 |
+
# dynamically adjust the LoRA scale
|
| 345 |
+
if not USE_PEFT_BACKEND:
|
| 346 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 347 |
+
else:
|
| 348 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 349 |
+
|
| 350 |
+
if prompt is not None and isinstance(prompt, str):
|
| 351 |
+
batch_size = 1
|
| 352 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 353 |
+
batch_size = len(prompt)
|
| 354 |
+
else:
|
| 355 |
+
batch_size = prompt_embeds.shape[0]
|
| 356 |
+
|
| 357 |
+
if prompt_embeds is None:
|
| 358 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 359 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 360 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 361 |
+
|
| 362 |
+
text_inputs = self.tokenizer(
|
| 363 |
+
prompt,
|
| 364 |
+
padding="max_length",
|
| 365 |
+
max_length=self.tokenizer.model_max_length,
|
| 366 |
+
truncation=True,
|
| 367 |
+
return_tensors="pt",
|
| 368 |
+
)
|
| 369 |
+
text_input_ids = text_inputs.input_ids
|
| 370 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 371 |
+
|
| 372 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 373 |
+
text_input_ids, untruncated_ids
|
| 374 |
+
):
|
| 375 |
+
removed_text = self.tokenizer.batch_decode(
|
| 376 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 377 |
+
)
|
| 378 |
+
logger.warning(
|
| 379 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 380 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 384 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 385 |
+
else:
|
| 386 |
+
attention_mask = None
|
| 387 |
+
|
| 388 |
+
if clip_skip is None:
|
| 389 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 390 |
+
prompt_embeds = prompt_embeds[0]
|
| 391 |
+
else:
|
| 392 |
+
prompt_embeds = self.text_encoder(
|
| 393 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 394 |
+
)
|
| 395 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 396 |
+
# all the hidden states from the encoder layers. Then index into
|
| 397 |
+
# the tuple to access the hidden states from the desired layer.
|
| 398 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 399 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 400 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 401 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 402 |
+
# layer.
|
| 403 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 404 |
+
|
| 405 |
+
if self.text_encoder is not None:
|
| 406 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 407 |
+
elif self.unet is not None:
|
| 408 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 409 |
+
else:
|
| 410 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 411 |
+
|
| 412 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 413 |
+
|
| 414 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 415 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 416 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 417 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 418 |
+
|
| 419 |
+
# get unconditional embeddings for classifier free guidance
|
| 420 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 421 |
+
uncond_tokens: List[str]
|
| 422 |
+
if negative_prompt is None:
|
| 423 |
+
uncond_tokens = [""] * batch_size
|
| 424 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 425 |
+
raise TypeError(
|
| 426 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 427 |
+
f" {type(prompt)}."
|
| 428 |
+
)
|
| 429 |
+
elif isinstance(negative_prompt, str):
|
| 430 |
+
uncond_tokens = [negative_prompt]
|
| 431 |
+
elif batch_size != len(negative_prompt):
|
| 432 |
+
raise ValueError(
|
| 433 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 434 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 435 |
+
" the batch size of `prompt`."
|
| 436 |
+
)
|
| 437 |
+
else:
|
| 438 |
+
uncond_tokens = negative_prompt
|
| 439 |
+
|
| 440 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 441 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 442 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 443 |
+
|
| 444 |
+
max_length = prompt_embeds.shape[1]
|
| 445 |
+
uncond_input = self.tokenizer(
|
| 446 |
+
uncond_tokens,
|
| 447 |
+
padding="max_length",
|
| 448 |
+
max_length=max_length,
|
| 449 |
+
truncation=True,
|
| 450 |
+
return_tensors="pt",
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 454 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 455 |
+
else:
|
| 456 |
+
attention_mask = None
|
| 457 |
+
|
| 458 |
+
negative_prompt_embeds = self.text_encoder(
|
| 459 |
+
uncond_input.input_ids.to(device),
|
| 460 |
+
attention_mask=attention_mask,
|
| 461 |
+
)
|
| 462 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 463 |
+
|
| 464 |
+
if do_classifier_free_guidance:
|
| 465 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 466 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 467 |
+
|
| 468 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 469 |
+
|
| 470 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 471 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 472 |
+
|
| 473 |
+
if self.text_encoder is not None:
|
| 474 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 475 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 476 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 477 |
+
|
| 478 |
+
return prompt_embeds, negative_prompt_embeds
|
| 479 |
+
|
| 480 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 481 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 482 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 483 |
+
|
| 484 |
+
if not isinstance(image, torch.Tensor):
|
| 485 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 486 |
+
|
| 487 |
+
image = image.to(device=device, dtype=dtype)
|
| 488 |
+
if output_hidden_states:
|
| 489 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 490 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 491 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 492 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 493 |
+
).hidden_states[-2]
|
| 494 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 495 |
+
num_images_per_prompt, dim=0
|
| 496 |
+
)
|
| 497 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 498 |
+
else:
|
| 499 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 500 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 501 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 502 |
+
|
| 503 |
+
return image_embeds, uncond_image_embeds
|
| 504 |
+
|
| 505 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 506 |
+
def prepare_ip_adapter_image_embeds(
|
| 507 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 508 |
+
):
|
| 509 |
+
image_embeds = []
|
| 510 |
+
if do_classifier_free_guidance:
|
| 511 |
+
negative_image_embeds = []
|
| 512 |
+
if ip_adapter_image_embeds is None:
|
| 513 |
+
if not isinstance(ip_adapter_image, list):
|
| 514 |
+
ip_adapter_image = [ip_adapter_image]
|
| 515 |
+
|
| 516 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 517 |
+
raise ValueError(
|
| 518 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 522 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 523 |
+
):
|
| 524 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 525 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 526 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 530 |
+
if do_classifier_free_guidance:
|
| 531 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 532 |
+
else:
|
| 533 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 534 |
+
if do_classifier_free_guidance:
|
| 535 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 536 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 537 |
+
image_embeds.append(single_image_embeds)
|
| 538 |
+
|
| 539 |
+
ip_adapter_image_embeds = []
|
| 540 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 541 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 542 |
+
if do_classifier_free_guidance:
|
| 543 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 544 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 545 |
+
|
| 546 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 547 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 548 |
+
|
| 549 |
+
return ip_adapter_image_embeds
|
| 550 |
+
|
| 551 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
| 552 |
+
def run_safety_checker(self, image, device, dtype):
|
| 553 |
+
if self.safety_checker is None:
|
| 554 |
+
has_nsfw_concept = None
|
| 555 |
+
else:
|
| 556 |
+
if torch.is_tensor(image):
|
| 557 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 558 |
+
else:
|
| 559 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 560 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| 561 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 562 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 563 |
+
)
|
| 564 |
+
return image, has_nsfw_concept
|
| 565 |
+
|
| 566 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
| 567 |
+
def decode_latents(self, latents):
|
| 568 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 569 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 570 |
+
|
| 571 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 572 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 573 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 574 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 575 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 576 |
+
return image
|
| 577 |
+
|
| 578 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 579 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 580 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 581 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 582 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 583 |
+
# and should be between [0, 1]
|
| 584 |
+
|
| 585 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 586 |
+
extra_step_kwargs = {}
|
| 587 |
+
if accepts_eta:
|
| 588 |
+
extra_step_kwargs["eta"] = eta
|
| 589 |
+
|
| 590 |
+
# check if the scheduler accepts generator
|
| 591 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 592 |
+
if accepts_generator:
|
| 593 |
+
extra_step_kwargs["generator"] = generator
|
| 594 |
+
return extra_step_kwargs
|
| 595 |
+
|
| 596 |
+
def check_inputs(
|
| 597 |
+
self,
|
| 598 |
+
prompt,
|
| 599 |
+
image,
|
| 600 |
+
callback_steps,
|
| 601 |
+
negative_prompt=None,
|
| 602 |
+
prompt_embeds=None,
|
| 603 |
+
negative_prompt_embeds=None,
|
| 604 |
+
ip_adapter_image=None,
|
| 605 |
+
ip_adapter_image_embeds=None,
|
| 606 |
+
controlnet_conditioning_scale=1.0,
|
| 607 |
+
control_guidance_start=0.0,
|
| 608 |
+
control_guidance_end=1.0,
|
| 609 |
+
callback_on_step_end_tensor_inputs=None,
|
| 610 |
+
):
|
| 611 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 612 |
+
raise ValueError(
|
| 613 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 614 |
+
f" {type(callback_steps)}."
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 618 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 619 |
+
):
|
| 620 |
+
raise ValueError(
|
| 621 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
if prompt is not None and prompt_embeds is not None:
|
| 625 |
+
raise ValueError(
|
| 626 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 627 |
+
" only forward one of the two."
|
| 628 |
+
)
|
| 629 |
+
elif prompt is None and prompt_embeds is None:
|
| 630 |
+
raise ValueError(
|
| 631 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 632 |
+
)
|
| 633 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 634 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 635 |
+
|
| 636 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 637 |
+
raise ValueError(
|
| 638 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 639 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 643 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 644 |
+
raise ValueError(
|
| 645 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 646 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 647 |
+
f" {negative_prompt_embeds.shape}."
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
# Check `image`
|
| 651 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
| 652 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
| 653 |
+
)
|
| 654 |
+
if (
|
| 655 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 656 |
+
or is_compiled
|
| 657 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 658 |
+
):
|
| 659 |
+
self.check_image(image, prompt, prompt_embeds)
|
| 660 |
+
elif (
|
| 661 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 662 |
+
or is_compiled
|
| 663 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 664 |
+
):
|
| 665 |
+
if not isinstance(image, list):
|
| 666 |
+
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
| 667 |
+
|
| 668 |
+
# When `image` is a nested list:
|
| 669 |
+
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
| 670 |
+
elif any(isinstance(i, list) for i in image):
|
| 671 |
+
transposed_image = [list(t) for t in zip(*image)]
|
| 672 |
+
if len(transposed_image) != len(self.controlnet.nets):
|
| 673 |
+
raise ValueError(
|
| 674 |
+
f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
|
| 675 |
+
)
|
| 676 |
+
for image_ in transposed_image:
|
| 677 |
+
self.check_image(image_, prompt, prompt_embeds)
|
| 678 |
+
elif len(image) != len(self.controlnet.nets):
|
| 679 |
+
raise ValueError(
|
| 680 |
+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
| 681 |
+
)
|
| 682 |
+
else:
|
| 683 |
+
for image_ in image:
|
| 684 |
+
self.check_image(image_, prompt, prompt_embeds)
|
| 685 |
+
else:
|
| 686 |
+
assert False
|
| 687 |
+
|
| 688 |
+
# Check `controlnet_conditioning_scale`
|
| 689 |
+
if (
|
| 690 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 691 |
+
or is_compiled
|
| 692 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 693 |
+
):
|
| 694 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
| 695 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
| 696 |
+
elif (
|
| 697 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 698 |
+
or is_compiled
|
| 699 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 700 |
+
):
|
| 701 |
+
if isinstance(controlnet_conditioning_scale, list):
|
| 702 |
+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
| 703 |
+
raise ValueError(
|
| 704 |
+
"A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
|
| 705 |
+
"The conditioning scale must be fixed across the batch."
|
| 706 |
+
)
|
| 707 |
+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
| 708 |
+
self.controlnet.nets
|
| 709 |
+
):
|
| 710 |
+
raise ValueError(
|
| 711 |
+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
| 712 |
+
" the same length as the number of controlnets"
|
| 713 |
+
)
|
| 714 |
+
else:
|
| 715 |
+
assert False
|
| 716 |
+
|
| 717 |
+
if not isinstance(control_guidance_start, (tuple, list)):
|
| 718 |
+
control_guidance_start = [control_guidance_start]
|
| 719 |
+
|
| 720 |
+
if not isinstance(control_guidance_end, (tuple, list)):
|
| 721 |
+
control_guidance_end = [control_guidance_end]
|
| 722 |
+
|
| 723 |
+
if len(control_guidance_start) != len(control_guidance_end):
|
| 724 |
+
raise ValueError(
|
| 725 |
+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 729 |
+
if len(control_guidance_start) != len(self.controlnet.nets):
|
| 730 |
+
raise ValueError(
|
| 731 |
+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
for start, end in zip(control_guidance_start, control_guidance_end):
|
| 735 |
+
if start >= end:
|
| 736 |
+
raise ValueError(
|
| 737 |
+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
| 738 |
+
)
|
| 739 |
+
if start < 0.0:
|
| 740 |
+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
| 741 |
+
if end > 1.0:
|
| 742 |
+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
| 743 |
+
|
| 744 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| 745 |
+
raise ValueError(
|
| 746 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
if ip_adapter_image_embeds is not None:
|
| 750 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 751 |
+
raise ValueError(
|
| 752 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| 753 |
+
)
|
| 754 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| 755 |
+
raise ValueError(
|
| 756 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
def check_image(self, image, prompt, prompt_embeds):
|
| 760 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
| 761 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
| 762 |
+
image_is_np = isinstance(image, np.ndarray)
|
| 763 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
| 764 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
| 765 |
+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
| 766 |
+
|
| 767 |
+
if (
|
| 768 |
+
not image_is_pil
|
| 769 |
+
and not image_is_tensor
|
| 770 |
+
and not image_is_np
|
| 771 |
+
and not image_is_pil_list
|
| 772 |
+
and not image_is_tensor_list
|
| 773 |
+
and not image_is_np_list
|
| 774 |
+
):
|
| 775 |
+
raise TypeError(
|
| 776 |
+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
if image_is_pil:
|
| 780 |
+
image_batch_size = 1
|
| 781 |
+
else:
|
| 782 |
+
image_batch_size = len(image)
|
| 783 |
+
|
| 784 |
+
if prompt is not None and isinstance(prompt, str):
|
| 785 |
+
prompt_batch_size = 1
|
| 786 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 787 |
+
prompt_batch_size = len(prompt)
|
| 788 |
+
elif prompt_embeds is not None:
|
| 789 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
| 790 |
+
|
| 791 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
| 792 |
+
raise ValueError(
|
| 793 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
def prepare_image(
|
| 797 |
+
self,
|
| 798 |
+
image,
|
| 799 |
+
width,
|
| 800 |
+
height,
|
| 801 |
+
batch_size,
|
| 802 |
+
num_images_per_prompt,
|
| 803 |
+
device,
|
| 804 |
+
dtype,
|
| 805 |
+
do_classifier_free_guidance=False,
|
| 806 |
+
guess_mode=False,
|
| 807 |
+
):
|
| 808 |
+
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
| 809 |
+
image_batch_size = image.shape[0]
|
| 810 |
+
|
| 811 |
+
if image_batch_size == 1:
|
| 812 |
+
repeat_by = batch_size
|
| 813 |
+
else:
|
| 814 |
+
# image batch size is the same as prompt batch size
|
| 815 |
+
repeat_by = num_images_per_prompt
|
| 816 |
+
|
| 817 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 818 |
+
|
| 819 |
+
image = image.to(device=device, dtype=dtype)
|
| 820 |
+
|
| 821 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 822 |
+
image = torch.cat([image] * 2)
|
| 823 |
+
|
| 824 |
+
return image
|
| 825 |
+
|
| 826 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 827 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 828 |
+
shape = (
|
| 829 |
+
batch_size,
|
| 830 |
+
num_channels_latents,
|
| 831 |
+
int(height) // self.vae_scale_factor,
|
| 832 |
+
int(width) // self.vae_scale_factor,
|
| 833 |
+
)
|
| 834 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 835 |
+
raise ValueError(
|
| 836 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 837 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
if latents is None:
|
| 841 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 842 |
+
else:
|
| 843 |
+
latents = latents.to(device)
|
| 844 |
+
|
| 845 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 846 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 847 |
+
return latents
|
| 848 |
+
|
| 849 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
| 850 |
+
def get_guidance_scale_embedding(
|
| 851 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
| 852 |
+
) -> torch.Tensor:
|
| 853 |
+
"""
|
| 854 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
w (`torch.Tensor`):
|
| 858 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
| 859 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 860 |
+
Dimension of the embeddings to generate.
|
| 861 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
| 862 |
+
Data type of the generated embeddings.
|
| 863 |
+
|
| 864 |
+
Returns:
|
| 865 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
| 866 |
+
"""
|
| 867 |
+
assert len(w.shape) == 1
|
| 868 |
+
w = w * 1000.0
|
| 869 |
+
|
| 870 |
+
half_dim = embedding_dim // 2
|
| 871 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 872 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 873 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 874 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 875 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 876 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 877 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 878 |
+
return emb
|
| 879 |
+
|
| 880 |
+
@property
|
| 881 |
+
def guidance_scale(self):
|
| 882 |
+
return self._guidance_scale
|
| 883 |
+
|
| 884 |
+
@property
|
| 885 |
+
def clip_skip(self):
|
| 886 |
+
return self._clip_skip
|
| 887 |
+
|
| 888 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 889 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 890 |
+
# corresponds to doing no classifier free guidance.
|
| 891 |
+
@property
|
| 892 |
+
def do_classifier_free_guidance(self):
|
| 893 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 894 |
+
|
| 895 |
+
@property
|
| 896 |
+
def cross_attention_kwargs(self):
|
| 897 |
+
return self._cross_attention_kwargs
|
| 898 |
+
|
| 899 |
+
@property
|
| 900 |
+
def num_timesteps(self):
|
| 901 |
+
return self._num_timesteps
|
| 902 |
+
|
| 903 |
+
@property
|
| 904 |
+
def interrupt(self):
|
| 905 |
+
return self._interrupt
|
| 906 |
+
|
| 907 |
+
@torch.no_grad()
|
| 908 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 909 |
+
def __call__(
|
| 910 |
+
self,
|
| 911 |
+
prompt: Union[str, List[str]] = None,
|
| 912 |
+
image: PipelineImageInput = None,
|
| 913 |
+
height: Optional[int] = None,
|
| 914 |
+
width: Optional[int] = None,
|
| 915 |
+
num_inference_steps: int = 50,
|
| 916 |
+
timesteps: List[int] = None,
|
| 917 |
+
sigmas: List[float] = None,
|
| 918 |
+
guidance_scale: float = 7.5,
|
| 919 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 920 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 921 |
+
eta: float = 0.0,
|
| 922 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 923 |
+
latents: Optional[torch.Tensor] = None,
|
| 924 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 925 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 926 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 927 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 928 |
+
output_type: Optional[str] = "pil",
|
| 929 |
+
return_dict: bool = True,
|
| 930 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 931 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 932 |
+
guess_mode: bool = False,
|
| 933 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 934 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 935 |
+
clip_skip: Optional[int] = None,
|
| 936 |
+
callback_on_step_end: Optional[
|
| 937 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 938 |
+
] = None,
|
| 939 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 940 |
+
**kwargs,
|
| 941 |
+
):
|
| 942 |
+
r"""
|
| 943 |
+
The call function to the pipeline for generation.
|
| 944 |
+
|
| 945 |
+
Args:
|
| 946 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 947 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 948 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
| 949 |
+
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
| 950 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
| 951 |
+
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
| 952 |
+
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
| 953 |
+
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
| 954 |
+
images must be passed as a list such that each element of the list can be correctly batched for input
|
| 955 |
+
to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
|
| 956 |
+
ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
|
| 957 |
+
ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
|
| 958 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 959 |
+
The height in pixels of the generated image.
|
| 960 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 961 |
+
The width in pixels of the generated image.
|
| 962 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 963 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 964 |
+
expense of slower inference.
|
| 965 |
+
timesteps (`List[int]`, *optional*):
|
| 966 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 967 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 968 |
+
passed will be used. Must be in descending order.
|
| 969 |
+
sigmas (`List[float]`, *optional*):
|
| 970 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 971 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 972 |
+
will be used.
|
| 973 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 974 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 975 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 976 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 977 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 978 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 979 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 980 |
+
The number of images to generate per prompt.
|
| 981 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 982 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 983 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 984 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 985 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 986 |
+
generation deterministic.
|
| 987 |
+
latents (`torch.Tensor`, *optional*):
|
| 988 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 989 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 990 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 991 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 992 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 993 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 994 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 995 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 996 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 997 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 998 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 999 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 1000 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 1001 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 1002 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 1003 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1004 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 1005 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1006 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1007 |
+
plain tuple.
|
| 1008 |
+
callback (`Callable`, *optional*):
|
| 1009 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 1010 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 1011 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1012 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 1013 |
+
every step.
|
| 1014 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 1015 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 1016 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1017 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 1018 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
| 1019 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
| 1020 |
+
the corresponding scale as a list.
|
| 1021 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
| 1022 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
| 1023 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
| 1024 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
| 1025 |
+
The percentage of total steps at which the ControlNet starts applying.
|
| 1026 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 1027 |
+
The percentage of total steps at which the ControlNet stops applying.
|
| 1028 |
+
clip_skip (`int`, *optional*):
|
| 1029 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 1030 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 1031 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 1032 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 1033 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 1034 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 1035 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 1036 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1037 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 1038 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 1039 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 1040 |
+
|
| 1041 |
+
Examples:
|
| 1042 |
+
|
| 1043 |
+
Returns:
|
| 1044 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1045 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 1046 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 1047 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 1048 |
+
"not-safe-for-work" (nsfw) content.
|
| 1049 |
+
"""
|
| 1050 |
+
|
| 1051 |
+
callback = kwargs.pop("callback", None)
|
| 1052 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 1053 |
+
|
| 1054 |
+
if callback is not None:
|
| 1055 |
+
deprecate(
|
| 1056 |
+
"callback",
|
| 1057 |
+
"1.0.0",
|
| 1058 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 1059 |
+
)
|
| 1060 |
+
if callback_steps is not None:
|
| 1061 |
+
deprecate(
|
| 1062 |
+
"callback_steps",
|
| 1063 |
+
"1.0.0",
|
| 1064 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 1068 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 1069 |
+
|
| 1070 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 1071 |
+
|
| 1072 |
+
# align format for control guidance
|
| 1073 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 1074 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 1075 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 1076 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 1077 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 1078 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 1079 |
+
control_guidance_start, control_guidance_end = (
|
| 1080 |
+
mult * [control_guidance_start],
|
| 1081 |
+
mult * [control_guidance_end],
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# 1. Check inputs. Raise error if not correct
|
| 1085 |
+
self.check_inputs(
|
| 1086 |
+
prompt,
|
| 1087 |
+
image,
|
| 1088 |
+
callback_steps,
|
| 1089 |
+
negative_prompt,
|
| 1090 |
+
prompt_embeds,
|
| 1091 |
+
negative_prompt_embeds,
|
| 1092 |
+
ip_adapter_image,
|
| 1093 |
+
ip_adapter_image_embeds,
|
| 1094 |
+
controlnet_conditioning_scale,
|
| 1095 |
+
control_guidance_start,
|
| 1096 |
+
control_guidance_end,
|
| 1097 |
+
callback_on_step_end_tensor_inputs,
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
self._guidance_scale = guidance_scale
|
| 1101 |
+
self._clip_skip = clip_skip
|
| 1102 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 1103 |
+
self._interrupt = False
|
| 1104 |
+
|
| 1105 |
+
# 2. Define call parameters
|
| 1106 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1107 |
+
batch_size = 1
|
| 1108 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1109 |
+
batch_size = len(prompt)
|
| 1110 |
+
else:
|
| 1111 |
+
batch_size = prompt_embeds.shape[0]
|
| 1112 |
+
|
| 1113 |
+
device = self._execution_device
|
| 1114 |
+
|
| 1115 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 1116 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
| 1117 |
+
|
| 1118 |
+
global_pool_conditions = (
|
| 1119 |
+
controlnet.config.global_pool_conditions
|
| 1120 |
+
if isinstance(controlnet, ControlNetModel)
|
| 1121 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 1122 |
+
)
|
| 1123 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 1124 |
+
|
| 1125 |
+
# 3. Encode input prompt
|
| 1126 |
+
text_encoder_lora_scale = (
|
| 1127 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 1128 |
+
)
|
| 1129 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 1130 |
+
prompt,
|
| 1131 |
+
device,
|
| 1132 |
+
num_images_per_prompt,
|
| 1133 |
+
self.do_classifier_free_guidance,
|
| 1134 |
+
negative_prompt,
|
| 1135 |
+
prompt_embeds=prompt_embeds,
|
| 1136 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1137 |
+
lora_scale=text_encoder_lora_scale,
|
| 1138 |
+
clip_skip=self.clip_skip,
|
| 1139 |
+
)
|
| 1140 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 1141 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 1142 |
+
# to avoid doing two forward passes
|
| 1143 |
+
if self.do_classifier_free_guidance:
|
| 1144 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 1145 |
+
|
| 1146 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1147 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1148 |
+
ip_adapter_image,
|
| 1149 |
+
ip_adapter_image_embeds,
|
| 1150 |
+
device,
|
| 1151 |
+
batch_size * num_images_per_prompt,
|
| 1152 |
+
self.do_classifier_free_guidance,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
# 4. Prepare image
|
| 1156 |
+
if isinstance(controlnet, ControlNetModel):
|
| 1157 |
+
image = self.prepare_image(
|
| 1158 |
+
image=image,
|
| 1159 |
+
width=width,
|
| 1160 |
+
height=height,
|
| 1161 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 1162 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1163 |
+
device=device,
|
| 1164 |
+
dtype=controlnet.dtype,
|
| 1165 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1166 |
+
guess_mode=guess_mode,
|
| 1167 |
+
)
|
| 1168 |
+
height, width = image.shape[-2:]
|
| 1169 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
| 1170 |
+
images = []
|
| 1171 |
+
|
| 1172 |
+
# Nested lists as ControlNet condition
|
| 1173 |
+
if isinstance(image[0], list):
|
| 1174 |
+
# Transpose the nested image list
|
| 1175 |
+
image = [list(t) for t in zip(*image)]
|
| 1176 |
+
|
| 1177 |
+
for image_ in image:
|
| 1178 |
+
image_ = self.prepare_image(
|
| 1179 |
+
image=image_,
|
| 1180 |
+
width=width,
|
| 1181 |
+
height=height,
|
| 1182 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 1183 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1184 |
+
device=device,
|
| 1185 |
+
dtype=controlnet.dtype,
|
| 1186 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1187 |
+
guess_mode=guess_mode,
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
images.append(image_)
|
| 1191 |
+
|
| 1192 |
+
image = images
|
| 1193 |
+
height, width = image[0].shape[-2:]
|
| 1194 |
+
else:
|
| 1195 |
+
assert False
|
| 1196 |
+
|
| 1197 |
+
# 5. Prepare timesteps
|
| 1198 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1199 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 1200 |
+
)
|
| 1201 |
+
self._num_timesteps = len(timesteps)
|
| 1202 |
+
|
| 1203 |
+
# 6. Prepare latent variables
|
| 1204 |
+
num_channels_latents = self.unet.config.in_channels
|
| 1205 |
+
latents = self.prepare_latents(
|
| 1206 |
+
batch_size * num_images_per_prompt,
|
| 1207 |
+
num_channels_latents,
|
| 1208 |
+
height,
|
| 1209 |
+
width,
|
| 1210 |
+
prompt_embeds.dtype,
|
| 1211 |
+
device,
|
| 1212 |
+
generator,
|
| 1213 |
+
latents,
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
| 1217 |
+
timestep_cond = None
|
| 1218 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 1219 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 1220 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 1221 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 1222 |
+
).to(device=device, dtype=latents.dtype)
|
| 1223 |
+
|
| 1224 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1225 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1226 |
+
|
| 1227 |
+
# 7.1 Add image embeds for IP-Adapter
|
| 1228 |
+
added_cond_kwargs = (
|
| 1229 |
+
{"image_embeds": image_embeds}
|
| 1230 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| 1231 |
+
else None
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
# 7.2 Create tensor stating which controlnets to keep
|
| 1235 |
+
controlnet_keep = []
|
| 1236 |
+
for i in range(len(timesteps)):
|
| 1237 |
+
keeps = [
|
| 1238 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 1239 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 1240 |
+
]
|
| 1241 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
| 1242 |
+
|
| 1243 |
+
# 8. Denoising loop
|
| 1244 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1245 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
| 1246 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
| 1247 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
| 1248 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1249 |
+
for i, t in enumerate(timesteps):
|
| 1250 |
+
if self.interrupt:
|
| 1251 |
+
continue
|
| 1252 |
+
|
| 1253 |
+
# Relevant thread:
|
| 1254 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
| 1255 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
| 1256 |
+
torch._inductor.cudagraph_mark_step_begin()
|
| 1257 |
+
# expand the latents if we are doing classifier free guidance
|
| 1258 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1259 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1260 |
+
|
| 1261 |
+
# controlnet(s) inference
|
| 1262 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 1263 |
+
# Infer ControlNet only for the conditional batch.
|
| 1264 |
+
control_model_input = latents
|
| 1265 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 1266 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 1267 |
+
else:
|
| 1268 |
+
control_model_input = latent_model_input
|
| 1269 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 1270 |
+
|
| 1271 |
+
if isinstance(controlnet_keep[i], list):
|
| 1272 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 1273 |
+
else:
|
| 1274 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 1275 |
+
if isinstance(controlnet_cond_scale, list):
|
| 1276 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 1277 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 1278 |
+
|
| 1279 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1280 |
+
control_model_input,
|
| 1281 |
+
t,
|
| 1282 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 1283 |
+
controlnet_cond=image,
|
| 1284 |
+
conditioning_scale=cond_scale,
|
| 1285 |
+
guess_mode=guess_mode,
|
| 1286 |
+
return_dict=False,
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 1290 |
+
# Inferred ControlNet only for the conditional batch.
|
| 1291 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 1292 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 1293 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 1294 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 1295 |
+
|
| 1296 |
+
# predict the noise residual
|
| 1297 |
+
noise_pred = self.unet(
|
| 1298 |
+
latent_model_input,
|
| 1299 |
+
t,
|
| 1300 |
+
encoder_hidden_states=prompt_embeds,
|
| 1301 |
+
timestep_cond=timestep_cond,
|
| 1302 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1303 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 1304 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 1305 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1306 |
+
return_dict=False,
|
| 1307 |
+
)[0]
|
| 1308 |
+
|
| 1309 |
+
# perform guidance
|
| 1310 |
+
if self.do_classifier_free_guidance:
|
| 1311 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1312 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1313 |
+
|
| 1314 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1315 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1316 |
+
|
| 1317 |
+
if callback_on_step_end is not None:
|
| 1318 |
+
callback_kwargs = {}
|
| 1319 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1320 |
+
callback_kwargs[k] = locals()[k]
|
| 1321 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1322 |
+
|
| 1323 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1324 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1325 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1326 |
+
image = callback_outputs.pop("image", image)
|
| 1327 |
+
|
| 1328 |
+
# call the callback, if provided
|
| 1329 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1330 |
+
progress_bar.update()
|
| 1331 |
+
if callback is not None and i % callback_steps == 0:
|
| 1332 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 1333 |
+
callback(step_idx, t, latents)
|
| 1334 |
+
|
| 1335 |
+
if XLA_AVAILABLE:
|
| 1336 |
+
xm.mark_step()
|
| 1337 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1338 |
+
# manually for max memory savings
|
| 1339 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 1340 |
+
self.unet.to("cpu")
|
| 1341 |
+
self.controlnet.to("cpu")
|
| 1342 |
+
empty_device_cache()
|
| 1343 |
+
|
| 1344 |
+
if not output_type == "latent":
|
| 1345 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 1346 |
+
0
|
| 1347 |
+
]
|
| 1348 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 1349 |
+
else:
|
| 1350 |
+
image = latents
|
| 1351 |
+
has_nsfw_concept = None
|
| 1352 |
+
|
| 1353 |
+
if has_nsfw_concept is None:
|
| 1354 |
+
do_denormalize = [True] * image.shape[0]
|
| 1355 |
+
else:
|
| 1356 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 1357 |
+
|
| 1358 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 1359 |
+
|
| 1360 |
+
# Offload all models
|
| 1361 |
+
self.maybe_free_model_hooks()
|
| 1362 |
+
|
| 1363 |
+
if not return_dict:
|
| 1364 |
+
return (image, has_nsfw_concept)
|
| 1365 |
+
|
| 1366 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Salesforce.com, inc.
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from typing import List, Optional, Union
|
| 16 |
+
|
| 17 |
+
import PIL.Image
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import CLIPTokenizer
|
| 20 |
+
|
| 21 |
+
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
| 22 |
+
from ...schedulers import PNDMScheduler
|
| 23 |
+
from ...utils import (
|
| 24 |
+
is_torch_xla_available,
|
| 25 |
+
logging,
|
| 26 |
+
replace_example_docstring,
|
| 27 |
+
)
|
| 28 |
+
from ...utils.torch_utils import randn_tensor
|
| 29 |
+
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
|
| 30 |
+
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
| 31 |
+
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
| 32 |
+
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_torch_xla_available():
|
| 36 |
+
import torch_xla.core.xla_model as xm
|
| 37 |
+
|
| 38 |
+
XLA_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
XLA_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
EXAMPLE_DOC_STRING = """
|
| 46 |
+
Examples:
|
| 47 |
+
```py
|
| 48 |
+
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
|
| 49 |
+
>>> from diffusers.utils import load_image
|
| 50 |
+
>>> from controlnet_aux import CannyDetector
|
| 51 |
+
>>> import torch
|
| 52 |
+
|
| 53 |
+
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
|
| 54 |
+
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
|
| 55 |
+
... ).to("cuda")
|
| 56 |
+
|
| 57 |
+
>>> style_subject = "flower"
|
| 58 |
+
>>> tgt_subject = "teapot"
|
| 59 |
+
>>> text_prompt = "on a marble table"
|
| 60 |
+
|
| 61 |
+
>>> cldm_cond_image = load_image(
|
| 62 |
+
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
|
| 63 |
+
... ).resize((512, 512))
|
| 64 |
+
>>> canny = CannyDetector()
|
| 65 |
+
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
|
| 66 |
+
>>> style_image = load_image(
|
| 67 |
+
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
|
| 68 |
+
... )
|
| 69 |
+
>>> guidance_scale = 7.5
|
| 70 |
+
>>> num_inference_steps = 50
|
| 71 |
+
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
>>> output = blip_diffusion_pipe(
|
| 75 |
+
... text_prompt,
|
| 76 |
+
... style_image,
|
| 77 |
+
... cldm_cond_image,
|
| 78 |
+
... style_subject,
|
| 79 |
+
... tgt_subject,
|
| 80 |
+
... guidance_scale=guidance_scale,
|
| 81 |
+
... num_inference_steps=num_inference_steps,
|
| 82 |
+
... neg_prompt=negative_prompt,
|
| 83 |
+
... height=512,
|
| 84 |
+
... width=512,
|
| 85 |
+
... ).images
|
| 86 |
+
>>> output[0].save("image.png")
|
| 87 |
+
```
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
| 92 |
+
"""
|
| 93 |
+
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
|
| 94 |
+
|
| 95 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 96 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
tokenizer ([`CLIPTokenizer`]):
|
| 100 |
+
Tokenizer for the text encoder
|
| 101 |
+
text_encoder ([`ContextCLIPTextModel`]):
|
| 102 |
+
Text encoder to encode the text prompt
|
| 103 |
+
vae ([`AutoencoderKL`]):
|
| 104 |
+
VAE model to map the latents to the image
|
| 105 |
+
unet ([`UNet2DConditionModel`]):
|
| 106 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 107 |
+
scheduler ([`PNDMScheduler`]):
|
| 108 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 109 |
+
qformer ([`Blip2QFormerModel`]):
|
| 110 |
+
QFormer model to get multi-modal embeddings from the text and image.
|
| 111 |
+
controlnet ([`ControlNetModel`]):
|
| 112 |
+
ControlNet model to get the conditioning image embedding.
|
| 113 |
+
image_processor ([`BlipImageProcessor`]):
|
| 114 |
+
Image Processor to preprocess and postprocess the image.
|
| 115 |
+
ctx_begin_pos (int, `optional`, defaults to 2):
|
| 116 |
+
Position of the context token in the text encoder.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
_last_supported_version = "0.33.1"
|
| 120 |
+
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
tokenizer: CLIPTokenizer,
|
| 125 |
+
text_encoder: ContextCLIPTextModel,
|
| 126 |
+
vae: AutoencoderKL,
|
| 127 |
+
unet: UNet2DConditionModel,
|
| 128 |
+
scheduler: PNDMScheduler,
|
| 129 |
+
qformer: Blip2QFormerModel,
|
| 130 |
+
controlnet: ControlNetModel,
|
| 131 |
+
image_processor: BlipImageProcessor,
|
| 132 |
+
ctx_begin_pos: int = 2,
|
| 133 |
+
mean: List[float] = None,
|
| 134 |
+
std: List[float] = None,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
|
| 138 |
+
self.register_modules(
|
| 139 |
+
tokenizer=tokenizer,
|
| 140 |
+
text_encoder=text_encoder,
|
| 141 |
+
vae=vae,
|
| 142 |
+
unet=unet,
|
| 143 |
+
scheduler=scheduler,
|
| 144 |
+
qformer=qformer,
|
| 145 |
+
controlnet=controlnet,
|
| 146 |
+
image_processor=image_processor,
|
| 147 |
+
)
|
| 148 |
+
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
| 149 |
+
|
| 150 |
+
def get_query_embeddings(self, input_image, src_subject):
|
| 151 |
+
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
| 152 |
+
|
| 153 |
+
# from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
|
| 154 |
+
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
| 155 |
+
rv = []
|
| 156 |
+
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
| 157 |
+
prompt = f"a {tgt_subject} {prompt.strip()}"
|
| 158 |
+
# a trick to amplify the prompt
|
| 159 |
+
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
| 160 |
+
|
| 161 |
+
return rv
|
| 162 |
+
|
| 163 |
+
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
| 164 |
+
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
| 165 |
+
shape = (batch_size, num_channels, height, width)
|
| 166 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 167 |
+
raise ValueError(
|
| 168 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 169 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if latents is None:
|
| 173 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 174 |
+
else:
|
| 175 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 176 |
+
|
| 177 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 178 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 179 |
+
return latents
|
| 180 |
+
|
| 181 |
+
def encode_prompt(self, query_embeds, prompt, device=None):
|
| 182 |
+
device = device or self._execution_device
|
| 183 |
+
|
| 184 |
+
# embeddings for prompt, with query_embeds as context
|
| 185 |
+
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
| 186 |
+
max_len -= self.qformer.config.num_query_tokens
|
| 187 |
+
|
| 188 |
+
tokenized_prompt = self.tokenizer(
|
| 189 |
+
prompt,
|
| 190 |
+
padding="max_length",
|
| 191 |
+
truncation=True,
|
| 192 |
+
max_length=max_len,
|
| 193 |
+
return_tensors="pt",
|
| 194 |
+
).to(device)
|
| 195 |
+
|
| 196 |
+
batch_size = query_embeds.shape[0]
|
| 197 |
+
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
| 198 |
+
|
| 199 |
+
text_embeddings = self.text_encoder(
|
| 200 |
+
input_ids=tokenized_prompt.input_ids,
|
| 201 |
+
ctx_embeddings=query_embeds,
|
| 202 |
+
ctx_begin_pos=ctx_begin_pos,
|
| 203 |
+
)[0]
|
| 204 |
+
|
| 205 |
+
return text_embeddings
|
| 206 |
+
|
| 207 |
+
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
| 208 |
+
def prepare_control_image(
|
| 209 |
+
self,
|
| 210 |
+
image,
|
| 211 |
+
width,
|
| 212 |
+
height,
|
| 213 |
+
batch_size,
|
| 214 |
+
num_images_per_prompt,
|
| 215 |
+
device,
|
| 216 |
+
dtype,
|
| 217 |
+
do_classifier_free_guidance=False,
|
| 218 |
+
):
|
| 219 |
+
image = self.image_processor.preprocess(
|
| 220 |
+
image,
|
| 221 |
+
size={"width": width, "height": height},
|
| 222 |
+
do_rescale=True,
|
| 223 |
+
do_center_crop=False,
|
| 224 |
+
do_normalize=False,
|
| 225 |
+
return_tensors="pt",
|
| 226 |
+
)["pixel_values"].to(device)
|
| 227 |
+
image_batch_size = image.shape[0]
|
| 228 |
+
|
| 229 |
+
if image_batch_size == 1:
|
| 230 |
+
repeat_by = batch_size
|
| 231 |
+
else:
|
| 232 |
+
# image batch size is the same as prompt batch size
|
| 233 |
+
repeat_by = num_images_per_prompt
|
| 234 |
+
|
| 235 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 236 |
+
|
| 237 |
+
image = image.to(device=device, dtype=dtype)
|
| 238 |
+
|
| 239 |
+
if do_classifier_free_guidance:
|
| 240 |
+
image = torch.cat([image] * 2)
|
| 241 |
+
|
| 242 |
+
return image
|
| 243 |
+
|
| 244 |
+
@torch.no_grad()
|
| 245 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 246 |
+
def __call__(
|
| 247 |
+
self,
|
| 248 |
+
prompt: List[str],
|
| 249 |
+
reference_image: PIL.Image.Image,
|
| 250 |
+
condtioning_image: PIL.Image.Image,
|
| 251 |
+
source_subject_category: List[str],
|
| 252 |
+
target_subject_category: List[str],
|
| 253 |
+
latents: Optional[torch.Tensor] = None,
|
| 254 |
+
guidance_scale: float = 7.5,
|
| 255 |
+
height: int = 512,
|
| 256 |
+
width: int = 512,
|
| 257 |
+
num_inference_steps: int = 50,
|
| 258 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 259 |
+
neg_prompt: Optional[str] = "",
|
| 260 |
+
prompt_strength: float = 1.0,
|
| 261 |
+
prompt_reps: int = 20,
|
| 262 |
+
output_type: Optional[str] = "pil",
|
| 263 |
+
return_dict: bool = True,
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Function invoked when calling the pipeline for generation.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
prompt (`List[str]`):
|
| 270 |
+
The prompt or prompts to guide the image generation.
|
| 271 |
+
reference_image (`PIL.Image.Image`):
|
| 272 |
+
The reference image to condition the generation on.
|
| 273 |
+
condtioning_image (`PIL.Image.Image`):
|
| 274 |
+
The conditioning canny edge image to condition the generation on.
|
| 275 |
+
source_subject_category (`List[str]`):
|
| 276 |
+
The source subject category.
|
| 277 |
+
target_subject_category (`List[str]`):
|
| 278 |
+
The target subject category.
|
| 279 |
+
latents (`torch.Tensor`, *optional*):
|
| 280 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 281 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 282 |
+
tensor will be generated by random sampling.
|
| 283 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 284 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 285 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 286 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 287 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 288 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 289 |
+
height (`int`, *optional*, defaults to 512):
|
| 290 |
+
The height of the generated image.
|
| 291 |
+
width (`int`, *optional*, defaults to 512):
|
| 292 |
+
The width of the generated image.
|
| 293 |
+
seed (`int`, *optional*, defaults to 42):
|
| 294 |
+
The seed to use for random generation.
|
| 295 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 296 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 297 |
+
expense of slower inference.
|
| 298 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 299 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 300 |
+
to make generation deterministic.
|
| 301 |
+
neg_prompt (`str`, *optional*, defaults to ""):
|
| 302 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 303 |
+
if `guidance_scale` is less than `1`).
|
| 304 |
+
prompt_strength (`float`, *optional*, defaults to 1.0):
|
| 305 |
+
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
| 306 |
+
to amplify the prompt.
|
| 307 |
+
prompt_reps (`int`, *optional*, defaults to 20):
|
| 308 |
+
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
| 309 |
+
Examples:
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 313 |
+
"""
|
| 314 |
+
device = self._execution_device
|
| 315 |
+
|
| 316 |
+
reference_image = self.image_processor.preprocess(
|
| 317 |
+
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
| 318 |
+
)["pixel_values"]
|
| 319 |
+
reference_image = reference_image.to(device)
|
| 320 |
+
|
| 321 |
+
if isinstance(prompt, str):
|
| 322 |
+
prompt = [prompt]
|
| 323 |
+
if isinstance(source_subject_category, str):
|
| 324 |
+
source_subject_category = [source_subject_category]
|
| 325 |
+
if isinstance(target_subject_category, str):
|
| 326 |
+
target_subject_category = [target_subject_category]
|
| 327 |
+
|
| 328 |
+
batch_size = len(prompt)
|
| 329 |
+
|
| 330 |
+
prompt = self._build_prompt(
|
| 331 |
+
prompts=prompt,
|
| 332 |
+
tgt_subjects=target_subject_category,
|
| 333 |
+
prompt_strength=prompt_strength,
|
| 334 |
+
prompt_reps=prompt_reps,
|
| 335 |
+
)
|
| 336 |
+
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
| 337 |
+
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
|
| 338 |
+
# 3. unconditional embedding
|
| 339 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 340 |
+
if do_classifier_free_guidance:
|
| 341 |
+
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
| 342 |
+
|
| 343 |
+
uncond_input = self.tokenizer(
|
| 344 |
+
[neg_prompt] * batch_size,
|
| 345 |
+
padding="max_length",
|
| 346 |
+
max_length=max_length,
|
| 347 |
+
return_tensors="pt",
|
| 348 |
+
)
|
| 349 |
+
uncond_embeddings = self.text_encoder(
|
| 350 |
+
input_ids=uncond_input.input_ids.to(device),
|
| 351 |
+
ctx_embeddings=None,
|
| 352 |
+
)[0]
|
| 353 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 354 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 355 |
+
# to avoid doing two forward passes
|
| 356 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 357 |
+
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
| 358 |
+
latents = self.prepare_latents(
|
| 359 |
+
batch_size=batch_size,
|
| 360 |
+
num_channels=self.unet.config.in_channels,
|
| 361 |
+
height=height // scale_down_factor,
|
| 362 |
+
width=width // scale_down_factor,
|
| 363 |
+
generator=generator,
|
| 364 |
+
latents=latents,
|
| 365 |
+
dtype=self.unet.dtype,
|
| 366 |
+
device=device,
|
| 367 |
+
)
|
| 368 |
+
# set timesteps
|
| 369 |
+
extra_set_kwargs = {}
|
| 370 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
| 371 |
+
|
| 372 |
+
cond_image = self.prepare_control_image(
|
| 373 |
+
image=condtioning_image,
|
| 374 |
+
width=width,
|
| 375 |
+
height=height,
|
| 376 |
+
batch_size=batch_size,
|
| 377 |
+
num_images_per_prompt=1,
|
| 378 |
+
device=device,
|
| 379 |
+
dtype=self.controlnet.dtype,
|
| 380 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
| 384 |
+
# expand the latents if we are doing classifier free guidance
|
| 385 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 386 |
+
|
| 387 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 388 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 389 |
+
latent_model_input,
|
| 390 |
+
t,
|
| 391 |
+
encoder_hidden_states=text_embeddings,
|
| 392 |
+
controlnet_cond=cond_image,
|
| 393 |
+
return_dict=False,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
noise_pred = self.unet(
|
| 397 |
+
latent_model_input,
|
| 398 |
+
timestep=t,
|
| 399 |
+
encoder_hidden_states=text_embeddings,
|
| 400 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 401 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 402 |
+
)["sample"]
|
| 403 |
+
|
| 404 |
+
# perform guidance
|
| 405 |
+
if do_classifier_free_guidance:
|
| 406 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 407 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 408 |
+
|
| 409 |
+
latents = self.scheduler.step(
|
| 410 |
+
noise_pred,
|
| 411 |
+
t,
|
| 412 |
+
latents,
|
| 413 |
+
)["prev_sample"]
|
| 414 |
+
|
| 415 |
+
if XLA_AVAILABLE:
|
| 416 |
+
xm.mark_step()
|
| 417 |
+
|
| 418 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 419 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 420 |
+
|
| 421 |
+
# Offload all models
|
| 422 |
+
self.maybe_free_model_hooks()
|
| 423 |
+
|
| 424 |
+
if not return_dict:
|
| 425 |
+
return (image,)
|
| 426 |
+
|
| 427 |
+
return ImagePipelineOutput(images=image)
|
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
ADDED
|
@@ -0,0 +1,1338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL.Image
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 23 |
+
|
| 24 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 26 |
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
| 27 |
+
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
|
| 28 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 29 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 30 |
+
from ...utils import (
|
| 31 |
+
USE_PEFT_BACKEND,
|
| 32 |
+
deprecate,
|
| 33 |
+
is_torch_xla_available,
|
| 34 |
+
logging,
|
| 35 |
+
replace_example_docstring,
|
| 36 |
+
scale_lora_layers,
|
| 37 |
+
unscale_lora_layers,
|
| 38 |
+
)
|
| 39 |
+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
|
| 40 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 41 |
+
from ..stable_diffusion import StableDiffusionPipelineOutput
|
| 42 |
+
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if is_torch_xla_available():
|
| 46 |
+
import torch_xla.core.xla_model as xm
|
| 47 |
+
|
| 48 |
+
XLA_AVAILABLE = True
|
| 49 |
+
else:
|
| 50 |
+
XLA_AVAILABLE = False
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
EXAMPLE_DOC_STRING = """
|
| 56 |
+
Examples:
|
| 57 |
+
```py
|
| 58 |
+
>>> # !pip install opencv-python transformers accelerate
|
| 59 |
+
>>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
|
| 60 |
+
>>> from diffusers.utils import load_image
|
| 61 |
+
>>> import numpy as np
|
| 62 |
+
>>> import torch
|
| 63 |
+
|
| 64 |
+
>>> import cv2
|
| 65 |
+
>>> from PIL import Image
|
| 66 |
+
|
| 67 |
+
>>> # download an image
|
| 68 |
+
>>> image = load_image(
|
| 69 |
+
... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
| 70 |
+
... )
|
| 71 |
+
>>> np_image = np.array(image)
|
| 72 |
+
|
| 73 |
+
>>> # get canny image
|
| 74 |
+
>>> np_image = cv2.Canny(np_image, 100, 200)
|
| 75 |
+
>>> np_image = np_image[:, :, None]
|
| 76 |
+
>>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
|
| 77 |
+
>>> canny_image = Image.fromarray(np_image)
|
| 78 |
+
|
| 79 |
+
>>> # load control net and stable diffusion v1-5
|
| 80 |
+
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
| 81 |
+
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
| 82 |
+
... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
| 83 |
+
... )
|
| 84 |
+
|
| 85 |
+
>>> # speed up diffusion process with faster scheduler and memory optimization
|
| 86 |
+
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 87 |
+
>>> pipe.enable_model_cpu_offload()
|
| 88 |
+
|
| 89 |
+
>>> # generate image
|
| 90 |
+
>>> generator = torch.manual_seed(0)
|
| 91 |
+
>>> image = pipe(
|
| 92 |
+
... "futuristic-looking woman",
|
| 93 |
+
... num_inference_steps=20,
|
| 94 |
+
... generator=generator,
|
| 95 |
+
... image=image,
|
| 96 |
+
... control_image=canny_image,
|
| 97 |
+
... ).images[0]
|
| 98 |
+
```
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 103 |
+
def retrieve_latents(
|
| 104 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 105 |
+
):
|
| 106 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 107 |
+
return encoder_output.latent_dist.sample(generator)
|
| 108 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 109 |
+
return encoder_output.latent_dist.mode()
|
| 110 |
+
elif hasattr(encoder_output, "latents"):
|
| 111 |
+
return encoder_output.latents
|
| 112 |
+
else:
|
| 113 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def prepare_image(image):
|
| 117 |
+
if isinstance(image, torch.Tensor):
|
| 118 |
+
# Batch single image
|
| 119 |
+
if image.ndim == 3:
|
| 120 |
+
image = image.unsqueeze(0)
|
| 121 |
+
|
| 122 |
+
image = image.to(dtype=torch.float32)
|
| 123 |
+
else:
|
| 124 |
+
# preprocess image
|
| 125 |
+
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
| 126 |
+
image = [image]
|
| 127 |
+
|
| 128 |
+
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
| 129 |
+
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
| 130 |
+
image = np.concatenate(image, axis=0)
|
| 131 |
+
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
| 132 |
+
image = np.concatenate([i[None, :] for i in image], axis=0)
|
| 133 |
+
|
| 134 |
+
image = image.transpose(0, 3, 1, 2)
|
| 135 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
| 136 |
+
|
| 137 |
+
return image
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class StableDiffusionControlNetImg2ImgPipeline(
|
| 141 |
+
DiffusionPipeline,
|
| 142 |
+
StableDiffusionMixin,
|
| 143 |
+
TextualInversionLoaderMixin,
|
| 144 |
+
StableDiffusionLoraLoaderMixin,
|
| 145 |
+
IPAdapterMixin,
|
| 146 |
+
FromSingleFileMixin,
|
| 147 |
+
):
|
| 148 |
+
r"""
|
| 149 |
+
Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance.
|
| 150 |
+
|
| 151 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 152 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 153 |
+
|
| 154 |
+
The pipeline also inherits the following loading methods:
|
| 155 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 156 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 157 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 158 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 159 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
vae ([`AutoencoderKL`]):
|
| 163 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 164 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 165 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 166 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 167 |
+
A `CLIPTokenizer` to tokenize text.
|
| 168 |
+
unet ([`UNet2DConditionModel`]):
|
| 169 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 170 |
+
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
|
| 171 |
+
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
| 172 |
+
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
| 173 |
+
additional conditioning.
|
| 174 |
+
scheduler ([`SchedulerMixin`]):
|
| 175 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 176 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 177 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 178 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 179 |
+
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
|
| 180 |
+
more details about a model's potential harms.
|
| 181 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 182 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
| 186 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
| 187 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 188 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
vae: AutoencoderKL,
|
| 193 |
+
text_encoder: CLIPTextModel,
|
| 194 |
+
tokenizer: CLIPTokenizer,
|
| 195 |
+
unet: UNet2DConditionModel,
|
| 196 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
| 197 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 198 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 199 |
+
feature_extractor: CLIPImageProcessor,
|
| 200 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 201 |
+
requires_safety_checker: bool = True,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
|
| 205 |
+
if safety_checker is None and requires_safety_checker:
|
| 206 |
+
logger.warning(
|
| 207 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 208 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 209 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 210 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 211 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 212 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if safety_checker is not None and feature_extractor is None:
|
| 216 |
+
raise ValueError(
|
| 217 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 218 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if isinstance(controlnet, (list, tuple)):
|
| 222 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 223 |
+
|
| 224 |
+
self.register_modules(
|
| 225 |
+
vae=vae,
|
| 226 |
+
text_encoder=text_encoder,
|
| 227 |
+
tokenizer=tokenizer,
|
| 228 |
+
unet=unet,
|
| 229 |
+
controlnet=controlnet,
|
| 230 |
+
scheduler=scheduler,
|
| 231 |
+
safety_checker=safety_checker,
|
| 232 |
+
feature_extractor=feature_extractor,
|
| 233 |
+
image_encoder=image_encoder,
|
| 234 |
+
)
|
| 235 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 236 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
| 237 |
+
self.control_image_processor = VaeImageProcessor(
|
| 238 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 239 |
+
)
|
| 240 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 241 |
+
|
| 242 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
| 243 |
+
def _encode_prompt(
|
| 244 |
+
self,
|
| 245 |
+
prompt,
|
| 246 |
+
device,
|
| 247 |
+
num_images_per_prompt,
|
| 248 |
+
do_classifier_free_guidance,
|
| 249 |
+
negative_prompt=None,
|
| 250 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 251 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 252 |
+
lora_scale: Optional[float] = None,
|
| 253 |
+
**kwargs,
|
| 254 |
+
):
|
| 255 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| 256 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
| 257 |
+
|
| 258 |
+
prompt_embeds_tuple = self.encode_prompt(
|
| 259 |
+
prompt=prompt,
|
| 260 |
+
device=device,
|
| 261 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 262 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 263 |
+
negative_prompt=negative_prompt,
|
| 264 |
+
prompt_embeds=prompt_embeds,
|
| 265 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 266 |
+
lora_scale=lora_scale,
|
| 267 |
+
**kwargs,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# concatenate for backwards comp
|
| 271 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
| 272 |
+
|
| 273 |
+
return prompt_embeds
|
| 274 |
+
|
| 275 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
| 276 |
+
def encode_prompt(
|
| 277 |
+
self,
|
| 278 |
+
prompt,
|
| 279 |
+
device,
|
| 280 |
+
num_images_per_prompt,
|
| 281 |
+
do_classifier_free_guidance,
|
| 282 |
+
negative_prompt=None,
|
| 283 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 284 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 285 |
+
lora_scale: Optional[float] = None,
|
| 286 |
+
clip_skip: Optional[int] = None,
|
| 287 |
+
):
|
| 288 |
+
r"""
|
| 289 |
+
Encodes the prompt into text encoder hidden states.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 293 |
+
prompt to be encoded
|
| 294 |
+
device: (`torch.device`):
|
| 295 |
+
torch device
|
| 296 |
+
num_images_per_prompt (`int`):
|
| 297 |
+
number of images that should be generated per prompt
|
| 298 |
+
do_classifier_free_guidance (`bool`):
|
| 299 |
+
whether to use classifier free guidance or not
|
| 300 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 301 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 302 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 303 |
+
less than `1`).
|
| 304 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 305 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 306 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 307 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 308 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 309 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 310 |
+
argument.
|
| 311 |
+
lora_scale (`float`, *optional*):
|
| 312 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 313 |
+
clip_skip (`int`, *optional*):
|
| 314 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 315 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 316 |
+
"""
|
| 317 |
+
# set lora scale so that monkey patched LoRA
|
| 318 |
+
# function of text encoder can correctly access it
|
| 319 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 320 |
+
self._lora_scale = lora_scale
|
| 321 |
+
|
| 322 |
+
# dynamically adjust the LoRA scale
|
| 323 |
+
if not USE_PEFT_BACKEND:
|
| 324 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 325 |
+
else:
|
| 326 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 327 |
+
|
| 328 |
+
if prompt is not None and isinstance(prompt, str):
|
| 329 |
+
batch_size = 1
|
| 330 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 331 |
+
batch_size = len(prompt)
|
| 332 |
+
else:
|
| 333 |
+
batch_size = prompt_embeds.shape[0]
|
| 334 |
+
|
| 335 |
+
if prompt_embeds is None:
|
| 336 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 337 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 338 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 339 |
+
|
| 340 |
+
text_inputs = self.tokenizer(
|
| 341 |
+
prompt,
|
| 342 |
+
padding="max_length",
|
| 343 |
+
max_length=self.tokenizer.model_max_length,
|
| 344 |
+
truncation=True,
|
| 345 |
+
return_tensors="pt",
|
| 346 |
+
)
|
| 347 |
+
text_input_ids = text_inputs.input_ids
|
| 348 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 349 |
+
|
| 350 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 351 |
+
text_input_ids, untruncated_ids
|
| 352 |
+
):
|
| 353 |
+
removed_text = self.tokenizer.batch_decode(
|
| 354 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 355 |
+
)
|
| 356 |
+
logger.warning(
|
| 357 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 358 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 362 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 363 |
+
else:
|
| 364 |
+
attention_mask = None
|
| 365 |
+
|
| 366 |
+
if clip_skip is None:
|
| 367 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 368 |
+
prompt_embeds = prompt_embeds[0]
|
| 369 |
+
else:
|
| 370 |
+
prompt_embeds = self.text_encoder(
|
| 371 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 372 |
+
)
|
| 373 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 374 |
+
# all the hidden states from the encoder layers. Then index into
|
| 375 |
+
# the tuple to access the hidden states from the desired layer.
|
| 376 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 377 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 378 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 379 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 380 |
+
# layer.
|
| 381 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 382 |
+
|
| 383 |
+
if self.text_encoder is not None:
|
| 384 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 385 |
+
elif self.unet is not None:
|
| 386 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 387 |
+
else:
|
| 388 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 389 |
+
|
| 390 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 391 |
+
|
| 392 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 393 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 394 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 395 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 396 |
+
|
| 397 |
+
# get unconditional embeddings for classifier free guidance
|
| 398 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 399 |
+
uncond_tokens: List[str]
|
| 400 |
+
if negative_prompt is None:
|
| 401 |
+
uncond_tokens = [""] * batch_size
|
| 402 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 403 |
+
raise TypeError(
|
| 404 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 405 |
+
f" {type(prompt)}."
|
| 406 |
+
)
|
| 407 |
+
elif isinstance(negative_prompt, str):
|
| 408 |
+
uncond_tokens = [negative_prompt]
|
| 409 |
+
elif batch_size != len(negative_prompt):
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 412 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 413 |
+
" the batch size of `prompt`."
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
uncond_tokens = negative_prompt
|
| 417 |
+
|
| 418 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 419 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 420 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 421 |
+
|
| 422 |
+
max_length = prompt_embeds.shape[1]
|
| 423 |
+
uncond_input = self.tokenizer(
|
| 424 |
+
uncond_tokens,
|
| 425 |
+
padding="max_length",
|
| 426 |
+
max_length=max_length,
|
| 427 |
+
truncation=True,
|
| 428 |
+
return_tensors="pt",
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 432 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 433 |
+
else:
|
| 434 |
+
attention_mask = None
|
| 435 |
+
|
| 436 |
+
negative_prompt_embeds = self.text_encoder(
|
| 437 |
+
uncond_input.input_ids.to(device),
|
| 438 |
+
attention_mask=attention_mask,
|
| 439 |
+
)
|
| 440 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 441 |
+
|
| 442 |
+
if do_classifier_free_guidance:
|
| 443 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 444 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 445 |
+
|
| 446 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 447 |
+
|
| 448 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 449 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 450 |
+
|
| 451 |
+
if self.text_encoder is not None:
|
| 452 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 453 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 454 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 455 |
+
|
| 456 |
+
return prompt_embeds, negative_prompt_embeds
|
| 457 |
+
|
| 458 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
| 459 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 460 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 461 |
+
|
| 462 |
+
if not isinstance(image, torch.Tensor):
|
| 463 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 464 |
+
|
| 465 |
+
image = image.to(device=device, dtype=dtype)
|
| 466 |
+
if output_hidden_states:
|
| 467 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 468 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 469 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 470 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 471 |
+
).hidden_states[-2]
|
| 472 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 473 |
+
num_images_per_prompt, dim=0
|
| 474 |
+
)
|
| 475 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 476 |
+
else:
|
| 477 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 478 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 479 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 480 |
+
|
| 481 |
+
return image_embeds, uncond_image_embeds
|
| 482 |
+
|
| 483 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
| 484 |
+
def prepare_ip_adapter_image_embeds(
|
| 485 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| 486 |
+
):
|
| 487 |
+
image_embeds = []
|
| 488 |
+
if do_classifier_free_guidance:
|
| 489 |
+
negative_image_embeds = []
|
| 490 |
+
if ip_adapter_image_embeds is None:
|
| 491 |
+
if not isinstance(ip_adapter_image, list):
|
| 492 |
+
ip_adapter_image = [ip_adapter_image]
|
| 493 |
+
|
| 494 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 500 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| 501 |
+
):
|
| 502 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| 503 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| 504 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 508 |
+
if do_classifier_free_guidance:
|
| 509 |
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
| 510 |
+
else:
|
| 511 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 512 |
+
if do_classifier_free_guidance:
|
| 513 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| 514 |
+
negative_image_embeds.append(single_negative_image_embeds)
|
| 515 |
+
image_embeds.append(single_image_embeds)
|
| 516 |
+
|
| 517 |
+
ip_adapter_image_embeds = []
|
| 518 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 519 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 520 |
+
if do_classifier_free_guidance:
|
| 521 |
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
| 522 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
| 523 |
+
|
| 524 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 525 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 526 |
+
|
| 527 |
+
return ip_adapter_image_embeds
|
| 528 |
+
|
| 529 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
| 530 |
+
def run_safety_checker(self, image, device, dtype):
|
| 531 |
+
if self.safety_checker is None:
|
| 532 |
+
has_nsfw_concept = None
|
| 533 |
+
else:
|
| 534 |
+
if torch.is_tensor(image):
|
| 535 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 536 |
+
else:
|
| 537 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 538 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| 539 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 540 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 541 |
+
)
|
| 542 |
+
return image, has_nsfw_concept
|
| 543 |
+
|
| 544 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
| 545 |
+
def decode_latents(self, latents):
|
| 546 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 547 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 548 |
+
|
| 549 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 550 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 551 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 552 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 553 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 554 |
+
return image
|
| 555 |
+
|
| 556 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 557 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 558 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 559 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 560 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 561 |
+
# and should be between [0, 1]
|
| 562 |
+
|
| 563 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 564 |
+
extra_step_kwargs = {}
|
| 565 |
+
if accepts_eta:
|
| 566 |
+
extra_step_kwargs["eta"] = eta
|
| 567 |
+
|
| 568 |
+
# check if the scheduler accepts generator
|
| 569 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 570 |
+
if accepts_generator:
|
| 571 |
+
extra_step_kwargs["generator"] = generator
|
| 572 |
+
return extra_step_kwargs
|
| 573 |
+
|
| 574 |
+
def check_inputs(
|
| 575 |
+
self,
|
| 576 |
+
prompt,
|
| 577 |
+
image,
|
| 578 |
+
callback_steps,
|
| 579 |
+
negative_prompt=None,
|
| 580 |
+
prompt_embeds=None,
|
| 581 |
+
negative_prompt_embeds=None,
|
| 582 |
+
ip_adapter_image=None,
|
| 583 |
+
ip_adapter_image_embeds=None,
|
| 584 |
+
controlnet_conditioning_scale=1.0,
|
| 585 |
+
control_guidance_start=0.0,
|
| 586 |
+
control_guidance_end=1.0,
|
| 587 |
+
callback_on_step_end_tensor_inputs=None,
|
| 588 |
+
):
|
| 589 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 590 |
+
raise ValueError(
|
| 591 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 592 |
+
f" {type(callback_steps)}."
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 596 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 597 |
+
):
|
| 598 |
+
raise ValueError(
|
| 599 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if prompt is not None and prompt_embeds is not None:
|
| 603 |
+
raise ValueError(
|
| 604 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 605 |
+
" only forward one of the two."
|
| 606 |
+
)
|
| 607 |
+
elif prompt is None and prompt_embeds is None:
|
| 608 |
+
raise ValueError(
|
| 609 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 610 |
+
)
|
| 611 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 612 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 613 |
+
|
| 614 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 615 |
+
raise ValueError(
|
| 616 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 617 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 621 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 622 |
+
raise ValueError(
|
| 623 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 624 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 625 |
+
f" {negative_prompt_embeds.shape}."
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
# `prompt` needs more sophisticated handling when there are multiple
|
| 629 |
+
# conditionings.
|
| 630 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 631 |
+
if isinstance(prompt, list):
|
| 632 |
+
logger.warning(
|
| 633 |
+
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
| 634 |
+
" prompts. The conditionings will be fixed across the prompts."
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Check `image`
|
| 638 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
| 639 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
| 640 |
+
)
|
| 641 |
+
if (
|
| 642 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 643 |
+
or is_compiled
|
| 644 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 645 |
+
):
|
| 646 |
+
self.check_image(image, prompt, prompt_embeds)
|
| 647 |
+
elif (
|
| 648 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 649 |
+
or is_compiled
|
| 650 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 651 |
+
):
|
| 652 |
+
if not isinstance(image, list):
|
| 653 |
+
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
| 654 |
+
|
| 655 |
+
# When `image` is a nested list:
|
| 656 |
+
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
| 657 |
+
elif any(isinstance(i, list) for i in image):
|
| 658 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
| 659 |
+
elif len(image) != len(self.controlnet.nets):
|
| 660 |
+
raise ValueError(
|
| 661 |
+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
for image_ in image:
|
| 665 |
+
self.check_image(image_, prompt, prompt_embeds)
|
| 666 |
+
else:
|
| 667 |
+
assert False
|
| 668 |
+
|
| 669 |
+
# Check `controlnet_conditioning_scale`
|
| 670 |
+
if (
|
| 671 |
+
isinstance(self.controlnet, ControlNetModel)
|
| 672 |
+
or is_compiled
|
| 673 |
+
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
| 674 |
+
):
|
| 675 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
| 676 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
| 677 |
+
elif (
|
| 678 |
+
isinstance(self.controlnet, MultiControlNetModel)
|
| 679 |
+
or is_compiled
|
| 680 |
+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
| 681 |
+
):
|
| 682 |
+
if isinstance(controlnet_conditioning_scale, list):
|
| 683 |
+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
| 684 |
+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
| 685 |
+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
| 686 |
+
self.controlnet.nets
|
| 687 |
+
):
|
| 688 |
+
raise ValueError(
|
| 689 |
+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
| 690 |
+
" the same length as the number of controlnets"
|
| 691 |
+
)
|
| 692 |
+
else:
|
| 693 |
+
assert False
|
| 694 |
+
|
| 695 |
+
if len(control_guidance_start) != len(control_guidance_end):
|
| 696 |
+
raise ValueError(
|
| 697 |
+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
| 701 |
+
if len(control_guidance_start) != len(self.controlnet.nets):
|
| 702 |
+
raise ValueError(
|
| 703 |
+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
for start, end in zip(control_guidance_start, control_guidance_end):
|
| 707 |
+
if start >= end:
|
| 708 |
+
raise ValueError(
|
| 709 |
+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
| 710 |
+
)
|
| 711 |
+
if start < 0.0:
|
| 712 |
+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
| 713 |
+
if end > 1.0:
|
| 714 |
+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
| 715 |
+
|
| 716 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| 717 |
+
raise ValueError(
|
| 718 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if ip_adapter_image_embeds is not None:
|
| 722 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 723 |
+
raise ValueError(
|
| 724 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| 725 |
+
)
|
| 726 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| 727 |
+
raise ValueError(
|
| 728 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
| 732 |
+
def check_image(self, image, prompt, prompt_embeds):
|
| 733 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
| 734 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
| 735 |
+
image_is_np = isinstance(image, np.ndarray)
|
| 736 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
| 737 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
| 738 |
+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
| 739 |
+
|
| 740 |
+
if (
|
| 741 |
+
not image_is_pil
|
| 742 |
+
and not image_is_tensor
|
| 743 |
+
and not image_is_np
|
| 744 |
+
and not image_is_pil_list
|
| 745 |
+
and not image_is_tensor_list
|
| 746 |
+
and not image_is_np_list
|
| 747 |
+
):
|
| 748 |
+
raise TypeError(
|
| 749 |
+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
if image_is_pil:
|
| 753 |
+
image_batch_size = 1
|
| 754 |
+
else:
|
| 755 |
+
image_batch_size = len(image)
|
| 756 |
+
|
| 757 |
+
if prompt is not None and isinstance(prompt, str):
|
| 758 |
+
prompt_batch_size = 1
|
| 759 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 760 |
+
prompt_batch_size = len(prompt)
|
| 761 |
+
elif prompt_embeds is not None:
|
| 762 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
| 763 |
+
|
| 764 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
| 765 |
+
raise ValueError(
|
| 766 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
| 770 |
+
def prepare_control_image(
|
| 771 |
+
self,
|
| 772 |
+
image,
|
| 773 |
+
width,
|
| 774 |
+
height,
|
| 775 |
+
batch_size,
|
| 776 |
+
num_images_per_prompt,
|
| 777 |
+
device,
|
| 778 |
+
dtype,
|
| 779 |
+
do_classifier_free_guidance=False,
|
| 780 |
+
guess_mode=False,
|
| 781 |
+
):
|
| 782 |
+
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
| 783 |
+
image_batch_size = image.shape[0]
|
| 784 |
+
|
| 785 |
+
if image_batch_size == 1:
|
| 786 |
+
repeat_by = batch_size
|
| 787 |
+
else:
|
| 788 |
+
# image batch size is the same as prompt batch size
|
| 789 |
+
repeat_by = num_images_per_prompt
|
| 790 |
+
|
| 791 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 792 |
+
|
| 793 |
+
image = image.to(device=device, dtype=dtype)
|
| 794 |
+
|
| 795 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 796 |
+
image = torch.cat([image] * 2)
|
| 797 |
+
|
| 798 |
+
return image
|
| 799 |
+
|
| 800 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
| 801 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 802 |
+
# get the original timestep using init_timestep
|
| 803 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 804 |
+
|
| 805 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 806 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 807 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
| 808 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
| 809 |
+
|
| 810 |
+
return timesteps, num_inference_steps - t_start
|
| 811 |
+
|
| 812 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
|
| 813 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
| 814 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 815 |
+
raise ValueError(
|
| 816 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
image = image.to(device=device, dtype=dtype)
|
| 820 |
+
|
| 821 |
+
batch_size = batch_size * num_images_per_prompt
|
| 822 |
+
|
| 823 |
+
if image.shape[1] == 4:
|
| 824 |
+
init_latents = image
|
| 825 |
+
|
| 826 |
+
else:
|
| 827 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 828 |
+
raise ValueError(
|
| 829 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 830 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
elif isinstance(generator, list):
|
| 834 |
+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
|
| 835 |
+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
|
| 836 |
+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
|
| 837 |
+
raise ValueError(
|
| 838 |
+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
init_latents = [
|
| 842 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 843 |
+
for i in range(batch_size)
|
| 844 |
+
]
|
| 845 |
+
init_latents = torch.cat(init_latents, dim=0)
|
| 846 |
+
else:
|
| 847 |
+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 848 |
+
|
| 849 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
| 850 |
+
|
| 851 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 852 |
+
# expand init_latents for batch_size
|
| 853 |
+
deprecation_message = (
|
| 854 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
| 855 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
| 856 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
| 857 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
| 858 |
+
)
|
| 859 |
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
| 860 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
| 861 |
+
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
| 862 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
| 863 |
+
raise ValueError(
|
| 864 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 865 |
+
)
|
| 866 |
+
else:
|
| 867 |
+
init_latents = torch.cat([init_latents], dim=0)
|
| 868 |
+
|
| 869 |
+
shape = init_latents.shape
|
| 870 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 871 |
+
|
| 872 |
+
# get latents
|
| 873 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 874 |
+
latents = init_latents
|
| 875 |
+
|
| 876 |
+
return latents
|
| 877 |
+
|
| 878 |
+
@property
|
| 879 |
+
def guidance_scale(self):
|
| 880 |
+
return self._guidance_scale
|
| 881 |
+
|
| 882 |
+
@property
|
| 883 |
+
def clip_skip(self):
|
| 884 |
+
return self._clip_skip
|
| 885 |
+
|
| 886 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 887 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 888 |
+
# corresponds to doing no classifier free guidance.
|
| 889 |
+
@property
|
| 890 |
+
def do_classifier_free_guidance(self):
|
| 891 |
+
return self._guidance_scale > 1
|
| 892 |
+
|
| 893 |
+
@property
|
| 894 |
+
def cross_attention_kwargs(self):
|
| 895 |
+
return self._cross_attention_kwargs
|
| 896 |
+
|
| 897 |
+
@property
|
| 898 |
+
def num_timesteps(self):
|
| 899 |
+
return self._num_timesteps
|
| 900 |
+
|
| 901 |
+
@property
|
| 902 |
+
def interrupt(self):
|
| 903 |
+
return self._interrupt
|
| 904 |
+
|
| 905 |
+
@torch.no_grad()
|
| 906 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 907 |
+
def __call__(
|
| 908 |
+
self,
|
| 909 |
+
prompt: Union[str, List[str]] = None,
|
| 910 |
+
image: PipelineImageInput = None,
|
| 911 |
+
control_image: PipelineImageInput = None,
|
| 912 |
+
height: Optional[int] = None,
|
| 913 |
+
width: Optional[int] = None,
|
| 914 |
+
strength: float = 0.8,
|
| 915 |
+
num_inference_steps: int = 50,
|
| 916 |
+
guidance_scale: float = 7.5,
|
| 917 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 918 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 919 |
+
eta: float = 0.0,
|
| 920 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 921 |
+
latents: Optional[torch.Tensor] = None,
|
| 922 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 923 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 924 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 925 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 926 |
+
output_type: Optional[str] = "pil",
|
| 927 |
+
return_dict: bool = True,
|
| 928 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 929 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
|
| 930 |
+
guess_mode: bool = False,
|
| 931 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 932 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 933 |
+
clip_skip: Optional[int] = None,
|
| 934 |
+
callback_on_step_end: Optional[
|
| 935 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 936 |
+
] = None,
|
| 937 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 938 |
+
**kwargs,
|
| 939 |
+
):
|
| 940 |
+
r"""
|
| 941 |
+
The call function to the pipeline for generation.
|
| 942 |
+
|
| 943 |
+
Args:
|
| 944 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 945 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 946 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
| 947 |
+
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
| 948 |
+
The initial image to be used as the starting point for the image generation process. Can also accept
|
| 949 |
+
image latents as `image`, and if passing latents directly they are not encoded again.
|
| 950 |
+
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
| 951 |
+
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
| 952 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
| 953 |
+
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
| 954 |
+
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
| 955 |
+
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
| 956 |
+
images must be passed as a list such that each element of the list can be correctly batched for input
|
| 957 |
+
to a single ControlNet.
|
| 958 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 959 |
+
The height in pixels of the generated image.
|
| 960 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 961 |
+
The width in pixels of the generated image.
|
| 962 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 963 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
| 964 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
| 965 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
| 966 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
| 967 |
+
essentially ignores `image`.
|
| 968 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 969 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 970 |
+
expense of slower inference.
|
| 971 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 972 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 973 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 974 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 975 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 976 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 977 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 978 |
+
The number of images to generate per prompt.
|
| 979 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 980 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 981 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 982 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 983 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 984 |
+
generation deterministic.
|
| 985 |
+
latents (`torch.Tensor`, *optional*):
|
| 986 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 987 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 988 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 989 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 990 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 991 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 992 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 993 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 994 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 995 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 996 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 997 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 998 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 999 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 1000 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 1001 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1002 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 1003 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1004 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1005 |
+
plain tuple.
|
| 1006 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 1007 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 1008 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1009 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 1010 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
| 1011 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
| 1012 |
+
the corresponding scale as a list.
|
| 1013 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
| 1014 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
| 1015 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
| 1016 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
| 1017 |
+
The percentage of total steps at which the ControlNet starts applying.
|
| 1018 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 1019 |
+
The percentage of total steps at which the ControlNet stops applying.
|
| 1020 |
+
clip_skip (`int`, *optional*):
|
| 1021 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 1022 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 1023 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 1024 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 1025 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 1026 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 1027 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 1028 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1029 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 1030 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 1031 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 1032 |
+
|
| 1033 |
+
Examples:
|
| 1034 |
+
|
| 1035 |
+
Returns:
|
| 1036 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1037 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 1038 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 1039 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 1040 |
+
"not-safe-for-work" (nsfw) content.
|
| 1041 |
+
"""
|
| 1042 |
+
|
| 1043 |
+
callback = kwargs.pop("callback", None)
|
| 1044 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 1045 |
+
|
| 1046 |
+
if callback is not None:
|
| 1047 |
+
deprecate(
|
| 1048 |
+
"callback",
|
| 1049 |
+
"1.0.0",
|
| 1050 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 1051 |
+
)
|
| 1052 |
+
if callback_steps is not None:
|
| 1053 |
+
deprecate(
|
| 1054 |
+
"callback_steps",
|
| 1055 |
+
"1.0.0",
|
| 1056 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 1060 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 1061 |
+
|
| 1062 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 1063 |
+
|
| 1064 |
+
# align format for control guidance
|
| 1065 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 1066 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 1067 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 1068 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 1069 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 1070 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 1071 |
+
control_guidance_start, control_guidance_end = (
|
| 1072 |
+
mult * [control_guidance_start],
|
| 1073 |
+
mult * [control_guidance_end],
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
# 1. Check inputs. Raise error if not correct
|
| 1077 |
+
self.check_inputs(
|
| 1078 |
+
prompt,
|
| 1079 |
+
control_image,
|
| 1080 |
+
callback_steps,
|
| 1081 |
+
negative_prompt,
|
| 1082 |
+
prompt_embeds,
|
| 1083 |
+
negative_prompt_embeds,
|
| 1084 |
+
ip_adapter_image,
|
| 1085 |
+
ip_adapter_image_embeds,
|
| 1086 |
+
controlnet_conditioning_scale,
|
| 1087 |
+
control_guidance_start,
|
| 1088 |
+
control_guidance_end,
|
| 1089 |
+
callback_on_step_end_tensor_inputs,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
self._guidance_scale = guidance_scale
|
| 1093 |
+
self._clip_skip = clip_skip
|
| 1094 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 1095 |
+
self._interrupt = False
|
| 1096 |
+
|
| 1097 |
+
# 2. Define call parameters
|
| 1098 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1099 |
+
batch_size = 1
|
| 1100 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1101 |
+
batch_size = len(prompt)
|
| 1102 |
+
else:
|
| 1103 |
+
batch_size = prompt_embeds.shape[0]
|
| 1104 |
+
|
| 1105 |
+
device = self._execution_device
|
| 1106 |
+
|
| 1107 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 1108 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
| 1109 |
+
|
| 1110 |
+
global_pool_conditions = (
|
| 1111 |
+
controlnet.config.global_pool_conditions
|
| 1112 |
+
if isinstance(controlnet, ControlNetModel)
|
| 1113 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 1114 |
+
)
|
| 1115 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 1116 |
+
|
| 1117 |
+
# 3. Encode input prompt
|
| 1118 |
+
text_encoder_lora_scale = (
|
| 1119 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 1120 |
+
)
|
| 1121 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 1122 |
+
prompt,
|
| 1123 |
+
device,
|
| 1124 |
+
num_images_per_prompt,
|
| 1125 |
+
self.do_classifier_free_guidance,
|
| 1126 |
+
negative_prompt,
|
| 1127 |
+
prompt_embeds=prompt_embeds,
|
| 1128 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1129 |
+
lora_scale=text_encoder_lora_scale,
|
| 1130 |
+
clip_skip=self.clip_skip,
|
| 1131 |
+
)
|
| 1132 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 1133 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 1134 |
+
# to avoid doing two forward passes
|
| 1135 |
+
if self.do_classifier_free_guidance:
|
| 1136 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 1137 |
+
|
| 1138 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1139 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1140 |
+
ip_adapter_image,
|
| 1141 |
+
ip_adapter_image_embeds,
|
| 1142 |
+
device,
|
| 1143 |
+
batch_size * num_images_per_prompt,
|
| 1144 |
+
self.do_classifier_free_guidance,
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
# 4. Prepare image
|
| 1148 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
| 1149 |
+
|
| 1150 |
+
# 5. Prepare controlnet_conditioning_image
|
| 1151 |
+
if isinstance(controlnet, ControlNetModel):
|
| 1152 |
+
control_image = self.prepare_control_image(
|
| 1153 |
+
image=control_image,
|
| 1154 |
+
width=width,
|
| 1155 |
+
height=height,
|
| 1156 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 1157 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1158 |
+
device=device,
|
| 1159 |
+
dtype=controlnet.dtype,
|
| 1160 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1161 |
+
guess_mode=guess_mode,
|
| 1162 |
+
)
|
| 1163 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
| 1164 |
+
control_images = []
|
| 1165 |
+
|
| 1166 |
+
for control_image_ in control_image:
|
| 1167 |
+
control_image_ = self.prepare_control_image(
|
| 1168 |
+
image=control_image_,
|
| 1169 |
+
width=width,
|
| 1170 |
+
height=height,
|
| 1171 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 1172 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1173 |
+
device=device,
|
| 1174 |
+
dtype=controlnet.dtype,
|
| 1175 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 1176 |
+
guess_mode=guess_mode,
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
control_images.append(control_image_)
|
| 1180 |
+
|
| 1181 |
+
control_image = control_images
|
| 1182 |
+
else:
|
| 1183 |
+
assert False
|
| 1184 |
+
|
| 1185 |
+
# 5. Prepare timesteps
|
| 1186 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 1187 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
| 1188 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 1189 |
+
self._num_timesteps = len(timesteps)
|
| 1190 |
+
|
| 1191 |
+
# 6. Prepare latent variables
|
| 1192 |
+
if latents is None:
|
| 1193 |
+
latents = self.prepare_latents(
|
| 1194 |
+
image,
|
| 1195 |
+
latent_timestep,
|
| 1196 |
+
batch_size,
|
| 1197 |
+
num_images_per_prompt,
|
| 1198 |
+
prompt_embeds.dtype,
|
| 1199 |
+
device,
|
| 1200 |
+
generator,
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1204 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1205 |
+
|
| 1206 |
+
# 7.1 Add image embeds for IP-Adapter
|
| 1207 |
+
added_cond_kwargs = (
|
| 1208 |
+
{"image_embeds": image_embeds}
|
| 1209 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| 1210 |
+
else None
|
| 1211 |
+
)
|
| 1212 |
+
|
| 1213 |
+
# 7.2 Create tensor stating which controlnets to keep
|
| 1214 |
+
controlnet_keep = []
|
| 1215 |
+
for i in range(len(timesteps)):
|
| 1216 |
+
keeps = [
|
| 1217 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 1218 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 1219 |
+
]
|
| 1220 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
| 1221 |
+
|
| 1222 |
+
# 8. Denoising loop
|
| 1223 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1224 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1225 |
+
for i, t in enumerate(timesteps):
|
| 1226 |
+
if self.interrupt:
|
| 1227 |
+
continue
|
| 1228 |
+
|
| 1229 |
+
# expand the latents if we are doing classifier free guidance
|
| 1230 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1231 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1232 |
+
|
| 1233 |
+
# controlnet(s) inference
|
| 1234 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 1235 |
+
# Infer ControlNet only for the conditional batch.
|
| 1236 |
+
control_model_input = latents
|
| 1237 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 1238 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 1239 |
+
else:
|
| 1240 |
+
control_model_input = latent_model_input
|
| 1241 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 1242 |
+
|
| 1243 |
+
if isinstance(controlnet_keep[i], list):
|
| 1244 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 1245 |
+
else:
|
| 1246 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 1247 |
+
if isinstance(controlnet_cond_scale, list):
|
| 1248 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 1249 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 1250 |
+
|
| 1251 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1252 |
+
control_model_input,
|
| 1253 |
+
t,
|
| 1254 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 1255 |
+
controlnet_cond=control_image,
|
| 1256 |
+
conditioning_scale=cond_scale,
|
| 1257 |
+
guess_mode=guess_mode,
|
| 1258 |
+
return_dict=False,
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 1262 |
+
# Inferred ControlNet only for the conditional batch.
|
| 1263 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 1264 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 1265 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 1266 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 1267 |
+
|
| 1268 |
+
# predict the noise residual
|
| 1269 |
+
noise_pred = self.unet(
|
| 1270 |
+
latent_model_input,
|
| 1271 |
+
t,
|
| 1272 |
+
encoder_hidden_states=prompt_embeds,
|
| 1273 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 1274 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 1275 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 1276 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 1277 |
+
return_dict=False,
|
| 1278 |
+
)[0]
|
| 1279 |
+
|
| 1280 |
+
# perform guidance
|
| 1281 |
+
if self.do_classifier_free_guidance:
|
| 1282 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1283 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1284 |
+
|
| 1285 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1286 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1287 |
+
|
| 1288 |
+
if callback_on_step_end is not None:
|
| 1289 |
+
callback_kwargs = {}
|
| 1290 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1291 |
+
callback_kwargs[k] = locals()[k]
|
| 1292 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1293 |
+
|
| 1294 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1295 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1296 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1297 |
+
control_image = callback_outputs.pop("control_image", control_image)
|
| 1298 |
+
|
| 1299 |
+
# call the callback, if provided
|
| 1300 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1301 |
+
progress_bar.update()
|
| 1302 |
+
if callback is not None and i % callback_steps == 0:
|
| 1303 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 1304 |
+
callback(step_idx, t, latents)
|
| 1305 |
+
|
| 1306 |
+
if XLA_AVAILABLE:
|
| 1307 |
+
xm.mark_step()
|
| 1308 |
+
|
| 1309 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1310 |
+
# manually for max memory savings
|
| 1311 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 1312 |
+
self.unet.to("cpu")
|
| 1313 |
+
self.controlnet.to("cpu")
|
| 1314 |
+
empty_device_cache()
|
| 1315 |
+
|
| 1316 |
+
if not output_type == "latent":
|
| 1317 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 1318 |
+
0
|
| 1319 |
+
]
|
| 1320 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 1321 |
+
else:
|
| 1322 |
+
image = latents
|
| 1323 |
+
has_nsfw_concept = None
|
| 1324 |
+
|
| 1325 |
+
if has_nsfw_concept is None:
|
| 1326 |
+
do_denormalize = [True] * image.shape[0]
|
| 1327 |
+
else:
|
| 1328 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 1329 |
+
|
| 1330 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 1331 |
+
|
| 1332 |
+
# Offload all models
|
| 1333 |
+
self.maybe_free_model_hooks()
|
| 1334 |
+
|
| 1335 |
+
if not return_dict:
|
| 1336 |
+
return (image, has_nsfw_concept)
|
| 1337 |
+
|
| 1338 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|